Train an image classifier using F# and ML .NET
Introduction
This post is part of F# Advent 2020. Thank you to Sergey Tihon for organizing this and the rest of the contributors for producing high-quality content. Make sure to check out the rest of the F# Advent 2020 content.
When picking who's on the naughty or nice list, I often wonder how Santa decides. I took a shot at answering this question by training an image classifier using the ML.NET image classification API and images of Disney heroes and villains to tell whether they're naughty or nice. You shouldn't judge someone by the way they look (even if they are the Evil Queen), so it's safe to say, don't try this at home or with your neighbors 😉. This sample is just for demo purposes. You can find the full code on GitHub.
Prerequisites
This sample was built on a Windows 10 PC, but should work on Mac / Linux PCs
The data
The dataset contains images of Disney characters, both real and animated. They were obtained from the Disney Fandom Wiki. The characters are split into two categories, villains and heroes. For the purpose of this sample, we'll label heroes as nice and villains as naughty. The dataset used to train this model contains 2400 villain (naughty) and 675 hero (nice) images stored in top-level directories with the naughty/nice names. This means that the dataset is unbalanced and may skew predictions as it can be seen when making predictions.
Install NuGet packages
Use the #r
convention to install the necessary NuGet packages used in this sample.
#r "nuget:Microsoft.ML"
#r "nuget:Microsoft.ML.Vision"
#r "nuget:Microsoft.ML.ImageAnalytics"
#r "nuget:SciSharp.TensorFlow.Redist"
Then, import the packages.
open System
open System.IO
open Microsoft.ML
open Microsoft.ML.Data
open Microsoft.ML.Vision
Define data types
Start off by defining the data types containing your input and output schema. You can do this by creating two records, ImageData
and ImagePrediction
. ImageData
is the input which contains the path to image file and the category it belongs to and the ImagePrediction
contains the prediction generated by the model.
[<CLIMutable>]
type ImageData = {
ImagePath: string
Label: string
}
[<CLIMutable>]
type ImagePrediction = {
PredictedLabel: string
}
Training
The training process loads a set of training images, preprocesses them, and uses the ML.NET image classification API to train an image classification model.
Initialize MLContext
Once you've defined the data type, initialize the MLContext
. MLContext
is the entrypoint for ML.NET applications.
let ctx = new MLContext()
Load training data
Then, load the data using the helper function loadImagesFromDirectory
and point it to the top-level directory containing the subdirectories of images in the nice and naughty categories.
let imageData = loadImagesFromDirectory "C:/Datasets/fsadvent2020/Train" true
The loadImagesFromDirectory
function looks like the following:
let loadImagesFromDirectory (path:string) (useDirectoryAsLabel:bool) =
let files = Directory.GetFiles(path, "*",searchOption=SearchOption.AllDirectories)
files
|> Array.filter(fun file ->
(Path.GetExtension(file) = ".jpg") ||
(Path.GetExtension(file) = ".png"))
|> Array.map(fun file ->
let mutable label = Path.GetFileName(file)
if useDirectoryAsLabel then
label <- Directory.GetParent(file).Name
else
let mutable brk = false
for index in 0..label.Length do
while not brk do
if not (label.[index] |> Char.IsLetter) then
label <- label.Substring(0,index)
brk <- true
{ImagePath=file; Label=label}
)
Then, create an IDataView
for the training images. An IDataView
is the way data is represented in ML.NET.
let imageIdv = ctx.Data.LoadFromEnumerable<ImageData>(imageData)
Define training pipeline
Once your data is is loaded into an IDataView
, set the classifier options by using ImageClassificationTrainer.Options
. Use it to define the name of the network architecture, input and output columns, and some additional parameters. The network architecture used in this case is ResNet V2
.
let classifierOptions = ImageClassificationTrainer.Options()
classifierOptions.FeatureColumnName <- "Image"
classifierOptions.LabelColumnName <- "LabelKey"
classifierOptions.TestOnTrainSet <- true
classifierOptions.Arch <- ImageClassificationTrainer.Architecture.ResnetV2101
classifierOptions.MetricsCallback <- Action<ImageClassificationTrainer.ImageClassificationMetrics>(fun x -> printfn "%s" (x.ToString()))
Define the preprocessing steps, image classification trainer (along with the previously defined options) and postprocessing steps.
let pipeline =
EstimatorChain()
.Append(ctx.Transforms.LoadRawImageBytes("Image",null,"ImagePath"))
.Append(ctx.Transforms.Conversion.MapValueToKey("LabelKey","Label"))
.Append(ctx.MulticlassClassification.Trainers.ImageClassification(classifierOptions))
.Append(ctx.Transforms.Conversion.MapKeyToValue("PredictedLabel"))
The ML.NET image classification API leverages a technique known as transfer learning. Transfer learning uses pretrained models (usually neural networks) and retrains the last few layers using new data. This significantly cuts down the amount of time, resources, and data you need to train deep learning models. ML .NET is able to do this with the help of TensorFlow .NET, a set of .NET bindings for the TensorFlow deep learning framework. Although transfer learning usually makes the process of training a deep learning models less resource intensive, the TensorFlow API is usually low level and still requires a significant amount of code. See this transfer learning example from TensorFlow.NET to see how you'd do it in TensorFlow.NET. Although the low-level nature of the TensorFlow API gives you control over what you're building, many times you don't need that level of control. ML.NET through the image classification trainer greatly simplifies this process by providing a high-level API for achieving the same task.
To train the model, use the Fit
method on the training image IDataView
.
let model = pipeline.Fit(imageIdv)
Throughout the training process, you should see output similar to the following:
Phase: Bottleneck Computation, Dataset used: Train, Image Index: 279
Phase: Bottleneck Computation, Dataset used: Train, Image Index: 280
Phase: Bottleneck Computation, Dataset used: Validation, Image Index: 1
With the model trained, it's time to use it to make predictions. Optionally, you can save it for use in other applications.
ctx.Model.Save(model,imageIdv.Schema,"fsadvent2020-model.zip")
Make predictions
Load the test images and create an IDataView
for them. The test images used are of Jack Skellington and The Grinch.
let testImages =
Directory.GetFiles("C:/Datasets/fsadvent2020/Test")
|> Array.map(fun file -> {ImagePath=file; Label=""})
let testImageIdv = ctx.Data.LoadFromEnumerable<ImageData>(testImages)
Then, use the model to make predictions.
let predictionIdv = model.Transform(testImageIdv)
One of the easiest ways to access the predictions is to create an IEnumerable
. To do so, use the CreateEnumerable
method.
let predictions = ctx.Data.CreateEnumerable<ImagePrediction>(predictionIdv,false)
Then, use the built-in F# sequence operations to display the predictions
predictions |> Seq.iter(fun pred ->
printfn "%s is %s" (Path.GetFileNameWithoutExtension(pred.ImagePath)) pred.PredictedLabel)
The output should look like the following:
grinch is Naughty
jack is Naughty
Conclusion
In this post, I showed how you can use the ML.NET and TensorFlow.NET to train an image classification model to classify Disney characters as naughty or nice. Depending on the level of control you need, you might choose to use TensorFlow.NET or if you want a high-level API for training an image classifier, you can use the ML.NET. Most importantly, we figured out that Jack Skellington and The Grinch are naughty, so I guess no gifts for them this year? Happy coding!
Call to Action
Originally, I had planned on writing this article using TensorFlow.Keras, which is part of the SciSharp TensorFlow.NET project. TensorFlow.Keras provides .NET bindings for Keras. Keras provides a high-level API for TensorFlow which makes the process of building custom neural networks much simpler than working with the TensorFlow API. Unfortunately, while trying to adapt my scenario to an existing sample, I ran into an issue. This is not something I would have been able to resolve in time to publish this post, so I defaulted to using ML.NET.
I'm a big fan of the work being done by the SciSharp community and the machine learning and data science capabilities it brings to the .NET ecosystem. The work and efforts are all community driven, and as such, there's plenty of opportunities to contribute. Here are just some examples of ways to contribute, especially from an F# perspective. From my end, I plan on eventually converting this sample to use TensorFlow.Keras. See you in the SciSharp repos! 🙂