The aim of this tutorial to use tensorflow object detection API to detect custom objects. Here in this tutorial, we will try to train the network to recognize battery charging image (Why battery charging ? later, this trained net can be used in a robot to detect the charging point from a picture). This is basically an excerpt of sentdex tensorflow tutorial series. I have listed out the steps which I have done to train custom image for quick access.
Download files here
To train the model, first we need to collect training data. This can be done by collecting images from google images. I used a chrome extension ‘Fatkun Batch Download Image’ for saving bulk images. Once the images are downloaded, download and install labelImg to annotate the training data.
git clone https://github.com/tzutalin/labelImg.git sudo apt-get install pyqt5-dev-tools sudo pip3 install lxml make qt5py3 python3 labelImg.py
Browse to the image folder that contains downloaded images. The idea is to create xml label for all the images. Select the image one by one, Click create rectangle box, give the label as ‘charging sign’ and save as xml file(default). ￼￼
Next step is to generate tfrecord for test and train data from generated csv data. Use modified generate_tfrecord.py for this step and generate tfrecord for test and train data.
python3 generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=data/train.record python3 generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=data/test.record
If you are getting error saying object_detection folder does not exist, export the below path. This tutorial needs tensor flow Object detection preinstalled. Please follow this link for more information
# From tensorflow/models/research/ export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
Copy data, training, images and ssd_mobilenet_v1_coco_11_06_2017 directories to tensorflow object_detection folder and start training.
wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz python train.py --logtostderr --train_dir=training/ --pipeline_config_path=training/ssd_mobilenet_v1_pets.config
ssd_mobilenet_v1_pets.config will have paths to both tf records, graph and pbtxt file which contain the classes to detect. The checkpoint files will be created inside training directory.
Next we need to create a frozen inference graph from the latest checkpoint file created. Once done, use the inference program to detect the charging sign.
python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/ssd_mobilenet_v1_pets.config --trained_checkpoint_prefix training/model.ckpt-9871 --output_directory charging_spot_det__inference_graph python custom_inference.py
Since my training data set was small( less than 100) and there was only one class, the inference is buggy. It identifies almost everything as charging sign. but this can be extended with multiple classes and more training data to get accurate results.