Train an image classifier using F# and ML .NET


This post is part of F# Advent 2020. Thank you to Sergey Tihon for organizing this and the rest of the contributors for producing high-quality content. Make sure to check out the rest of the F# Advent 2020 content.

When picking who's on the naughty or nice list, I often wonder how Santa decides. I took a shot at answering this question by training an image classifier using the ML.NET image classification API and images of Disney heroes and villains to tell whether they're naughty or nice. You shouldn't judge someone by the way they look (even if they are the Evil Queen), so it's safe to say, don't try this at home or with your neighbors 😉. This sample is just for demo purposes. You can find the full code on GitHub.


This sample was built on a Windows 10 PC, but should work on Mac / Linux PCs

The data

The dataset contains images of Disney characters, both real and animated. They were obtained from the Disney Fandom Wiki. The characters are split into two categories, villains and heroes. For the purpose of this sample, we'll label heroes as nice and villains as naughty. The dataset used to train this model contains 2400 villain (naughty) and 675 hero (nice) images stored in top-level directories with the naughty/nice names. This means that the dataset is unbalanced and may skew predictions as it can be seen when making predictions.

Install NuGet packages

Use the #r convention to install the necessary NuGet packages used in this sample.

#r "nuget:Microsoft.ML"
#r "nuget:Microsoft.ML.Vision"
#r "nuget:Microsoft.ML.ImageAnalytics"
#r "nuget:SciSharp.TensorFlow.Redist" 

Then, import the packages.

open System
open System.IO
open Microsoft.ML
open Microsoft.ML.Data
open Microsoft.ML.Vision

Define data types

Start off by defining the data types containing your input and output schema. You can do this by creating two records, ImageData and ImagePrediction. ImageData is the input which contains the path to image file and the category it belongs to and the ImagePrediction contains the prediction generated by the model.

type ImageData = {
    ImagePath: string
    Label: string

type ImagePrediction = {
    PredictedLabel: string


The training process loads a set of training images, preprocesses them, and uses the ML.NET image classification API to train an image classification model.

Initialize MLContext

Once you've defined the data type, initialize the MLContext. MLContext is the entrypoint for ML.NET applications.

let ctx = new MLContext()

Load training data

Then, load the data using the helper function loadImagesFromDirectory and point it to the top-level directory containing the subdirectories of images in the nice and naughty categories.

let imageData = loadImagesFromDirectory "C:/Datasets/fsadvent2020/Train" true

The loadImagesFromDirectory function looks like the following:

let loadImagesFromDirectory (path:string) (useDirectoryAsLabel:bool) = 

    let files = Directory.GetFiles(path, "*",searchOption=SearchOption.AllDirectories)

    |> Array.filter(fun file -> 
        (Path.GetExtension(file) = ".jpg") ||
        (Path.GetExtension(file) = ".png"))
    |> file -> 
        let mutable label = Path.GetFileName(file)
        if useDirectoryAsLabel then
            label <-  Directory.GetParent(file).Name
            let mutable brk = false
            for index in 0..label.Length do
                while not brk do
                    if not (label.[index] |> Char.IsLetter) then
                        label <- label.Substring(0,index)
                        brk <- true

        {ImagePath=file; Label=label}

Then, create an IDataView for the training images. An IDataView is the way data is represented in ML.NET.

let imageIdv = ctx.Data.LoadFromEnumerable<ImageData>(imageData)

Define training pipeline

Once your data is is loaded into an IDataView, set the classifier options by using ImageClassificationTrainer.Options. Use it to define the name of the network architecture, input and output columns, and some additional parameters. The network architecture used in this case is ResNet V2.

let classifierOptions = ImageClassificationTrainer.Options()
classifierOptions.FeatureColumnName <- "Image" 
classifierOptions.LabelColumnName <- "LabelKey" 
classifierOptions.TestOnTrainSet <- true  
classifierOptions.Arch <- ImageClassificationTrainer.Architecture.ResnetV2101
classifierOptions.MetricsCallback <- Action<ImageClassificationTrainer.ImageClassificationMetrics>(fun x -> printfn "%s" (x.ToString()))

Define the preprocessing steps, image classification trainer (along with the previously defined options) and postprocessing steps.

let pipeline = 

The ML.NET image classification API leverages a technique known as transfer learning. Transfer learning uses pretrained models (usually neural networks) and retrains the last few layers using new data. This significantly cuts down the amount of time, resources, and data you need to train deep learning models. ML .NET is able to do this with the help of TensorFlow .NET, a set of .NET bindings for the TensorFlow deep learning framework. Although transfer learning usually makes the process of training a deep learning models less resource intensive, the TensorFlow API is usually low level and still requires a significant amount of code. See this transfer learning example from TensorFlow.NET to see how you'd do it in TensorFlow.NET. Although the low-level nature of the TensorFlow API gives you control over what you're building, many times you don't need that level of control. ML.NET through the image classification trainer greatly simplifies this process by providing a high-level API for achieving the same task.

To train the model, use the Fit method on the training image IDataView.

let model = pipeline.Fit(imageIdv)

Throughout the training process, you should see output similar to the following:

Phase: Bottleneck Computation, Dataset used:      Train, Image Index: 279
Phase: Bottleneck Computation, Dataset used:      Train, Image Index: 280
Phase: Bottleneck Computation, Dataset used: Validation, Image Index:   1

With the model trained, it's time to use it to make predictions. Optionally, you can save it for use in other applications.


Make predictions

Load the test images and create an IDataView for them. The test images used are of Jack Skellington and The Grinch.

let testImages = 
    |> file -> {ImagePath=file; Label=""})

let testImageIdv = ctx.Data.LoadFromEnumerable<ImageData>(testImages)

The grinch

Then, use the model to make predictions.

let predictionIdv = model.Transform(testImageIdv)

One of the easiest ways to access the predictions is to create an IEnumerable. To do so, use the CreateEnumerable method.

let predictions = ctx.Data.CreateEnumerable<ImagePrediction>(predictionIdv,false)

Then, use the built-in F# sequence operations to display the predictions

predictions |> Seq.iter(fun pred -> 
    printfn "%s is %s" (Path.GetFileNameWithoutExtension(pred.ImagePath)) pred.PredictedLabel)

The output should look like the following:

grinch is Naughty
jack is Naughty


In this post, I showed how you can use the ML.NET and TensorFlow.NET to train an image classification model to classify Disney characters as naughty or nice. Depending on the level of control you need, you might choose to use TensorFlow.NET or if you want a high-level API for training an image classifier, you can use the ML.NET. Most importantly, we figured out that Jack Skellington and The Grinch are naughty, so I guess no gifts for them this year? Happy coding!

Call to Action

Originally, I had planned on writing this article using TensorFlow.Keras, which is part of the SciSharp TensorFlow.NET project. TensorFlow.Keras provides .NET bindings for Keras. Keras provides a high-level API for TensorFlow which makes the process of building custom neural networks much simpler than working with the TensorFlow API. Unfortunately, while trying to adapt my scenario to an existing sample, I ran into an issue. This is not something I would have been able to resolve in time to publish this post, so I defaulted to using ML.NET.

I'm a big fan of the work being done by the SciSharp community and the machine learning and data science capabilities it brings to the .NET ecosystem. The work and efforts are all community driven, and as such, there's plenty of opportunities to contribute. Here are just some examples of ways to contribute, especially from an F# perspective. From my end, I plan on eventually converting this sample to use TensorFlow.Keras. See you in the SciSharp repos! 🙂

FsLab SciSharp contribute

Send me a message or webmention