rdsx.dev
Fri Mar 15 2024
Image Classifier with PyTorch
PyTorch
Neural Networks
Machine Learning
A simple image classifier built using PyTorch, trained on the CIFAR-10 dataset.
Overview
This project demonstrates how to build a basic image classifier using PyTorch, utilizing the CIFAR-10 dataset. The neural network is trained to classify images into one of ten classes.
Code Explanation
Importing Libraries
We start by importing the necessary libraries:
NumPy and PIL are used for handling data and images.
- PyTorch libraries (torch, torch.nn, torch.optim): are essential for building and training the neural network.
- Torchvision libraries: helps in loading and transforming the CIFAR-10 dataset.
Data Loading and Transformation
- Transform: Converts images to tensors and normalizes them to have a mean of 0.5 and a standard deviation of 0.5.
- CIFAR-10 Dataset: Loaded and transformed. The dataset is divided into training and testing sets.
- DataLoader: Provides an iterator over the dataset with batching and shuffling for training and testing.
Visualizing an Image
Retrieves and displays the size of a sample image from the training dataset.
Defining the Neural Network
- NeuralNet Class: Defines a convolutional neural network with two convolutional layers followed by three fully connected layers.
- Convolutional Layers: Extract features from images.
- Pooling Layer: Reduces the dimensionality of the feature maps.
- Fully Connected Layers: Perform the final classification.
Training the Model
- Network Initialization: Creates an instance of the neural network, sets up the loss function (CrossEntropyLoss), and defines the optimizer (SGD).
- Training Loop: Runs for 30 epochs, calculates loss, and updates the network weights. As we can see in the next image the loss keeps decreasing which indicates that there's room for improvement in the model by increasing the number of epochs.
Saving and Loading the Model
- Saving: The trained model's state is saved to a file.
- Loading: The model is reinitialized and loaded with the saved weights.
Evaluating the Model
The model's accuracy is computed on the test set without updating the weights.
Using the Model for Predictions
Then if we wanted to test the model with data outside of the dataset we could do something like this:
- Image Transformation: Resizes and normalizes new images for prediction.
- Prediction: Loads images, performs predictions, and prints the results.