Skip to content

Example: MNIST Digits

MNIST is a classic multi-class classification problem - recognizing handwritten digits (0-9).

Dataset

  • Source: MNIST (Modified National Institute of Standards and Technology) via Kaggle mnist-in-csv
  • Task: Classify digits (0-9)
  • Features: 784 (28×28 pixel values as integers 0-255)
  • Classes: 10 (digits 0-9)
  • Training samples: 60,000
  • Test samples: 10,000

Quick Start

The config files for this example are in examples/mnist/:

# 1. Download MNIST CSV data (requires kagglehub)
pip install kagglehub
python3 -c "
import kagglehub
path = kagglehub.dataset_download('oddrationale/mnist-in-csv')
print(f'Downloaded to: {path}')
"

# 2. Point the datacard to your downloaded data
#    Edit examples/mnist/dc_mninst.yaml and update:
#      train_files:
#        - <kagglehub_path>/mnist_train.csv
#      test_files:
#        - <kagglehub_path>/mnist_test.csv

# 3. Train
pilz train \
  --datacard examples/mnist/dc_mninst.yaml \
  --trainsettings examples/mnist/train_settings.yaml

# 4. Evaluate
pilz eval \
  --datacard examples/mnist/dc_mninst.yaml \
  --evalsettings examples/mnist/eval_settings.yaml

Or use the provided script:

cd examples/mnist
# After downloading data and updating dc_mninst.yaml paths
bash run.sh

DataCard Structure

For MNIST, every pixel is a feature (784 total):

features:
  - name: label
    statistical: categorial
    type: int
  - name: 1x1
    statistical: numerical
    type: int
  - name: 1x2
    statistical: numerical
    type: int
  # ... 784 pixels total (1x1 through 28x28)
  - name: 28x28
    statistical: numerical
    type: int

target:
  feature_name: label
  values:
    - 0
    - 1
    - 2
    - 3
    - 4
    - 5
    - 6
    - 7
    - 8
    - 9

train_files:
  - /path/to/mnist_train.csv
test_files:
  - /path/to/mnist_test.csv

Settings (Quick Start)

The train_settings.yaml is configured for a fast first run:

n: 1                # 1 tree per digit (10 trees total)
out_folder: testi
max_depth: 5        # Shallow trees for speed
frac_eval_cat: 0.8
max_eval_fit: 500   # Fewer samples per evaluation
min_eval_fit: 5
n_dims: 2           # Pairwise feature combinations
n_cat: 3            # 3 bins per pixel (low/medium/high)
calcs_per_dim: 200  # Limited calculations per dimension

The eval_settings.yaml:

in_folders:
  - testi
out_folder: eval

Training Time

With the quick-start settings on a modern laptop (Apple Silicon):

  • Training: ~15 minutes (10 digits × 1 tree at depth 5)
  • Evaluation: ~1 second

Training is resumable — if interrupted, it continues where it left off.

Actual Results

Overall Accuracy: 86.7%

Per-Digit Accuracy

Digit Accuracy
0 96.6%
1 96.7%
2 84.4%
3 81.9%
4 82.7%
5 80.0%
6 88.5%
7 87.5%
8 83.6%
9 83.4%

ROC Curves

Per-digit AUC ranges from ~0.96 to ~0.99 across all 10 classes.

Output Files

testi/
├── 0/0.json         # Model for digit "0"
├── 1/0.json         # Model for digit "1"
├── 2/0.json
├── 3/0.json
├── 4/0.json
├── 5/0.json
├── 6/0.json
├── 7/0.json
├── 8/0.json
└── 9/0.json

eval/
├── 0_roc.html ... 9_roc.html   # ROC curves per digit
├── all_roc.html                 # Combined ROC overlay
└── multi_class_result.html      # Per-digit accuracy chart

Multi-Class Strategy

Pilz uses one-vs-rest for multi-class classification:

flowchart LR subgraph "10 Binary Models" M0[0 vs rest] M1[1 vs rest] M2[2 vs rest] M3[3 vs rest] M4[4 vs rest] M5[5 vs rest] M6[6 vs rest] M7[7 vs rest] M8[8 vs rest] M9[9 vs rest] end M0 --> S0[Score: 0.85] M1 --> S1[Score: 0.92] M2 --> S2[Score: 0.78] M3 --> S3[Score: 0.88] M4 --> S4[Score: 0.90] M5 --> S5[Score: 0.82] M6 --> S6[Score: 0.95] M7 --> S7[Score: 0.87] M8 --> S8[Score: 0.91] M9 --> S9[Score: 0.89] S0 --> ARG[ARGMAX] S1 --> ARG S2 --> ARG S3 --> ARG S4 --> ARG S5 --> ARG S6 --> ARG S7 --> ARG S8 --> ARG S9 --> ARG ARG --> P[Predicted: 6] style M0 fill:#e0f0ff style ARG fill:#ccffcc style P fill:#ffff99

For each digit, Pilz trains a binary classifier ("Is this digit or not?"). At prediction time, all 10 models run and the highest score wins.

Tips for MNIST

Quick Start Settings (current)

These are intentionally reduced for fast iteration. Training takes ~15 minutes and gives ~87% accuracy.

For Better Accuracy

Increase these settings in train_settings.yaml:

n: 10               # 10 trees per digit (ensemble)
max_depth: 13       # Deeper trees capture more patterns
n_dims: 3           # Triple feature combinations
n_cat: 5            # Finer pixel value bins
calcs_per_dim: 1000 # More thorough search
max_eval_fit: 5000  # More training samples

With these settings expect: - Training time: 1-2 hours - Accuracy: 90-95% - AUC per digit: >0.98

Incremental Approach

  1. Start with n=1, max_depth=5, n_dims=2 to verify the pipeline works
  2. Increase max_depth to 8, then 13
  3. Add feature combinations: n_dims=3
  4. Add more trees: n=5, then n=10
  5. Fine-tune with more calculations: calcs_per_dim=5000

Monitor Training

Training is resumable — if interrupted, it picks up from the last saved tree. Check testi/ to see progress.