33

Bounding Box Prediction from Scratch using PyTorch

 4 years ago
source link: https://towardsdatascience.com/bounding-box-prediction-from-scratch-using-pytorch-a8525da51ddc?gi=a3cbe7eff990
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Bounding Box Prediction from Scratch using PyTorch

Multi-Task learning — Bounding Box Regression + Image Classification

Jul 7 ·5min read

7na2aa.jpg!web

Image clicked by author

Object detection is a very popular task in Computer Vision, where, given an image, you predict (usually rectangular) boxes around objects present in the image and also recognize the types of objects. There could be multiple objects in your image and there are various state-of-the-art techniques and architectures to tackle this problem like Faster-RCNN and YOLO v3 .

This article talks about the case when there is only one object of interest present in an image. The focus here is more on how to read an image and its bounding box, resize and perform augmentations correctly, rather than on the model itself. The goal is to have a good grasp of the fundamental ideas behind object detection, which you can extend to get a better understanding of the more complex techniques.

Here’s a link to the notebook consisting of all the code I’ve used for this article: https://jovian.ml/aakanksha-ns/road-signs-bounding-box-prediction

If you’re new to Deep Learning or PyTorch, or just need a refresher, this might interest you:

Problem Statement

Given an image consisting of a road sign, predict a bounding box around the road sign and identify the type of road sign.

There are four distinct classes these signs could belong to:

  • Traffic Light
  • Stop
  • Speed Limit
  • Crosswalk

This is called a multi-task learning problem as it involves performing two tasks — 1) regression to find the bounding box coordinates, 2) classification to identify the type of road sign

Sample images. Source

Dataset

I’ve used the Road Sign Detection Dataset from Kaggle:

It consists of 877 images. It’s a pretty imbalanced dataset, with most images belonging to the speed limit class, but since we’re more focused on the bounding box prediction, we can ignore the imbalance.

Loading the Data

The annotations for each image were stored in separate XML files. I followed the following steps to create the training dataframe:

  • Walk through the training directory to get a list of all the .xml files.
  • Parse the .xml file using xml.etree.ElementTree
  • Create a dictionary consisting of filepath , width , height , the bounding box coordinates ( xmin , xmax , ymin , ymax ) and class for each image and append the dictionary to a list.
  • Create a pandas dataframe using the list of dictionaries of image stats.
  • Label encode the class column

Resizing Images and Bounding Boxes

Since training a computer vision model needs images to be of the same size, we need to resize our images and their corresponding bounding boxes. Resizing an image is straightforward but resizing the bounding box is a little tricky because each box is relative to an image and its dimensions.

Here’s how resizing a bounding box works:

  • Convert the bounding box into an image (called mask) of the same size as the image it corresponds to. This mask would just have 0 for background and 1 for the area covered by the bounding box.
Original Image
Mask of the bounding box
  • Resize the mask to the required dimensions.
  • Extract bounding box coordinates from the resized mask.
Helper functions to create mask from bounding box, extract bounding box coordinates from a mask
Function to resize an image, write to a new path, and get resized bounding box coordinates

Data Augmentation

Data Augmentation is a technique to generalize our model better by creating new training images by using different variations of the existing images. We have only 800 images in our current training set, so data augmentation is very important to ensure our model doesn’t overfit.

For this problem, I’ve used flip, rotation, center crop and random crop. I’ve talked about various data augmentation techniques in this article:

The only thing to remember here is ensuring that the bounding box is also transformed the same way as the image. To do this we follow the same approach as resizing — convert bounding box to a mask, apply the same transformations to the mask as the original image, and extract the bounding box coordinates.

Helper functions to center crop and random crop an image
Transforming image and mask
Displaying bounding box

PyTorch Dataset

Now that we have our data augmentations in place, we can do the train-validation split and create our PyTorch dataset. We normalize the images using ImageNet stats because we’re using a pre-trained ResNet model and apply data augmentations in our dataset while training.

train-valid split
Creating train and valid datasets
Setting the batch size and creating data loaders

PyTorch Model

For the model, I’ve used a very simple pre-trained resNet-34 model. Since we have two tasks to accomplish here, there are two final layers — the bounding box regressor and the image classifier.

Training

For the loss, we need to take into both classification loss and the bounding box regression loss, so we use a combination of cross-entropy and L1-loss (sum of all the absolute differences between the true value and the predicted coordinates). I’ve scaled the L1-loss by a factor of 1000 because to have both the classification and regression losses in a similar range. Apart from this, it’s a standard PyTorch training loop (using a GPU):

Prediction on Test Images

Now that we’re done with training, we can pick a random image and test our model on it. Even though we had a fairly small number of training images, we end up getting a pretty decent prediction on our test image.

It’ll be a fun exercise to take a real photo using your phone and test out the model. Another interesting experiment would be to not perform any data augmentations and train the model and compare the two models.

Conclusion

Now that we’ve covered the fundamentals of object detection and implemented it from scratch, you can extend these ideas to the multi-object case and try out more complex models like RCNN and YOLO! Also, check out this super cool library called albumentations to perform data augmentations easily.

References


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK