Bounding Box Prediction from Scratch using PyTorch
source link: https://towardsdatascience.com/bounding-box-prediction-from-scratch-using-pytorch-a8525da51ddc?gi=a3cbe7eff990
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Bounding Box Prediction from Scratch using PyTorch
Multi-Task learning — Bounding Box Regression + Image Classification
Jul 7 ·5min read
Image clicked by author
Object detection is a very popular task in Computer Vision, where, given an image, you predict (usually rectangular) boxes around objects present in the image and also recognize the types of objects. There could be multiple objects in your image and there are various state-of-the-art techniques and architectures to tackle this problem like Faster-RCNN and YOLO v3 .
This article talks about the case when there is only one object of interest present in an image. The focus here is more on how to read an image and its bounding box, resize and perform augmentations correctly, rather than on the model itself. The goal is to have a good grasp of the fundamental ideas behind object detection, which you can extend to get a better understanding of the more complex techniques.
Here’s a link to the notebook consisting of all the code I’ve used for this article: https://jovian.ml/aakanksha-ns/road-signs-bounding-box-prediction
If you’re new to Deep Learning or PyTorch, or just need a refresher, this might interest you:
Problem Statement
Given an image consisting of a road sign, predict a bounding box around the road sign and identify the type of road sign.
There are four distinct classes these signs could belong to:
- Traffic Light
- Stop
- Speed Limit
- Crosswalk
This is called a multi-task learning problem as it involves performing two tasks — 1) regression to find the bounding box coordinates, 2) classification to identify the type of road sign
Sample images. SourceDataset
I’ve used the Road Sign Detection Dataset from Kaggle:
It consists of 877 images. It’s a pretty imbalanced dataset, with most images belonging to the speed limit
class, but since we’re more focused on the bounding box prediction, we can ignore the imbalance.
Loading the Data
The annotations for each image were stored in separate XML
files. I followed the following steps to create the training dataframe:
- Walk through the training directory to get a list of all the
.xml
files. - Parse the
.xml
file usingxml.etree.ElementTree
- Create a dictionary consisting of
filepath
,width
,height
, the bounding box coordinates (xmin
,xmax
,ymin
,ymax
) andclass
for each image and append the dictionary to a list. - Create a pandas dataframe using the list of dictionaries of image stats.
- Label encode the
class
column
Resizing Images and Bounding Boxes
Since training a computer vision model needs images to be of the same size, we need to resize our images and their corresponding bounding boxes. Resizing an image is straightforward but resizing the bounding box is a little tricky because each box is relative to an image and its dimensions.
Here’s how resizing a bounding box works:
- Convert the bounding box into an image (called mask) of the same size as the image it corresponds to. This mask would just have
0
for background and1
for the area covered by the bounding box.
- Resize the mask to the required dimensions.
- Extract bounding box coordinates from the resized mask.
Data Augmentation
Data Augmentation is a technique to generalize our model better by creating new training images by using different variations of the existing images. We have only 800 images in our current training set, so data augmentation is very important to ensure our model doesn’t overfit.
For this problem, I’ve used flip, rotation, center crop and random crop. I’ve talked about various data augmentation techniques in this article:
The only thing to remember here is ensuring that the bounding box is also transformed the same way as the image. To do this we follow the same approach as resizing — convert bounding box to a mask, apply the same transformations to the mask as the original image, and extract the bounding box coordinates.
PyTorch Dataset
Now that we have our data augmentations in place, we can do the train-validation split and create our PyTorch dataset. We normalize the images using ImageNet stats because we’re using a pre-trained ResNet model and apply data augmentations in our dataset while training.
PyTorch Model
For the model, I’ve used a very simple pre-trained resNet-34 model. Since we have two tasks to accomplish here, there are two final layers — the bounding box regressor and the image classifier.
Training
For the loss, we need to take into both classification loss and the bounding box regression loss, so we use a combination of cross-entropy and L1-loss (sum of all the absolute differences between the true value and the predicted coordinates). I’ve scaled the L1-loss by a factor of 1000 because to have both the classification and regression losses in a similar range. Apart from this, it’s a standard PyTorch training loop (using a GPU):
Prediction on Test Images
Now that we’re done with training, we can pick a random image and test our model on it. Even though we had a fairly small number of training images, we end up getting a pretty decent prediction on our test image.
It’ll be a fun exercise to take a real photo using your phone and test out the model. Another interesting experiment would be to not perform any data augmentations and train the model and compare the two models.
Conclusion
Now that we’ve covered the fundamentals of object detection and implemented it from scratch, you can extend these ideas to the multi-object case and try out more complex models like RCNN and YOLO! Also, check out this super cool library called albumentations to perform data augmentations easily.
References
- Deep Learning summer elective at the University of San Francisco’s Master’s in Data Science program
- https://www.usfca.edu/data-institute/certificates/fundamentals-deep-learning
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK