Using a neural network to classify images is an attractive proposition, but getting set up to do it may seem like a challenge. In many cases, especially in research, it can be difficult and expensive to get enough labelled data to train a neural network from scratch. One solution is to use a network that has already been trained on a larger data set as the basis for your neural network, an approach known as ‘transfer learning’. The idea is to take advantage of the information embedded in the pre-trained network so that you need far less data to train your own network.
Once you’ve trained a neural network to classify your images, you’ll probably want to use it to analyze lots of data—if not, you might be better off just having an expert do the analysis. You can run distributed predictions with Dask and easily shift the workload to a remote cluster using Coiled Cloud. With a personal account on Coiled Cloud, you get 1000 Coiled Credits for free each month.
In this post, we'll cover how to:
This post draws heavily on the Batch Prediction with PyTorch tutorial from Dask, which in turn builds on the Transfer Learning for Computer Vision tutorial from PyTorch. A notebook accompanying this post is available here: Analyzing Microscopy Images with PyTorch and Dask. You can repeat the analysis or experiment with modifying it.
We’ll be analyzing microscopy images to determine whether they show normal cells or leukemic blast cells. The images are from 118 patients (almost 10 GB in the full dataset) who were either healthy or had acute lymphoblastic leukemia (ALL), which is the most common childhood cancer and the leading cause of cancer-related deaths among children. ALL can be treated with chemotherapy, but identifying ALL cells under the microscope can be challenging. A reliable, high-throughput system for scoring microscopy images could be used to distinguish cases that are almost certainly positive or negative from those that need further analysis, such as immunological assays or expert review.
For this example analysis, we’re going to download a subset of the data. Though this post explains how to run the image analysis on Coiled Cloud, we’re going to run it locally on a small data set rather than remotely on a larger data set from cloud storage. We’ll be using a randomly selected set of images from this acute lymphoblastic leukemia dataset which we download into leukemia_data (download our subset of data here). The images are divided into training and validation sets (train and val folders) which will be used to retrain the model. Note that there are only 120 training images and 75 validation images—far fewer than would be needed to train a neural network from scratch!
If you’d like to try this with your own data instead, create the same folder structure with your images and change data_dir in helper_functions to point to that folder. You can do this locally with about the same number of images and then use the model to train images from a larger dataset remotely.
There are several approaches to transfer learning. The one we’ll be using here is to retrain the final layer of a trained neural network so it classifies images into the categories we’re interested in. We’ll be using ResNet18, a neural network trained on the ImageNet dataset.
The idea is to take advantage of the fact that ResNet18 can extract useful features from images— things like edges and shapes—but retrain its final layer to classify images into the two categories we’re interested in (leukemia vs. healthy cells) rather than the 1,000 categories it was originally trained to recognize. (If you want to read more about transfer learning or learn about the other approaches, have a look at these course notes.)
To carry out the transfer learning, we’re going to use a few functions from the PyTorch tutorial that we’ve put into helper_functions:
//from helper_functions import (imshow, train_model, visualize_model,
dataloaders, class_names, finetune_model)//]]>
The transfer learning is carried out by finetune_model. The details of how transfer learning works aren’t the focus of this post, but we’ll go over it briefly. For a more thorough explanation, have a look at the PyTorch tutorial (in particular this section).
The first thing the function does is load the pre-trained ResNet18 model from PyTorch:
//model_ft = models.resnet18(pretrained=True)//]]>
Next, it freezes all of the parameters in the network so they won’t be adjusted when we retrain it:
//for param in model_ft.parameters():
param.requires_grad = False //]]>
The function then resets the final layer to have the same number of input features but only two output classes:
//num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)//]]>
The remainder of the function trains the model (using the train_model helper function). Since every layer except the (new) last layer has been frozen, the last layer is the only one being trained. In other words, we’re taking advantage of all of the feature extraction embedded in the earlier layers or ResNet18 but retraining it to recognize the categories we care about.
//model = finetune_model() //]]>
(Note that we’re only training the model for four epochs in this example to save time. If you want to get better results, try increasing num_epochs in the last line of finetune_model.)
Let’s have a look at the predictions for some images the model hasn’t seen before. We’re looking at images from the validation set, so we know the correct labels for these images, but the model wasn’t trained on them. This gives us a quick visual overview of how the model performs on new data. Of course, the idea is to eventually use the model to analyze unlabeled images.
At this point, you’ve got a model trained to classify your data. If you were only going to analyze a handful of images, you could stop here and use the neural network to classify them. Instead, we’re going to show you how to run a distributed analysis on a remote cluster!
The Batch Prediction with PyTorch tutorial uses dask.delayed. With dask.delayed, functions aren’t executed immediately but lazily. The execution is deferred, and the dask scheduler builds a task graph connecting all the functions in a computation. This enables the scheduler to identify opportunities for parallel execution and use them to speed up execution.
We’re going to take a different approach here. Instead of building a task graph, we're going to use Dasks’ Futures interface to run the analysis asynchronously but not lazily. Instead of building a large task graph and then doing the computation, we run the analysis asynchronously in the background and collect the results as they become ready. We still get the benefits of Dask and Coiled—easy local or remote distributed computing—but the local Python session never blocks.
Both approaches are effective. The choice between them depends to some extent on your workflow. Futures offers a bit more flexibility so you can adjust your computations and explore different avenues in real-time, which might be preferable for this kind of analysis. If you’d like to use delayed instead, compare the code below with the tutorial linked above.
Since we have the data locally, we’re going to use Dask locally instead of creating a cluster on Coiled Cloud:
//from distributed import Client
# Adjust the parameters to match your system
client = Client(n_workers=4, threads_per_worker=2) //]]>
If you have data that’s available remotely, it’s quite easy to use a cluster on Coiled Cloud instead:
// # Start a cluster on Coiled Cloud
cluster = coiled.Cluster(software="sedeer/pytorch-example")
# Connect the client
from distributed import Client
client = Client(cluster)//]]>
The initiation takes a few minutes to provision resources for the cluster. Note that we’re creating a cluster using the sedeer/pytorch-example software environment, which was built using the same environment file provided in Setup. For more information about software environments and the importance of using the same environment locally and on Coiled Cloud, have a look at the Coiled documentation.
Aside from a small difference in how the data is loaded, the rest of the code—the data processing and predictions—is almost exactly the same whether you’re using Dask locally or on Coiled Cloud. We’ll write a couple of functions to load the images and transform them into tensors:
from torchvision import transforms
from PIL import Image
def load(path, fs=__builtins__):
with fs.open(path, 'rb') as f:
img = Image.open(f).convert("RGB")
def transform(img): # This is the same as the 'val' transformation in helper_functions
trn = transforms.Compose([
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
We’re going to use client.submit to submit these functions to the Dask scheduler to run on the cluster. This returns a Future object which points to the result, which could be completed or pending.
We can use Future objects as input to client.submit without worrying about their status, as we do in the second line of the for loop below. Unless we explicitly request a result, everything will run asynchronously with the scheduler tracking dependencies and passing results forward.
//tensors = 
for f in glob.glob("leukemia_data/val/*/*.bmp"):
img = client.submit(load, f)
tensor = client.submit(transform, img)
Note that we loaded the data from the local filesystem (which the local Dask workers have access to). When you’re running this on Coiled Cloud, the data needs to be somewhere where the remote workers can access it, such as an Amazon S3 bucket. In that case, we use the same load function but point it towards our cloud storage with something like:
fs = s3fs.S3FileSystem(anon=False, key='<Access Key>',
and then our for loop would be:
//for f in fs.glob("s3://coiled-data/leukemia_data/val/*/*.bmp"):
img = client.submit(load, f, fs=fs)
Either way, the next step is to collect the transformed tensors into batches because the model expects batched input. We’ll loop over sequences of up to 10 tensors and collection them in a batch with torch.stack:
//batches = 
for b in toolz.partition_all(10, tensors):
batch = client.submit(torch.stack, b)
Our data is now ready for analysis. Actually, it may not yet be ready, strictly speaking, but we don’t have to worry about that. We’ve got a bunch of futures pointing to the data. They could be pending or finished, and we can go ahead and use them as though the data were ready:
Before we continue, we’ll delete the futures stored in tensors. This isn’t crucial, but if we don’t do it, Dask will hold on to the futures in case we ever explicitly request the result. We don’t want to use up memory to hold intermediate results, especially in cases where we’re dealing with large data sets. By deleting it, we let the scheduler know we don’t need it and the memory can be cleared as soon as all of the computations depending on it are done.
//del tensors //]]>
The next thing to do is prepare the neural network for use with Dask. To do that, we serialize it with pickle and then send it to the cluster with client.scatter. This returns a future pointing to it so the Dask scheduler can figure out how to handle it. This way, the scheduler can use the model future for the computations instead of including the model once per batch.
modelf = client.scatter(model)//]]>
We’ll use the following function to make the predictions. It returns the most likely category for each image, which is handy when we’re looking at the validation data. Later, we’ll use a similar function to return the probability for each category, which would probably be more useful in a real analysis.
//def predict(batch, model):
out = model(batch)
_, predicted = torch.max(out, 1)
predicted = predicted.numpy()
And with that, we have everything we need to classify the images:
//predictions = [client.submit(predict, batch, modelf) for batch in batches]
Again, these are just futures. In this case, it’s fine to get the output immediately because we’re just looking at the (small) validation set. If you were running this on a larger (remote) data set, you might prefer to write the results to cloud storage instead of displaying them.
To get the predictions, we use the .result method, which blocks until the future is finished and returns the result. We can do this using a list comprehension or with client.gather, which can be faster:
//#results = [predictor.result() for predictor in predictions]
results = client.gather(predictions)
We can easily convert the predictions into a more readable format:
for rslts in results:
print([classes[result] for result in rslts])//]]>
That output is a helpful overview of how the model performed on the validation data. We ran predictions on the entire validation set, so a perfect prediction would produce 75 ALLs followed by 75 Nors.
In a real analysis, you might be more interested in getting the probability score for each category for an image. You might use a certain cut-off to identify samples which need further analysis, such as expert review, immunological assays, or maybe even additional machine-learning approaches. It only takes a small change to the predict function to output probabilities; we just switch to using a softmax instead of retrieving the class with the max score:
//def predict_probabilities(batch, model):
out = model(batch)
predicted = torch.nn.functional.softmax(out, dim=1)
predicted = predicted.numpy()
probabilities = [client.submit(predict_probabilities, batch, modelf) for batch in batches]//]]>
At this point, you’d probably want to feed the probabilities into the next step in your workflow/pipeline. Since this is just an example, let’s have a look at them instead:
//result_probs = client.gather(probabilities)
And finally, we shut down the client when we’re done:
And with that, you’re ready to train an image classifier and run it distributed across a local or remote cluster!
Dig further into the leukemia data set—we only analyzed a small fraction of it here—or use these tools to analyze your own data. Get hassle-free cloud deployments of Dask clusters for your analyses by clicking below to sign up for a free Coiled Cloud account. Get in touch and let us know how your analyses go—we’d love to help with any challenges and hear about your successes!