MNIST: Digits Classification#

Open In Colab

The MNIST dataset is considered to be the “hello world” dataset of machine learning. It is a dataset of 60,000 small square 28×28 pixel grayscale images of handwritten single digits between 0 and 9.

In that same spirit, we’ll be making the “hello world” UnionML app using this dataset and a simple linear classifier with sklearn.

With this dataset, we’ll see just how easy it is to create a single-script UnionML app.


This tutorial is adapted from this sklearn guide.

Setup and importing libraries#

First let’s import our dependencies and create the UnionML Dataset and Model objects:

from typing import List, Union

import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

from unionml import Dataset, Model

dataset = Dataset(name="mnist_dataset", test_size=0.2, shuffle=True, targets=["class"])
model = Model(name="mnist_classifier", dataset=dataset)

Let’s break down the code cell above.

We first define a Dataset>, which defines the specification for data that can be used for training and prediction. We also give it a few keyword options:

  • test_size: this indicated the percentage of the data that should be held over for testing. In this case the dataset is divided into test-set (20%) and training set (80%) for evaluation.

  • shuffle: this randomly shuffles the data before splitting into train/test splits.

  • targets: this accepts a list of strings referring to the column names of the dataset.

Then we define a Model>, which refers to the specification for how to actually train the model, evaluate it, and generate predictions from it. Note that we bind the dataset we just defined to the model.

Caching Data#

For convenience, we cache the dataset so that MNIST loading is faster upon subsequent calls to the fetch_openml function:

from pathlib import Path
from joblib import Memory

memory = Memory(Path.home() / "tmp")
fetch_openml_cached = memory.cache(fetch_openml)

We do this so we don’t have to re-download the dataset it every time we need to train a model.

Define Core UnionML Functions#

Run the following command to define our core UnionML app functions:

@dataset.reader(cache=True, cache_version="1")
def reader() -> pd.DataFrame:
    dataset = fetch_openml_cached(
    # randomly sample a subset for faster training
    return dataset.frame.sample(1000, random_state=42)

def init(hyperparameters: dict) -> Pipeline:
    estimator = Pipeline(
            ("scaler", StandardScaler()),
            ("classifier", LogisticRegression()),
    return estimator.set_params(**hyperparameters)

@model.trainer(cache=True, cache_version="1")
def trainer(
    estimator: Pipeline,
    features: pd.DataFrame,
    target: pd.DataFrame,
) -> Pipeline:
    return, target.squeeze())

def predictor(
    estimator: Pipeline,
    features: pd.DataFrame,
) -> List[float]:
    return [float(x) for x in estimator.predict(features)]

def evaluator(
    estimator: Pipeline,
    features: pd.DataFrame,
    target: pd.DataFrame,
) -> float:
    return float(accuracy_score(target.squeeze(), estimator.predict(features)))

The Dataset and Model objects expose function decorators where we define the behavior of our machine learning app:

  • reader() - Register a function for getting data from some external source.

  • init() - Register a function for initializing a model object. This is equivalent to specifying a class or callable using the init kwarg in the Model constructor.

  • trainer() - Register a function for training a model object.

  • predictor() - Register a function that generates predictions from a model object.

  • evaluator() - Register a function for evaluating given model object.

Training a Model Locally#

Then we can train our model locally:

estimator, metrics = model.train(
        "classifier__penalty": "l2",
        "classifier__C": 0.1,
        "classifier__max_iter": 1000,
print(estimator, metrics, sep="\n")

Note that we pass a dictionary of hyperparameters when we invoke evaluating, which, in this case, follows the sklearn conventions for specifying hyperparameters for sklearn Pipelines

Serving on a Gradio Widget#

Finally, let’s create a gradio widget by simply using the predict() method in the gradio.Interface object.

Before we do this, however, we want to define a feature_loader() function to handle the raw input coming from the gradio widget:

import numpy as np

def feature_loader(data: np.ndarray) -> pd.DataFrame:
    return (
        .rename(columns=lambda x: f"pixel{x + 1}")

We also need to take care to handle the None case when we press the clear button on the widget using a lambda function:

import gradio as gr

    fn=lambda img: img if img is None else model.predict(img)[0],

You might notice that the model may not perform as well as you might expect… welcome to the world of machine learning practice! To obtain a better model given a fixed dataset, feel free to play around with the model hyperparameters or even switch up the model type/architecture that’s defined in the trainer function.