rdsx.dev
Thu Feb 09 2023


Digit Classifier API
Jupyter Notebook
Python
FastAPI
Web based API for classifying digits.
Overview
This was a quick project to learn how to launch an API that uses a trained model.
Code Explanation
Training in train_model.py
import pickle
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf = RandomForestClassifier(n_jobs=-1)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))
with open('mnist_model.pkl', 'wb') as f:
pickle.dump(clf, f)
Backend in main.py
import io
import pickle
import numpy as np
from PIL import Image
import PIL.Image
import PIL.ImageOps
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
with open('mnist_model.pkl', 'rb') as f:
model = pickle.load(f)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
@app.post('/predict-image')
async def predict(file: UploadFile = File(...)):
pil_image = PIL.Image.open(io.BytesIO(await file.read())).convert('L')
pil_image = PIL.ImageOps.invert(pil_image)
pil_image = pil_image.resize((28, 28), Image.Resampling.LANCZOS)
image_array = np.array(pil_image).reshape(1, -1)
prediction = model.predict(image_array)
return {'prediction': int(prediction[0])}
Frontend in index.html
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>Digit Classifier</title>
</head>
<body>
<input type="file" id="imageInput" accept="image/*" />
<button onclick="uploadImage()">Classify</button>
<p id="predictionResult"></p>
<script type="text/javascript">
async function uploadImage() {
const input = document.getElementById("imageInput");
const predictionResult = document.getElementById("predictionResult");
if (!input.files[0]) {
alert("Please select an image");
return;
}
const file = input.files[0];
const formData = new FormData();
formData.append("file", file);
try {
const response = await fetch("http://127.0.0.1:8000/predict-image", {
method: "POST",
body: formData,
});
const result = await response.json();
predictionResult.textContent = `Prediction: ${result.prediction}`;
} catch (error) {
console.error("Error:", error);
alert("An error occurred. Please try again.");
}
}
</script>
</body>
</html>