rdsx.dev

Thu Feb 09 2023

Digit Classifier API

Digit Classifier API

Web based API for classifying digits.

Overview

This was a quick project to learn how to launch an API that uses a trained model.

Code Explanation

Training in train_model.py

import pickle
 
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
 
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
 
clf = RandomForestClassifier(n_jobs=-1)
 
clf.fit(X_train, y_train)
 
print(clf.score(X_test, y_test))
 
with open('mnist_model.pkl', 'wb') as f:
    pickle.dump(clf, f)
 

Backend in main.py

import io
import pickle
import numpy as np
from PIL import Image
import PIL.Image
import PIL.ImageOps
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
 
with open('mnist_model.pkl', 'rb') as f:
    model = pickle.load(f)
 
app = FastAPI()
 
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)
 
@app.post('/predict-image')
async def predict(file: UploadFile = File(...)):
    pil_image = PIL.Image.open(io.BytesIO(await file.read())).convert('L')
    pil_image = PIL.ImageOps.invert(pil_image)
    pil_image = pil_image.resize((28, 28), Image.Resampling.LANCZOS)
    image_array = np.array(pil_image).reshape(1, -1)
    prediction = model.predict(image_array)
    return {'prediction': int(prediction[0])}
 

Frontend in index.html

<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8" />
    <title>Digit Classifier</title>
  </head>
  <body>
    <input type="file" id="imageInput" accept="image/*" />
    <button onclick="uploadImage()">Classify</button>
    <p id="predictionResult"></p>
 
    <script type="text/javascript">
      async function uploadImage() {
        const input = document.getElementById("imageInput");
        const predictionResult = document.getElementById("predictionResult");
 
        if (!input.files[0]) {
          alert("Please select an image");
          return;
        }
 
        const file = input.files[0];
        const formData = new FormData();
        formData.append("file", file);
 
        try {
          const response = await fetch("http://127.0.0.1:8000/predict-image", {
            method: "POST",
            body: formData,
          });
          const result = await response.json();
          predictionResult.textContent = `Prediction: ${result.prediction}`;
        } catch (error) {
          console.error("Error:", error);
          alert("An error occurred. Please try again.");
        }
      }
    </script>
  </body>
</html>