root/trunk/opencv/samples/cpp/bagofwords_classification.cpp @ 3714

Revision 3714, 112.4 KB (checked in by mdim, 4 years ago)

added sample on BOW usage to image classification (training and testing is on Pascal VOC dataset)

  • Property svn:eol-style set to native
Line 
1#include <highgui.h>
2#include "opencv2/imgproc/imgproc.hpp"
3#include "opencv2/features2d/features2d.hpp"
4#include "opencv2/ml/ml.hpp"
5#include <fstream>
6#include <iostream>
7#include <memory>
8
9#if defined WIN32 || defined _WIN32
10#include "sys/types.h"
11#endif
12#include <sys/stat.h>
13
14#define DEBUG_DESC_PROGRESS
15
16using namespace cv;
17using namespace std;
18
19const string paramsFile = "params.xml";
20const string vocabularyFile = "vocabulary.xml.gz";
21const string bowImageDescriptorsDir = "/bowImageDescriptors";
22const string svmsDir = "/svms";
23const string plotsDir = "/plots";
24
25void makeDir( const string& dir )
26{
27#if defined WIN32 || defined _WIN32
28    CreateDirectory( dir.c_str(), 0 );
29#else
30    mkdir( dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH );
31#endif
32}
33
34void makeUsedDirs( const string& rootPath )
35{
36    makeDir(rootPath + bowImageDescriptorsDir);
37    makeDir(rootPath + svmsDir);
38    makeDir(rootPath + plotsDir);
39}
40
41/****************************************************************************************\
42*                    Classes to work with PASCAL VOC dataset                             *
43\****************************************************************************************/
44//
45// TODO: refactor this part of the code
46//
47
48
49//used to specify the (sub-)dataset over which operations are performed
50enum ObdDatasetType {CV_OBD_TRAIN, CV_OBD_TEST};
51
52class ObdObject
53{
54public:
55    string object_class;
56    Rect boundingBox;
57};
58
59//extended object data specific to VOC
60enum VocPose {CV_VOC_POSE_UNSPECIFIED, CV_VOC_POSE_FRONTAL, CV_VOC_POSE_REAR, CV_VOC_POSE_LEFT, CV_VOC_POSE_RIGHT};
61class VocObjectData
62{
63public:
64    bool difficult;
65    bool occluded;
66    bool truncated;
67    VocPose pose;
68};
69//enum VocDataset {CV_VOC2007, CV_VOC2008, CV_VOC2009, CV_VOC2010};
70enum VocPlotType {CV_VOC_PLOT_SCREEN, CV_VOC_PLOT_PNG};
71enum VocGT {CV_VOC_GT_NONE, CV_VOC_GT_DIFFICULT, CV_VOC_GT_PRESENT};
72enum VocConfCond {CV_VOC_CCOND_RECALL, CV_VOC_CCOND_SCORETHRESH};
73enum VocTask {CV_VOC_TASK_CLASSIFICATION, CV_VOC_TASK_DETECTION};
74
75class ObdImage
76{
77public:
78    ObdImage(string p_id, string p_path) : id(p_id), path(p_path) {}
79    string id;
80    string path;
81};
82
83//used by getDetectorGroundTruth to sort a two dimensional list of floats in descending order
84class ObdScoreIndexSorter
85{
86public:
87    float score;
88    int image_idx;
89    int obj_idx;
90    bool operator < (const ObdScoreIndexSorter& compare) const {return (score < compare.score);}
91};
92
93class VocData
94{
95public:
96    VocData( const string& vocPath, bool useTestDataset )
97        { initVoc( vocPath, useTestDataset ); }
98    ~VocData(){}
99    /* functions for returning classification/object data for multiple images given an object class */
100    void getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
101    void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects);
102    void getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth);
103    /* functions for returning object data for a single image given an image id */
104    ObdImage getObjects(const string& id, vector<ObdObject>& objects);
105    ObdImage getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
106    ObdImage getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth);
107    /* functions for returning the ground truth (present/absent) for groups of images */
108    void getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth);
109    void getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth);
110    int getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult = true);
111    /* functions for writing VOC-compatible results files */
112    void writeClassifierResultsFile(const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition = 1, const bool overwrite_ifexists = false);
113    /* functions for calculating metrics from a set of classification/detection results */
114    string getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition = -1, const int number = -1);
115    void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking);
116    void calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap);
117    void calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile = false);
118    /* functions for calculating confusion matrices */
119    void calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values);
120    void calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult = true);
121    /* functions for outputting gnuplot output files */
122    void savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title = string(), const VocPlotType plot_type = CV_VOC_PLOT_SCREEN);
123    /* functions for reading in result/ground truth files */
124    void readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present);
125    void readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores);
126    void readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
127    /* functions for getting dataset info */
128    const vector<string>& getObjectClasses();
129    string getResultsDirectory();
130protected:
131    void initVoc( const string& vocPath, const bool useTestDataset );
132    void initVoc2007to2010( const string& vocPath, const bool useTestDataset);
133    void readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present);
134    void readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores);
135    void readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes);
136    void extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data);
137    string getImagePath(const string& input_str);
138
139    void getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present);
140    void calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization = -1);
141
142    //test two bounding boxes to see if they meet the overlap criteria defined in the VOC documentation
143    float testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth);
144    //extract class and dataset name from a VOC-standard classification/detection results filename
145    void extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name);
146    //get classifier ground truth for a single image
147    bool getClassifierGroundTruthImage(const string& obj_class, const string& id);
148
149    //utility functions
150    void getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending = true);
151    int stringToInteger(const string input_str);
152    void readFileToString(const string filename, string& file_contents);
153    string integerToString(const int input_int);
154    string checkFilenamePathsep(const string filename, bool add_trailing_slash = false);
155    void convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images);
156    int extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents);
157    //utility sorter
158    struct orderingSorter
159    {
160        bool operator ()(std::pair<size_t, vector<float>::const_iterator> const& a, std::pair<size_t, vector<float>::const_iterator> const& b)
161        {
162            return (*a.second) > (*b.second);
163        }
164    };
165    //data members
166    string m_vocPath;
167    string m_vocName;
168    //string m_resPath;
169
170    string m_annotation_path;
171    string m_image_path;
172    string m_imageset_path;
173    string m_class_imageset_path;
174
175    vector<string> m_classifier_gt_all_ids;
176    vector<char> m_classifier_gt_all_present;
177    string m_classifier_gt_class;
178
179    //data members
180    string m_train_set;
181    string m_test_set;
182
183    vector<string> m_object_classes;
184
185
186    float m_min_overlap;
187    bool m_sampled_ap;
188};
189
190
191//Return the classification ground truth data for all images of a given VOC object class
192//--------------------------------------------------------------------------------------
193//INPUTS:
194// - obj_class          The VOC object class identifier string
195// - dataset            Specifies whether to extract images from the training or test set
196//OUTPUTS:
197// - images             An array of ObdImage containing info of all images extracted from the ground truth file
198// - object_present     An array of bools specifying whether the object defined by 'obj_class' is present in each image or not
199//NOTES:
200// This function is primarily useful for the classification task, where only
201// whether a given object is present or not in an image is required, and not each object instance's
202// position etc.
203void VocData::getClassImages(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
204{
205    string dataset_str;
206    //generate the filename of the classification ground-truth textfile for the object class
207    if (dataset == CV_OBD_TRAIN)
208    {
209        dataset_str = m_train_set;
210    } else {
211        dataset_str = m_test_set;
212    }
213
214    getClassImages_impl(obj_class, dataset_str, images, object_present);
215}
216
217void VocData::getClassImages_impl(const string& obj_class, const string& dataset_str, vector<ObdImage>& images, vector<char>& object_present)
218{
219    //generate the filename of the classification ground-truth textfile for the object class
220    string gtFilename = m_class_imageset_path;
221    gtFilename.replace(gtFilename.find("%s"),2,obj_class);
222    gtFilename.replace(gtFilename.find("%s"),2,dataset_str);
223
224    //parse the ground truth file, storing in two separate vectors
225    //for the image code and the ground truth value
226    vector<string> image_codes;
227    readClassifierGroundTruth(gtFilename, image_codes, object_present);
228
229    //prepare output arrays
230    images.clear();
231
232    convertImageCodesToObdImages(image_codes, images);
233}
234
235//Return the object data for all images of a given VOC object class
236//-----------------------------------------------------------------
237//INPUTS:
238// - obj_class          The VOC object class identifier string
239// - dataset            Specifies whether to extract images from the training or test set
240//OUTPUTS:
241// - images             An array of ObdImage containing info of all images in chosen dataset (tag, path etc.)
242// - objects            Contains the extended object info (bounding box etc.) for each object instance in each image
243// - object_data        Contains VOC-specific extended object info (marked difficult etc.)
244// - ground_truth       Specifies whether there are any difficult/non-difficult instances of the current
245//                          object class within each image
246//NOTES:
247// This function returns extended object information in addition to the absent/present
248// classification data returned by getClassImages. The objects returned for each image in the 'objects'
249// array are of all object classes present in the image, and not just the class defined by 'obj_class'.
250// 'ground_truth' can be used to determine quickly whether an object instance of the given class is present
251// in an image or not.
252void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects)
253{
254    vector<vector<VocObjectData> > object_data;
255    vector<VocGT> ground_truth;
256
257    getClassObjects(obj_class,dataset,images,objects,object_data,ground_truth);
258}
259
260void VocData::getClassObjects(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<vector<ObdObject> >& objects, vector<vector<VocObjectData> >& object_data, vector<VocGT>& ground_truth)
261{
262    //generate the filename of the classification ground-truth textfile for the object class
263    string gtFilename = m_class_imageset_path;
264    gtFilename.replace(gtFilename.find("%s"),2,obj_class);
265    if (dataset == CV_OBD_TRAIN)
266    {
267        gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
268    } else {
269        gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
270    }
271
272    //parse the ground truth file, storing in two separate vectors
273    //for the image code and the ground truth value
274    vector<string> image_codes;
275    vector<char> object_present;
276    readClassifierGroundTruth(gtFilename, image_codes, object_present);
277
278    //prepare output arrays
279    images.clear();
280    objects.clear();
281    object_data.clear();
282    ground_truth.clear();
283
284    string annotationFilename;
285    vector<ObdObject> image_objects;
286    vector<VocObjectData> image_object_data;
287    VocGT image_gt;
288
289    //transfer to output arrays and read in object data for each image
290    for (size_t i = 0; i < image_codes.size(); ++i)
291    {
292        ObdImage image = getObjects(obj_class, image_codes[i], image_objects, image_object_data, image_gt);
293
294        images.push_back(image);
295        objects.push_back(image_objects);
296        object_data.push_back(image_object_data);
297        ground_truth.push_back(image_gt);
298    }
299}
300
301//Return ground truth data for the objects present in an image with a given UID
302//-----------------------------------------------------------------------------
303//INPUTS:
304// - id                 VOC Dataset unique identifier (string code in form YYYY_XXXXXX where YYYY is the year)
305//OUTPUTS:
306// - obj_class (*3)     Specifies the object class to use to resolve 'ground_truth'
307// - objects            Contains the extended object info (bounding box etc.) for each object in the image
308// - object_data (*2,3) Contains VOC-specific extended object info (marked difficult etc.)
309// - ground_truth (*3)  Specifies whether there are any difficult/non-difficult instances of the current
310//                          object class within the image
311//RETURN VALUE:
312// ObdImage containing path and other details of image file with given code
313//NOTES:
314// There are three versions of this function
315//  * One returns a simple array of objects given an id [1]
316//  * One returns the same as (1) plus VOC specific object data [2]
317//  * One returns the same as (2) plus the ground_truth flag. This also requires an extra input obj_class [3]
318ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects)
319{
320    vector<VocObjectData> object_data;
321    ObdImage image = getObjects(id, objects, object_data);
322
323    return image;
324}
325
326ObdImage VocData::getObjects(const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
327{
328    //first generate the filename of the annotation file
329    string annotationFilename = m_annotation_path;
330
331    annotationFilename.replace(annotationFilename.find("%s"),2,id);
332
333    //extract objects contained in the current image from the xml
334    extractVocObjects(annotationFilename,objects,object_data);
335
336    //generate image path from extracted string code
337    string path = getImagePath(id);
338
339    ObdImage image(id, path);
340    return image;
341}
342
343ObdImage VocData::getObjects(const string& obj_class, const string& id, vector<ObdObject>& objects, vector<VocObjectData>& object_data, VocGT& ground_truth)
344{
345
346    //extract object data (except for ground truth flag)
347    ObdImage image = getObjects(id,objects,object_data);
348
349    //pregenerate a flag to indicate whether the current class is present or not in the image
350    ground_truth = CV_VOC_GT_NONE;
351    //iterate through all objects in current image
352    for (size_t j = 0; j < objects.size(); ++j)
353    {
354        if (objects[j].object_class == obj_class)
355        {
356            if (object_data[j].difficult == false)
357            {
358                //if at least one non-difficult example is present, this flag is always set to CV_VOC_GT_PRESENT
359                ground_truth = CV_VOC_GT_PRESENT;
360                break;
361            } else {
362                //set if at least one object instance is present, but it is marked difficult
363                ground_truth = CV_VOC_GT_DIFFICULT;
364            }
365        }
366    }
367
368    return image;
369}
370
371//Return ground truth data for the presence/absence of a given object class in an arbitrary array of images
372//---------------------------------------------------------------------------------------------------------
373//INPUTS:
374// - obj_class          The VOC object class identifier string
375// - images             An array of ObdImage OR strings containing the images for which ground truth
376//                          will be computed
377//OUTPUTS:
378// - ground_truth       An output array indicating the presence/absence of obj_class within each image
379void VocData::getClassifierGroundTruth(const string& obj_class, const vector<ObdImage>& images, vector<char>& ground_truth)
380{
381    vector<char>(images.size()).swap(ground_truth);
382
383    vector<ObdObject> objects;
384    vector<VocObjectData> object_data;
385    vector<char>::iterator gt_it = ground_truth.begin();
386    for (vector<ObdImage>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
387    {
388        //getObjects(obj_class, it->id, objects, object_data, voc_ground_truth);
389        (*gt_it) = (getClassifierGroundTruthImage(obj_class, it->id));
390    }
391}
392
393void VocData::getClassifierGroundTruth(const string& obj_class, const vector<string>& images, vector<char>& ground_truth)
394{
395    vector<char>(images.size()).swap(ground_truth);
396
397    vector<ObdObject> objects;
398    vector<VocObjectData> object_data;
399    vector<char>::iterator gt_it = ground_truth.begin();
400    for (vector<string>::const_iterator it = images.begin(); it != images.end(); ++it, ++gt_it)
401    {
402        //getObjects(obj_class, (*it), objects, object_data, voc_ground_truth);
403        (*gt_it) = (getClassifierGroundTruthImage(obj_class, (*it)));
404    }
405}
406
407//Return ground truth data for the accuracy of detection results
408//--------------------------------------------------------------
409//INPUTS:
410// - obj_class          The VOC object class identifier string
411// - images             An array of ObdImage containing the images for which ground truth
412//                          will be computed
413// - bounding_boxes     A 2D input array containing the bounding box rects of the objects of
414//                          obj_class which were detected in each image
415//OUTPUTS:
416// - ground_truth       A 2D output array indicating whether each object detection was accurate
417//                          or not
418// - detection_difficult A 2D output array indicating whether the detection fired on an object
419//                          marked as 'difficult'. This allows it to be ignored if necessary
420//                          (the voc documentation specifies objects marked as difficult
421//                          have no effects on the results and are effectively ignored)
422// - (ignore_difficult) If set to true, objects marked as difficult will be ignored when returning
423//                          the number of hits for p-r normalization (default = true)
424//RETURN VALUE:
425//                      Returns the number of object hits in total in the gt to allow proper normalization
426//                          of a p-r curve
427//NOTES:
428// As stated in the VOC documentation, multiple detections of the same object in an image are
429// considered FALSE detections e.g. 5 detections of a single object is counted as 1 correct
430// detection and 4 false detections - it is the responsibility of the participant's system
431// to filter multiple detections from its output
432int VocData::getDetectorGroundTruth(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<Rect> >& bounding_boxes, const vector<vector<float> >& scores, vector<vector<char> >& ground_truth, vector<vector<char> >& detection_difficult, bool ignore_difficult)
433{
434    int recall_normalization = 0;
435
436    /* first create a list of indices referring to the elements of bounding_boxes and scores in
437     * descending order of scores */
438    vector<ObdScoreIndexSorter> sorted_ids;
439    {
440        /* first count how many objects to allow preallocation */
441        int obj_count = 0;
442        CV_Assert(images.size() == bounding_boxes.size());
443        CV_Assert(scores.size() == bounding_boxes.size());
444        for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
445        {
446            CV_Assert(scores[im_idx].size() == bounding_boxes[im_idx].size());
447            obj_count += scores[im_idx].size();
448        }
449        /* preallocate id vector */
450        sorted_ids.resize(obj_count);
451        /* now copy across scores and indexes to preallocated vector */
452        int flat_pos = 0;
453        for (size_t im_idx = 0; im_idx < scores.size(); ++im_idx)
454        {
455            for (size_t ob_idx = 0; ob_idx < scores[im_idx].size(); ++ob_idx)
456            {
457                sorted_ids[flat_pos].score = scores[im_idx][ob_idx];
458                sorted_ids[flat_pos].image_idx = im_idx;
459                sorted_ids[flat_pos].obj_idx = ob_idx;
460                ++flat_pos;
461            }
462        }
463        /* and sort the vector in descending order of score */
464        std::sort(sorted_ids.begin(),sorted_ids.end());
465        std::reverse(sorted_ids.begin(),sorted_ids.end());
466    }
467
468    /* prepare ground truth + difficult vector (1st dimension) */
469    vector<vector<char> >(images.size()).swap(ground_truth);
470    vector<vector<char> >(images.size()).swap(detection_difficult);
471    vector<vector<char> > detected(images.size());
472
473    vector<vector<ObdObject> > img_objects(images.size());
474    vector<vector<VocObjectData> > img_object_data(images.size());
475    /* preload object ground truth bounding box data */
476    {
477        vector<vector<ObdObject> > img_objects_all(images.size());
478        vector<vector<VocObjectData> > img_object_data_all(images.size());
479        for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
480        {
481            /* prepopulate ground truth bounding boxes */
482            getObjects(images[image_idx].id, img_objects_all[image_idx], img_object_data_all[image_idx]);
483            /* meanwhile, also set length of target ground truth + difficult vector to same as number of object detections (2nd dimension) */
484            ground_truth[image_idx].resize(bounding_boxes[image_idx].size());
485            detection_difficult[image_idx].resize(bounding_boxes[image_idx].size());
486        }
487
488        /* save only instances of the object class concerned */
489        for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
490        {
491            for (size_t obj_idx = 0; obj_idx < img_objects_all[image_idx].size(); ++obj_idx)
492            {
493                if (img_objects_all[image_idx][obj_idx].object_class == obj_class)
494                {
495                    img_objects[image_idx].push_back(img_objects_all[image_idx][obj_idx]);
496                    img_object_data[image_idx].push_back(img_object_data_all[image_idx][obj_idx]);
497                }
498            }
499            detected[image_idx].resize(img_objects[image_idx].size(), false);
500        }
501    }
502
503    /* calculate the total number of objects in the ground truth for the current dataset */
504    {
505        vector<ObdImage> gt_images;
506        vector<char> gt_object_present;
507        getClassImages(obj_class, dataset, gt_images, gt_object_present);
508
509        for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
510        {
511            vector<ObdObject> gt_img_objects;
512            vector<VocObjectData> gt_img_object_data;
513            getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
514            for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
515            {
516                if (gt_img_objects[obj_idx].object_class == obj_class)
517                {
518                    if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
519                        ++recall_normalization;
520                }
521            }
522        }
523    }
524
525#ifdef PR_DEBUG
526    int printed_count = 0;
527#endif
528    /* now iterate through detections in descending order of score, assigning to ground truth bounding boxes if possible */
529    for (size_t detect_idx = 0; detect_idx < sorted_ids.size(); ++detect_idx)
530    {
531        //read in indexes to make following code easier to read
532        int im_idx = sorted_ids[detect_idx].image_idx;
533        int ob_idx = sorted_ids[detect_idx].obj_idx;
534        //set ground truth for the current object to false by default
535        ground_truth[im_idx][ob_idx] = false;
536        detection_difficult[im_idx][ob_idx] = false;
537        float maxov = -1.0;
538        bool max_is_difficult = false;
539        int max_gt_obj_idx = -1;
540        //-- for each detected object iterate through objects present in the bounding box ground truth --
541        for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
542        {
543            if (detected[im_idx][gt_obj_idx] == false)
544            {
545                //check if the detected object and ground truth object overlap by a sufficient margin
546                float ov = testBoundingBoxesForOverlap(bounding_boxes[im_idx][ob_idx], img_objects[im_idx][gt_obj_idx].boundingBox);
547                if (ov != -1.0)
548                {
549                    //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
550                    if (ov > maxov)
551                    {
552                        maxov = ov;
553                        max_gt_obj_idx = gt_obj_idx;
554                        //store whether the maximum detection is marked as difficult or not
555                        max_is_difficult = (img_object_data[im_idx][gt_obj_idx].difficult);
556                    }
557                }
558            }
559        }
560        //-- if a match was found, set the ground truth of the current object to true --
561        if (maxov != -1.0)
562        {
563            CV_Assert(max_gt_obj_idx != -1);
564            ground_truth[im_idx][ob_idx] = true;
565            //store whether the maximum detection was marked as 'difficult' or not
566            detection_difficult[im_idx][ob_idx] = max_is_difficult;
567            //remove the ground truth object so it doesn't match with subsequent detected objects
568            //** this is the behaviour defined by the voc documentation **
569            detected[im_idx][max_gt_obj_idx] = true;
570        }
571#ifdef PR_DEBUG
572        if (printed_count < 10)
573        {
574            cout << printed_count << ": id=" << images[im_idx].id << ", score=" << scores[im_idx][ob_idx] << " (" << ob_idx << ") [" << bounding_boxes[im_idx][ob_idx].x << "," <<
575                    bounding_boxes[im_idx][ob_idx].y << "," << bounding_boxes[im_idx][ob_idx].width + bounding_boxes[im_idx][ob_idx].x <<
576                    "," << bounding_boxes[im_idx][ob_idx].height + bounding_boxes[im_idx][ob_idx].y << "] detected=" << ground_truth[im_idx][ob_idx] <<
577                    ", difficult=" << detection_difficult[im_idx][ob_idx] << endl;
578            ++printed_count;
579            /* print ground truth */
580            for (int gt_obj_idx = 0; gt_obj_idx < img_objects[im_idx].size(); ++gt_obj_idx)
581            {
582                cout << "    GT: [" << img_objects[im_idx][gt_obj_idx].boundingBox.x << "," <<
583                        img_objects[im_idx][gt_obj_idx].boundingBox.y << "," << img_objects[im_idx][gt_obj_idx].boundingBox.width + img_objects[im_idx][gt_obj_idx].boundingBox.x <<
584                        "," << img_objects[im_idx][gt_obj_idx].boundingBox.height + img_objects[im_idx][gt_obj_idx].boundingBox.y << "]";
585                if (gt_obj_idx == max_gt_obj_idx) cout << " <--- (" << maxov << " overlap)";
586                cout << endl;
587            }
588        }
589#endif
590    }
591
592    return recall_normalization;
593}
594
595//Write VOC-compliant classifier results file
596//-------------------------------------------
597//INPUTS:
598// - obj_class          The VOC object class identifier string
599// - dataset            Specifies whether working with the training or test set
600// - images             An array of ObdImage containing the images for which data will be saved to the result file
601// - scores             A corresponding array of confidence scores given a query
602// - (competition)      If specified, defines which competition the results are for (see VOC documentation - default 1)
603//NOTES:
604// The result file path and filename are determined automatically using m_results_directory as a base
605void VocData::writeClassifierResultsFile( const string& out_dir, const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<float>& scores, const int competition, const bool overwrite_ifexists)
606{
607    CV_Assert(images.size() == scores.size());
608
609    string output_file_base, output_file;
610    if (dataset == CV_OBD_TRAIN)
611    {
612        output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_train_set + "_" + obj_class;
613    } else {
614        output_file_base = out_dir + "/comp" + integerToString(competition) + "_cls_" + m_test_set + "_" + obj_class;
615    }
616    output_file = output_file_base + ".txt";
617
618    //check if file exists, and if so create a numbered new file instead
619    if (overwrite_ifexists == false)
620    {
621        struct stat stFileInfo;
622        if (stat(output_file.c_str(),&stFileInfo) == 0)
623        {
624            string output_file_new;
625            int filenum = 0;
626            do
627            {
628                ++filenum;
629                output_file_new = output_file_base + "_" + integerToString(filenum);
630                output_file = output_file_new + ".txt";
631            } while (stat(output_file.c_str(),&stFileInfo) == 0);
632        }
633    }
634
635    //output data to file
636    std::ofstream result_file(output_file.c_str());
637    if (result_file.is_open())
638    {
639        for (size_t i = 0; i < images.size(); ++i)
640        {
641            result_file << images[i].id << " " << scores[i] << endl;
642        }
643        result_file.close();
644    } else {
645        string err_msg = "could not open classifier results file '" + output_file + "' for writing. Before running for the first time, a 'results' subdirectory should be created within the VOC dataset base directory. e.g. if the VOC data is stored in /VOC/VOC2010 then the path /VOC/results must be created.";
646        CV_Error(CV_StsError,err_msg.c_str());
647    }
648}
649
650//---------------------------------------
651//CALCULATE METRICS FROM VOC RESULTS DATA
652//---------------------------------------
653
654//Utility function to construct a VOC-standard classification results filename
655//----------------------------------------------------------------------------
656//INPUTS:
657// - obj_class          The VOC object class identifier string
658// - task               Specifies whether to generate a filename for the classification or detection task
659// - dataset            Specifies whether working with the training or test set
660// - (competition)      If specified, defines which competition the results are for (see VOC documentation
661//                      default of -1 means this is set to 1 for the classification task and 3 for the detection task)
662// - (number)           If specified and above 0, defines which of a number of duplicate results file produced for a given set of
663//                      of settings should be used (this number will be added as a postfix to the filename)
664//NOTES:
665// This is primarily useful for returning the filename of a classification file previously computed using writeClassifierResultsFile
666// for example when calling calcClassifierPrecRecall
667string VocData::getResultsFilename(const string& obj_class, const VocTask task, const ObdDatasetType dataset, const int competition, const int number)
668{
669    if ((competition < 1) && (competition != -1))
670        CV_Error(CV_StsBadArg,"competition argument should be a positive non-zero number or -1 to accept the default");
671    if ((number < 1) && (number != -1))
672        CV_Error(CV_StsBadArg,"number argument should be a positive non-zero number or -1 to accept the default");
673
674    string dset, task_type;
675
676    if (dataset == CV_OBD_TRAIN)
677    {
678        dset = m_train_set;
679    } else {
680        dset = m_test_set;
681    }
682
683    int comp = competition;
684    if (task == CV_VOC_TASK_CLASSIFICATION)
685    {
686        task_type = "cls";
687        if (comp == -1) comp = 1;
688    } else {
689        task_type = "det";
690        if (comp == -1) comp = 3;
691    }
692
693    stringstream ss;
694    if (number < 1)
695    {
696        ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << ".txt";
697    } else {
698        ss << "comp" << comp << "_" << task_type << "_" << dset << "_" << obj_class << "_" << number << ".txt";
699    }
700
701    string filename = ss.str();
702    return filename;
703}
704
705//Calculate metrics for classification results
706//--------------------------------------------
707//INPUTS:
708// - ground_truth       A vector of booleans determining whether the currently tested class is present in each input image
709// - scores             A vector containing the similarity score for each input image (higher is more similar)
710//OUTPUTS:
711// - precision          A vector containing the precision calculated at each datapoint of a p-r curve generated from the result set
712// - recall             A vector containing the recall calculated at each datapoint of a p-r curve generated from the result set
713// - ap                The ap metric calculated from the result set
714// - (ranking)          A vector of the same length as 'ground_truth' and 'scores' containing the order of the indices in both of
715//                      these arrays when sorting by the ranking score in descending order
716//NOTES:
717// The result file path and filename are determined automatically using m_results_directory as a base
718void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking)
719{
720    vector<char> res_ground_truth;
721    getClassifierGroundTruth(obj_class, images, res_ground_truth);
722
723    calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
724}
725
726void VocData::calcClassifierPrecRecall(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap)
727{
728    vector<char> res_ground_truth;
729    getClassifierGroundTruth(obj_class, images, res_ground_truth);
730
731    vector<size_t> ranking;
732    calcPrecRecall_impl(res_ground_truth, scores, precision, recall, ap, ranking);
733}
734
735//< Overloaded version which accepts VOC classification result file input instead of array of scores/ground truth >
736//INPUTS:
737// - input_file         The path to the VOC standard results file to use for calculating precision/recall
738//                      If a full path is not specified, it is assumed this file is in the VOC standard results directory
739//                      A VOC standard filename can be retrieved (as used by writeClassifierResultsFile) by calling  getClassifierResultsFilename
740
741void VocData::calcClassifierPrecRecall(const string& input_file, vector<float>& precision, vector<float>& recall, float& ap, bool outputRankingFile)
742{
743    //read in classification results file
744    vector<string> res_image_codes;
745    vector<float> res_scores;
746
747    string input_file_std = checkFilenamePathsep(input_file);
748    readClassifierResultsFile(input_file_std, res_image_codes, res_scores);
749
750    //extract the object class and dataset from the results file filename
751    string class_name, dataset_name;
752    extractDataFromResultsFilename(input_file_std, class_name, dataset_name);
753
754    //generate the ground truth for the images extracted from the results file
755    vector<char> res_ground_truth;
756
757    getClassifierGroundTruth(class_name, res_image_codes, res_ground_truth);
758
759    if (outputRankingFile)
760    {
761        /* 1. store sorting order by score (descending) in 'order' */
762        vector<std::pair<size_t, vector<float>::const_iterator> > order(res_scores.size());
763
764        size_t n = 0;
765        for (vector<float>::const_iterator it = res_scores.begin(); it != res_scores.end(); ++it, ++n)
766            order[n] = make_pair(n, it);
767
768        std::sort(order.begin(),order.end(),orderingSorter());
769
770        /* 2. save ranking results to text file */
771        string input_file_std = checkFilenamePathsep(input_file);
772        size_t fnamestart = input_file_std.rfind("/");
773        string scoregt_file_str = input_file_std.substr(0,fnamestart+1) + "scoregt_" + class_name + ".txt";
774        std::ofstream scoregt_file(scoregt_file_str.c_str());
775        if (scoregt_file.is_open())
776        {
777            for (size_t i = 0; i < res_scores.size(); ++i)
778            {
779                scoregt_file << res_image_codes[order[i].first] << " " << res_scores[order[i].first] << " " << res_ground_truth[order[i].first] << endl;
780            }
781            scoregt_file.close();
782        } else {
783            string err_msg = "could not open scoregt file '" + scoregt_file_str + "' for writing.";
784            CV_Error(CV_StsError,err_msg.c_str());
785        }
786    }
787
788    //finally, calculate precision+recall+ap
789    vector<size_t> ranking;
790    calcPrecRecall_impl(res_ground_truth,res_scores,precision,recall,ap,ranking);
791}
792
793//< Protected implementation of Precision-Recall calculation used by both calcClassifierPrecRecall and calcDetectorPrecRecall >
794
795void VocData::calcPrecRecall_impl(const vector<char>& ground_truth, const vector<float>& scores, vector<float>& precision, vector<float>& recall, float& ap, vector<size_t>& ranking, int recall_normalization)
796{
797    CV_Assert(ground_truth.size() == scores.size());
798
799    //add extra element for p-r at 0 recall (in case that first retrieved is positive)
800    vector<float>(scores.size()+1).swap(precision);
801    vector<float>(scores.size()+1).swap(recall);
802
803    // SORT RESULTS BY THEIR SCORE
804    /* 1. store sorting order in 'order' */
805    VocData::getSortOrder(scores, ranking);
806
807#ifdef PR_DEBUG
808    std::ofstream scoregt_file("D:/pr.txt");
809    if (scoregt_file.is_open())
810    {
811       for (int i = 0; i < scores.size(); ++i)
812       {
813           scoregt_file << scores[ranking[i]] << " " << ground_truth[ranking[i]] << endl;
814       }
815       scoregt_file.close();
816    }
817#endif
818
819    // CALCULATE PRECISION+RECALL
820
821    int retrieved_hits = 0;
822
823    int recall_norm;
824    if (recall_normalization != -1)
825    {
826        recall_norm = recall_normalization;
827    } else {
828        recall_norm = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
829    }
830
831    ap = 0;
832    recall[0] = 0;
833    for (size_t idx = 0; idx < ground_truth.size(); ++idx)
834    {
835        if (ground_truth[ranking[idx]] == true) ++retrieved_hits;
836
837        precision[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(idx+1);
838        recall[idx+1] = static_cast<float>(retrieved_hits)/static_cast<float>(recall_norm);
839
840        if (idx == 0)
841        {
842            //add further point at 0 recall with the same precision value as the first computed point
843            precision[idx] = precision[idx+1];
844        }
845        if (recall[idx+1] == 1.0)
846        {
847            //if recall = 1, then end early as all positive images have been found
848            recall.resize(idx+2);
849            precision.resize(idx+2);
850            break;
851        }
852    }
853
854    /* ap calculation */
855    if (m_sampled_ap == false)
856    {
857        // FOR VOC2010+ AP IS CALCULATED FROM ALL DATAPOINTS
858        /* make precision monotonically decreasing for purposes of calculating ap */
859        vector<float> precision_monot(precision.size());
860        vector<float>::iterator prec_m_it = precision_monot.begin();
861        for (vector<float>::iterator prec_it = precision.begin(); prec_it != precision.end(); ++prec_it, ++prec_m_it)
862        {
863            vector<float>::iterator max_elem;
864            max_elem = std::max_element(prec_it,precision.end());
865            (*prec_m_it) = (*max_elem);
866        }
867        /* calculate ap */
868        for (size_t idx = 0; idx < (recall.size()-1); ++idx)
869        {
870            ap += (recall[idx+1] - recall[idx])*precision_monot[idx+1] +   //no need to take min of prec - is monotonically decreasing
871                    0.5*(recall[idx+1] - recall[idx])*std::abs(precision_monot[idx+1] - precision_monot[idx]);
872        }
873    } else {
874        // FOR BEFORE VOC2010 AP IS CALCULATED BY SAMPLING PRECISION AT RECALL 0.0,0.1,..,1.0
875
876        for (float recall_pos = 0.0; recall_pos <= 1.0; recall_pos += 0.1)
877        {
878            //find iterator of the precision corresponding to the first recall >= recall_pos
879            vector<float>::iterator recall_it = recall.begin();
880            vector<float>::iterator prec_it = precision.begin();
881
882            while ((*recall_it) < recall_pos)
883            {
884                ++recall_it;
885                ++prec_it;
886                if (recall_it == recall.end()) break;
887            }
888
889            /* if no recall >= recall_pos found, this level of recall is never reached so stop adding to ap */
890            if (recall_it == recall.end()) break;
891
892            /* if the prec_it is valid, compute the max precision at this level of recall or higher */
893            vector<float>::iterator max_prec = std::max_element(prec_it,precision.end());
894
895            ap += (*max_prec)/11;
896        }
897    }
898}
899
900/* functions for calculating confusion matrix rows */
901
902//Calculate rows of a confusion matrix
903//------------------------------------
904//INPUTS:
905// - obj_class          The VOC object class identifier string for the confusion matrix row to compute
906// - images             An array of ObdImage containing the images to use for the computation
907// - scores             A corresponding array of confidence scores for the presence of obj_class in each image
908// - cond               Defines whether to use a cut off point based on recall (CV_VOC_CCOND_RECALL) or score
909//                      (CV_VOC_CCOND_SCORETHRESH) the latter is useful for classifier detections where positive
910//                      values are positive detections and negative values are negative detections
911// - threshold          Threshold value for cond. In case of CV_VOC_CCOND_RECALL, is proportion recall (e.g. 0.5).
912//                      In the case of CV_VOC_CCOND_SCORETHRESH is the value above which to count results.
913//OUTPUTS:
914// - output_headers     An output vector of object class headers for the confusion matrix row
915// - output_values      An output vector of values for the confusion matrix row corresponding to the classes
916//                      defined in output_headers
917//NOTES:
918// The methodology used by the classifier version of this function is that true positives have a single unit
919// added to the obj_class column in the confusion matrix row, whereas false positives have a single unit
920// distributed in proportion between all the columns in the confusion matrix row corresponding to the objects
921// present in the image.
922void VocData::calcClassifierConfMatRow(const string& obj_class, const vector<ObdImage>& images, const vector<float>& scores, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values)
923{
924    CV_Assert(images.size() == scores.size());
925
926    // SORT RESULTS BY THEIR SCORE
927    /* 1. store sorting order in 'ranking' */
928    vector<size_t> ranking;
929    VocData::getSortOrder(scores, ranking);
930
931    // CALCULATE CONFUSION MATRIX ENTRIES
932    /* prepare object category headers */
933    output_headers = m_object_classes;
934    vector<float>(output_headers.size(),0.0).swap(output_values);
935    /* find the index of the target object class in the headers for later use */
936    int target_idx;
937    {
938        vector<string>::iterator target_idx_it = std::find(output_headers.begin(),output_headers.end(),obj_class);
939        /* if the target class can not be found, raise an exception */
940        if (target_idx_it == output_headers.end())
941        {
942            string err_msg = "could not find the target object class '" + obj_class + "' in list of valid classes.";
943            CV_Error(CV_StsError,err_msg.c_str());
944        }
945        /* convert iterator to index */
946        target_idx = std::distance(output_headers.begin(),target_idx_it);
947    }
948
949    /* prepare variables related to calculating recall if using the recall threshold */
950    int retrieved_hits = 0;
951    int total_relevant;
952    if (cond == CV_VOC_CCOND_RECALL)
953    {
954        vector<char> ground_truth;
955        /* in order to calculate the total number of relevant images for normalization of recall
956            it's necessary to extract the ground truth for the images under consideration */
957        getClassifierGroundTruth(obj_class, images, ground_truth);
958        total_relevant = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
959    }
960
961    /* iterate through images */
962    vector<ObdObject> img_objects;
963    vector<VocObjectData> img_object_data;
964    int total_images = 0;
965    for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
966    {
967        /* if using the score as the break condition, check for it now */
968        if (cond == CV_VOC_CCOND_SCORETHRESH)
969        {
970            if (scores[ranking[image_idx]] <= threshold) break;
971        }
972        /* if continuing for this iteration, increment the image counter for later normalization */
973        ++total_images;
974        /* for each image retrieve the objects contained */
975        getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
976        //check if the tested for object class is present
977        if (getClassifierGroundTruthImage(obj_class, images[ranking[image_idx]].id))
978        {
979            //if the target class is present, assign fully to the target class element in the confusion matrix row
980            output_values[target_idx] += 1.0;
981            if (cond == CV_VOC_CCOND_RECALL) ++retrieved_hits;
982        } else {
983            //first delete all objects marked as difficult
984            for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
985            {
986                if (img_object_data[obj_idx].difficult == true)
987                {
988                    vector<ObdObject>::iterator it1 = img_objects.begin();
989                    std::advance(it1,obj_idx);
990                    img_objects.erase(it1);
991                    vector<VocObjectData>::iterator it2 = img_object_data.begin();
992                    std::advance(it2,obj_idx);
993                    img_object_data.erase(it2);
994                    --obj_idx;
995                }
996            }
997            //if the target class is not present, add values to the confusion matrix row in equal proportions to all objects present in the image
998            for (size_t obj_idx = 0; obj_idx < img_objects.size(); ++obj_idx)
999            {
1000                //find the index of the currently considered object
1001                vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[obj_idx].object_class);
1002                //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1003                if (class_idx_it == output_headers.end())
1004                {
1005                    string err_msg = "could not find object class '" + img_objects[obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1006                    CV_Error(CV_StsError,err_msg.c_str());
1007                }
1008                /* convert iterator to index */
1009                int class_idx = std::distance(output_headers.begin(),class_idx_it);
1010                //add to confusion matrix row in proportion
1011                output_values[class_idx] += 1.0/static_cast<float>(img_objects.size());
1012            }
1013        }
1014        //check break conditions if breaking on certain level of recall
1015        if (cond == CV_VOC_CCOND_RECALL)
1016        {
1017            if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1018        }
1019    }
1020    /* finally, normalize confusion matrix row */
1021    for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1022    {
1023        (*it) /= static_cast<float>(total_images);
1024    }
1025}
1026
1027// NOTE: doesn't ignore repeated detections
1028void VocData::calcDetectorConfMatRow(const string& obj_class, const ObdDatasetType dataset, const vector<ObdImage>& images, const vector<vector<float> >& scores, const vector<vector<Rect> >& bounding_boxes, const VocConfCond cond, const float threshold, vector<string>& output_headers, vector<float>& output_values, bool ignore_difficult)
1029{
1030    CV_Assert(images.size() == scores.size());
1031    CV_Assert(images.size() == bounding_boxes.size());
1032
1033    //collapse scores and ground_truth vectors into 1D vectors to allow ranking
1034    /* define final flat vectors */
1035    vector<string> images_flat;
1036    vector<float> scores_flat;
1037    vector<Rect> bounding_boxes_flat;
1038    {
1039        /* first count how many objects to allow preallocation */
1040        int obj_count = 0;
1041        CV_Assert(scores.size() == bounding_boxes.size());
1042        for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1043        {
1044            CV_Assert(scores[img_idx].size() == bounding_boxes[img_idx].size());
1045            for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1046            {
1047                ++obj_count;
1048            }
1049        }
1050        /* preallocate vectors */
1051        images_flat.resize(obj_count);
1052        scores_flat.resize(obj_count);
1053        bounding_boxes_flat.resize(obj_count);
1054        /* now copy across to preallocated vectors */
1055        int flat_pos = 0;
1056        for (size_t img_idx = 0; img_idx < scores.size(); ++img_idx)
1057        {
1058            for (size_t obj_idx = 0; obj_idx < scores[img_idx].size(); ++obj_idx)
1059            {
1060                images_flat[flat_pos] = images[img_idx].id;
1061                scores_flat[flat_pos] = scores[img_idx][obj_idx];
1062                bounding_boxes_flat[flat_pos] = bounding_boxes[img_idx][obj_idx];
1063                ++flat_pos;
1064            }
1065        }
1066    }
1067
1068    // SORT RESULTS BY THEIR SCORE
1069    /* 1. store sorting order in 'ranking' */
1070    vector<size_t> ranking;
1071    VocData::getSortOrder(scores_flat, ranking);
1072
1073    // CALCULATE CONFUSION MATRIX ENTRIES
1074    /* prepare object category headers */
1075    output_headers = m_object_classes;
1076    output_headers.push_back("background");
1077    vector<float>(output_headers.size(),0.0).swap(output_values);
1078
1079    /* prepare variables related to calculating recall if using the recall threshold */
1080    int retrieved_hits = 0;
1081    int total_relevant = 0;
1082    if (cond == CV_VOC_CCOND_RECALL)
1083    {
1084//        vector<char> ground_truth;
1085//        /* in order to calculate the total number of relevant images for normalization of recall
1086//            it's necessary to extract the ground truth for the images under consideration */
1087//        getClassifierGroundTruth(obj_class, images, ground_truth);
1088//        total_relevant = std::count_if(ground_truth.begin(),ground_truth.end(),std::bind2nd(std::equal_to<bool>(),true));
1089        /* calculate the total number of objects in the ground truth for the current dataset */
1090        vector<ObdImage> gt_images;
1091        vector<char> gt_object_present;
1092        getClassImages(obj_class, dataset, gt_images, gt_object_present);
1093
1094        for (size_t image_idx = 0; image_idx < gt_images.size(); ++image_idx)
1095        {
1096            vector<ObdObject> gt_img_objects;
1097            vector<VocObjectData> gt_img_object_data;
1098            getObjects(gt_images[image_idx].id, gt_img_objects, gt_img_object_data);
1099            for (size_t obj_idx = 0; obj_idx < gt_img_objects.size(); ++obj_idx)
1100            {
1101                if (gt_img_objects[obj_idx].object_class == obj_class)
1102                {
1103                    if ((gt_img_object_data[obj_idx].difficult == false) || (ignore_difficult == false))
1104                        ++total_relevant;
1105                }
1106            }
1107        }
1108    }
1109
1110    /* iterate through objects */
1111    vector<ObdObject> img_objects;
1112    vector<VocObjectData> img_object_data;
1113    int total_objects = 0;
1114    for (size_t image_idx = 0; image_idx < images.size(); ++image_idx)
1115    {
1116        /* if using the score as the break condition, check for it now */
1117        if (cond == CV_VOC_CCOND_SCORETHRESH)
1118        {
1119            if (scores_flat[ranking[image_idx]] <= threshold) break;
1120        }
1121        /* increment the image counter for later normalization */
1122        ++total_objects;
1123        /* for each image retrieve the objects contained */
1124        getObjects(images[ranking[image_idx]].id, img_objects, img_object_data);
1125
1126        //find the ground truth object which has the highest overlap score with the detected object
1127        float maxov = -1.0;
1128        size_t max_gt_obj_idx = -1;
1129        //-- for each detected object iterate through objects present in ground truth --
1130        for (size_t gt_obj_idx = 0; gt_obj_idx < img_objects.size(); ++gt_obj_idx)
1131        {
1132            //check difficulty flag
1133            if (ignore_difficult || (img_object_data[gt_obj_idx].difficult = false))
1134            {
1135                //if the class matches, then check if the detected object and ground truth object overlap by a sufficient margin
1136                int ov = testBoundingBoxesForOverlap(bounding_boxes_flat[ranking[image_idx]], img_objects[gt_obj_idx].boundingBox);
1137                if (ov != -1.0)
1138                {
1139                    //if all conditions are met store the overlap score and index (as objects are assigned to the highest scoring match)
1140                    if (ov > maxov)
1141                    {
1142                        maxov = ov;
1143                        max_gt_obj_idx = gt_obj_idx;
1144                    }
1145                }
1146            }
1147        }
1148
1149        //assign to appropriate object class if an object was detected
1150        if (maxov != -1.0)
1151        {
1152            //find the index of the currently considered object
1153            vector<string>::iterator class_idx_it = std::find(output_headers.begin(),output_headers.end(),img_objects[max_gt_obj_idx].object_class);
1154            //if the class name extracted from the ground truth file could not be found in the list of available classes, raise an exception
1155            if (class_idx_it == output_headers.end())
1156            {
1157                string err_msg = "could not find object class '" + img_objects[max_gt_obj_idx].object_class + "' specified in the ground truth file of '" + images[ranking[image_idx]].id +"'in list of valid classes.";
1158                CV_Error(CV_StsError,err_msg.c_str());
1159            }
1160            /* convert iterator to index */
1161            int class_idx = std::distance(output_headers.begin(),class_idx_it);
1162            //add to confusion matrix row in proportion
1163            output_values[class_idx] += 1.0;
1164        } else {
1165            //otherwise assign to background class
1166            output_values[output_values.size()-1] += 1.0;
1167        }
1168
1169        //check break conditions if breaking on certain level of recall
1170        if (cond == CV_VOC_CCOND_RECALL)
1171        {
1172            if(static_cast<float>(retrieved_hits)/static_cast<float>(total_relevant) >= threshold) break;
1173        }
1174    }
1175
1176    /* finally, normalize confusion matrix row */
1177    for (vector<float>::iterator it = output_values.begin(); it < output_values.end(); ++it)
1178    {
1179        (*it) /= static_cast<float>(total_objects);
1180    }
1181}
1182
1183//Save Precision-Recall results to a p-r curve in GNUPlot format
1184//--------------------------------------------------------------
1185//INPUTS:
1186// - output_file        The file to which to save the GNUPlot data file. If only a filename is specified, the data
1187//                      file is saved to the standard VOC results directory.
1188// - precision          Vector of precisions as returned from calcClassifier/DetectorPrecRecall
1189// - recall             Vector of recalls as returned from calcClassifier/DetectorPrecRecall
1190// - ap                ap as returned from calcClassifier/DetectorPrecRecall
1191// - (title)            Title to use for the plot (if not specified, just the ap is printed as the title)
1192//                      This also specifies the filename of the output file if printing to pdf
1193// - (plot_type)        Specifies whether to instruct GNUPlot to save to a PDF file (CV_VOC_PLOT_PDF) or directly
1194//                      to screen (CV_VOC_PLOT_SCREEN) in the datafile
1195//NOTES:
1196// The GNUPlot data file can be executed using GNUPlot from the commandline in the following way:
1197//      >> GNUPlot <output_file>
1198// This will then display the p-r curve on the screen or save it to a pdf file depending on plot_type
1199
1200void VocData::savePrecRecallToGnuplot(const string& output_file, const vector<float>& precision, const vector<float>& recall, const float ap, const string title, const VocPlotType plot_type)
1201{
1202    string output_file_std = checkFilenamePathsep(output_file);
1203
1204    //if no directory is specified, by default save the output file in the results directory
1205//    if (output_file_std.find("/") == output_file_std.npos)
1206//    {
1207//        output_file_std = m_results_directory + output_file_std;
1208//    }
1209
1210    std::ofstream plot_file(output_file_std.c_str());
1211
1212    if (plot_file.is_open())
1213    {
1214        plot_file << "set xrange [0:1]" << endl;
1215        plot_file << "set yrange [0:1]" << endl;
1216        plot_file << "set size square" << endl;
1217        string title_text = title;
1218        if (title_text.size() == 0) title_text = "Precision-Recall Curve";
1219        plot_file << "set title \"" << title_text << " (ap: " << ap << ")\"" << endl;
1220        plot_file << "set xlabel \"Recall\"" << endl;
1221        plot_file << "set ylabel \"Precision\"" << endl;
1222        plot_file << "set style data lines" << endl;
1223        plot_file << "set nokey" << endl;
1224        if (plot_type == CV_VOC_PLOT_PNG)
1225        {
1226            plot_file << "set terminal png" << endl;
1227            string pdf_filename;
1228            if (title.size() != 0)
1229            {
1230                pdf_filename = title;
1231            } else {
1232                pdf_filename = "prcurve";
1233            }
1234            plot_file << "set out \"" << title << ".png\"" << endl;
1235        }
1236        plot_file << "plot \"-\" using 1:2" << endl;
1237        plot_file << "# X Y" << endl;
1238        CV_Assert(precision.size() == recall.size());
1239        for (size_t i = 0; i < precision.size(); ++i)
1240        {
1241            plot_file << "  " << recall[i] << " " << precision[i] << endl;
1242        }
1243        plot_file << "end" << endl;
1244        if (plot_type == CV_VOC_PLOT_SCREEN)
1245        {
1246            plot_file << "pause -1" << endl;
1247        }
1248        plot_file.close();
1249    } else {
1250        string err_msg = "could not open plot file '" + output_file_std + "' for writing.";
1251        CV_Error(CV_StsError,err_msg.c_str());
1252    }
1253}
1254
1255void VocData::readClassifierGroundTruth(const string& obj_class, const ObdDatasetType dataset, vector<ObdImage>& images, vector<char>& object_present)
1256{
1257    images.clear();
1258
1259    string gtFilename = m_class_imageset_path;
1260    gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1261    if (dataset == CV_OBD_TRAIN)
1262    {
1263        gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1264    } else {
1265        gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1266    }
1267
1268    vector<string> image_codes;
1269    readClassifierGroundTruth(gtFilename, image_codes, object_present);
1270
1271    convertImageCodesToObdImages(image_codes, images);
1272}
1273
1274void VocData::readClassifierResultsFile(const std:: string& input_file, vector<ObdImage>& images, vector<float>& scores)
1275{
1276    images.clear();
1277
1278    string input_file_std = checkFilenamePathsep(input_file);
1279
1280    //if no directory is specified, by default search for the input file in the results directory
1281//    if (input_file_std.find("/") == input_file_std.npos)
1282//    {
1283//        input_file_std = m_results_directory + input_file_std;
1284//    }
1285
1286    vector<string> image_codes;
1287    readClassifierResultsFile(input_file_std, image_codes, scores);
1288
1289    convertImageCodesToObdImages(image_codes, images);
1290}
1291
1292void VocData::readDetectorResultsFile(const string& input_file, vector<ObdImage>& images, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1293{
1294    images.clear();
1295
1296    string input_file_std = checkFilenamePathsep(input_file);
1297
1298    //if no directory is specified, by default search for the input file in the results directory
1299//    if (input_file_std.find("/") == input_file_std.npos)
1300//    {
1301//        input_file_std = m_results_directory + input_file_std;
1302//    }
1303
1304    vector<string> image_codes;
1305    readDetectorResultsFile(input_file_std, image_codes, scores, bounding_boxes);
1306
1307    convertImageCodesToObdImages(image_codes, images);
1308}
1309
1310const vector<string>& VocData::getObjectClasses()
1311{
1312    return m_object_classes;
1313}
1314
1315//string VocData::getResultsDirectory()
1316//{
1317//    return m_results_directory;
1318//}
1319
1320//---------------------------------------------------------
1321// Protected Functions ------------------------------------
1322//---------------------------------------------------------
1323
1324string getVocName( const string& vocPath )
1325{
1326    size_t found = vocPath.rfind( '/' );
1327    if( found == string::npos )
1328    {
1329        found = vocPath.rfind( '\\' );
1330        if( found == string::npos )
1331            return vocPath;
1332    }
1333    return vocPath.substr(found + 1, vocPath.size() - found);
1334}
1335
1336void VocData::initVoc( const string& vocPath, const bool useTestDataset )
1337{
1338    initVoc2007to2010( vocPath, useTestDataset );
1339}
1340
1341//Initialize file paths and settings for the VOC 2010 dataset
1342//-----------------------------------------------------------
1343void VocData::initVoc2007to2010( const string& vocPath, const bool useTestDataset )
1344{
1345    //check format of root directory and modify if necessary
1346
1347    m_vocName = getVocName( vocPath );
1348
1349    CV_Assert( !m_vocName.compare("VOC2007") || !m_vocName.compare("VOC2008") ||
1350               !m_vocName.compare("VOC2009") || !m_vocName.compare("VOC2010") )
1351
1352    m_vocPath = checkFilenamePathsep( vocPath, true );
1353
1354    if (useTestDataset)
1355    {
1356        m_train_set = "trainval";
1357        m_test_set = "test";
1358    } else {
1359        m_train_set = "train";
1360        m_test_set = "val";
1361    }
1362
1363    // initialize main classification/detection challenge paths
1364    m_annotation_path = m_vocPath + "/Annotations/%s.xml";
1365    m_image_path = m_vocPath + "/JPEGImages/%s.jpg";
1366    m_imageset_path = m_vocPath + "/ImageSets/Main/%s.txt";
1367    m_class_imageset_path = m_vocPath + "/ImageSets/Main/%s_%s.txt";
1368
1369    //define available object_classes for VOC2010 dataset
1370    m_object_classes.push_back("aeroplane");
1371    m_object_classes.push_back("bicycle");
1372    m_object_classes.push_back("bird");
1373    m_object_classes.push_back("boat");
1374    m_object_classes.push_back("bottle");
1375    m_object_classes.push_back("bus");
1376    m_object_classes.push_back("car");
1377    m_object_classes.push_back("cat");
1378    m_object_classes.push_back("chair");
1379    m_object_classes.push_back("cow");
1380    m_object_classes.push_back("diningtable");
1381    m_object_classes.push_back("dog");
1382    m_object_classes.push_back("horse");
1383    m_object_classes.push_back("motorbike");
1384    m_object_classes.push_back("person");
1385    m_object_classes.push_back("pottedplant");
1386    m_object_classes.push_back("sheep");
1387    m_object_classes.push_back("sofa");
1388    m_object_classes.push_back("train");
1389    m_object_classes.push_back("tvmonitor");
1390
1391    m_min_overlap = 0.5;
1392
1393    //up until VOC 2010, ap was calculated by sampling p-r curve, not taking complete curve
1394    m_sampled_ap = ((m_vocName == "VOC2007") || (m_vocName == "VOC2008") || (m_vocName == "VOC2009"));
1395}
1396
1397//Read a VOC classification ground truth text file for a given object class and dataset
1398//-------------------------------------------------------------------------------------
1399//INPUTS:
1400// - filename           The path of the text file to read
1401//OUTPUTS:
1402// - image_codes        VOC image codes extracted from the GT file in the form 20XX_XXXXXX where the first four
1403//                          digits specify the year of the dataset, and the last group specifies a unique ID
1404// - object_present     For each image in the 'image_codes' array, specifies whether the object class described
1405//                          in the loaded GT file is present or not
1406void VocData::readClassifierGroundTruth(const string& filename, vector<string>& image_codes, vector<char>& object_present)
1407{
1408    image_codes.clear();
1409    object_present.clear();
1410
1411    std::ifstream gtfile(filename.c_str());
1412    if (!gtfile.is_open())
1413    {
1414        string err_msg = "could not open VOC ground truth textfile '" + filename + "'.";
1415        CV_Error(CV_StsError,err_msg.c_str());
1416    }
1417
1418    string line;
1419    string image;
1420    int obj_present;
1421    while (!gtfile.eof())
1422    {
1423        std::getline(gtfile,line);
1424        std::istringstream iss(line);
1425        iss >> image >> obj_present;
1426        if (!iss.fail())
1427        {
1428            image_codes.push_back(image);
1429            object_present.push_back(obj_present == 1);
1430        } else {
1431            if (!gtfile.eof()) CV_Error(CV_StsParseError,"error parsing VOC ground truth textfile.");
1432        }
1433    }
1434    gtfile.close();
1435}
1436
1437void VocData::readClassifierResultsFile(const string& input_file, vector<string>& image_codes, vector<float>& scores)
1438{
1439    //check if results file exists
1440    std::ifstream result_file(input_file.c_str());
1441    if (result_file.is_open())
1442    {
1443        string line;
1444        string image;
1445        float score;
1446        //read in the results file
1447        while (!result_file.eof())
1448        {
1449            std::getline(result_file,line);
1450            std::istringstream iss(line);
1451            iss >> image >> score;
1452            if (!iss.fail())
1453            {
1454                image_codes.push_back(image);
1455                scores.push_back(score);
1456            } else {
1457                if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC classifier results file.");
1458            }
1459        }
1460        result_file.close();
1461    } else {
1462        string err_msg = "could not open classifier results file '" + input_file + "' for reading.";
1463        CV_Error(CV_StsError,err_msg.c_str());
1464    }
1465}
1466
1467void VocData::readDetectorResultsFile(const string& input_file, vector<string>& image_codes, vector<vector<float> >& scores, vector<vector<Rect> >& bounding_boxes)
1468{
1469    image_codes.clear();
1470    scores.clear();
1471    bounding_boxes.clear();
1472
1473    //check if results file exists
1474    std::ifstream result_file(input_file.c_str());
1475    if (result_file.is_open())
1476    {
1477        string line;
1478        string image;
1479        Rect bounding_box;
1480        float score;
1481        //read in the results file
1482        while (!result_file.eof())
1483        {
1484            std::getline(result_file,line);
1485            std::istringstream iss(line);
1486            iss >> image >> score >> bounding_box.x >> bounding_box.y >> bounding_box.width >> bounding_box.height;
1487            if (!iss.fail())
1488            {
1489                //convert right and bottom positions to width and height
1490                bounding_box.width -= bounding_box.x;
1491                bounding_box.height -= bounding_box.y;
1492                //convert to 0-indexing
1493                bounding_box.x -= 1;
1494                bounding_box.y -= 1;
1495                //store in output vectors
1496                /* first check if the current image code has been seen before */
1497                vector<string>::iterator image_codes_it = std::find(image_codes.begin(),image_codes.end(),image);
1498                if (image_codes_it == image_codes.end())
1499                {
1500                    image_codes.push_back(image);
1501                    vector<float> score_vect(1);
1502                    score_vect[0] = score;
1503                    scores.push_back(score_vect);
1504                    vector<Rect> bounding_box_vect(1);
1505                    bounding_box_vect[0] = bounding_box;
1506                    bounding_boxes.push_back(bounding_box_vect);
1507                } else {
1508                    /* if the image index has been seen before, add the current object below it in the 2D arrays */
1509                    int image_idx = std::distance(image_codes.begin(),image_codes_it);
1510                    scores[image_idx].push_back(score);
1511                    bounding_boxes[image_idx].push_back(bounding_box);
1512                }
1513            } else {
1514                if(!result_file.eof()) CV_Error(CV_StsParseError,"error parsing VOC detector results file.");
1515            }
1516        }
1517        result_file.close();
1518    } else {
1519        string err_msg = "could not open detector results file '" + input_file + "' for reading.";
1520        CV_Error(CV_StsError,err_msg.c_str());
1521    }
1522}
1523
1524
1525//Read a VOC annotation xml file for a given image
1526//------------------------------------------------
1527//INPUTS:
1528// - filename           The path of the xml file to read
1529//OUTPUTS:
1530// - objects            Array of VocObject describing all object instances present in the given image
1531void VocData::extractVocObjects(const string filename, vector<ObdObject>& objects, vector<VocObjectData>& object_data)
1532{
1533#ifdef PR_DEBUG
1534    int block = 1;
1535    cout << "SAMPLE VOC OBJECT EXTRACTION for " << filename << ":" << endl;
1536#endif
1537    objects.clear();
1538    object_data.clear();
1539
1540    string contents, object_contents, tag_contents;
1541
1542    readFileToString(filename, contents);
1543
1544    //keep on extracting 'object' blocks until no more can be found
1545    if (extractXMLBlock(contents, "annotation", 0, contents) != -1)
1546    {
1547        int searchpos = 0;
1548        searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1549        while (searchpos != -1)
1550        {
1551#ifdef PR_DEBUG
1552            cout << "SEARCHPOS:" << searchpos << endl;
1553            cout << "start block " << block << " ---------" << endl;
1554            cout << object_contents << endl;
1555            cout << "end block " << block << " -----------" << endl;
1556            ++block;
1557#endif
1558
1559            ObdObject object;
1560            VocObjectData object_d;
1561
1562            //object class -------------
1563
1564            if (extractXMLBlock(object_contents, "name", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <name> tag in object definition of '" + filename + "'");
1565            object.object_class.swap(tag_contents);
1566
1567            //object bounding box -------------
1568
1569            int xmax, xmin, ymax, ymin;
1570
1571            if (extractXMLBlock(object_contents, "xmax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmax> tag in object definition of '" + filename + "'");
1572            xmax = stringToInteger(tag_contents);
1573
1574            if (extractXMLBlock(object_contents, "xmin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <xmin> tag in object definition of '" + filename + "'");
1575            xmin = stringToInteger(tag_contents);
1576
1577            if (extractXMLBlock(object_contents, "ymax", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymax> tag in object definition of '" + filename + "'");
1578            ymax = stringToInteger(tag_contents);
1579
1580            if (extractXMLBlock(object_contents, "ymin", 0, tag_contents) == -1) CV_Error(CV_StsError,"missing <ymin> tag in object definition of '" + filename + "'");
1581            ymin = stringToInteger(tag_contents);
1582
1583            object.boundingBox.x = xmin-1;      //convert to 0-based indexing
1584            object.boundingBox.width = xmax - xmin;
1585            object.boundingBox.y = ymin-1;
1586            object.boundingBox.height = ymax - ymin;
1587
1588            CV_Assert(xmin != 0);
1589            CV_Assert(xmax > xmin);
1590            CV_Assert(ymin != 0);
1591            CV_Assert(ymax > ymin);
1592
1593
1594            //object tags -------------
1595
1596            if (extractXMLBlock(object_contents, "difficult", 0, tag_contents) != -1)
1597            {
1598                object_d.difficult = (tag_contents == "1");
1599            } else object_d.difficult = false;
1600            if (extractXMLBlock(object_contents, "occluded", 0, tag_contents) != -1)
1601            {
1602                object_d.occluded = (tag_contents == "1");
1603            } else object_d.occluded = false;
1604            if (extractXMLBlock(object_contents, "truncated", 0, tag_contents) != -1)
1605            {
1606                object_d.truncated = (tag_contents == "1");
1607            } else object_d.truncated = false;
1608            if (extractXMLBlock(object_contents, "pose", 0, tag_contents) != -1)
1609            {
1610                if (tag_contents == "Frontal") object_d.pose = CV_VOC_POSE_FRONTAL;
1611                if (tag_contents == "Rear") object_d.pose = CV_VOC_POSE_REAR;
1612                if (tag_contents == "Left") object_d.pose = CV_VOC_POSE_LEFT;
1613                if (tag_contents == "Right") object_d.pose = CV_VOC_POSE_RIGHT;
1614            }
1615
1616            //add to array of objects
1617            objects.push_back(object);
1618            object_data.push_back(object_d);
1619
1620            //extract next 'object' block from file if it exists
1621            searchpos = extractXMLBlock(contents, "object", searchpos, object_contents);
1622        }
1623    }
1624}
1625
1626//Converts an image identifier string in the format YYYY_XXXXXX to a single index integer of form XXXXXXYYYY
1627//where Y represents a year and returns the image path
1628//----------------------------------------------------------------------------------------------------------
1629string VocData::getImagePath(const string& input_str)
1630{
1631    string path = m_image_path;
1632    path.replace(path.find("%s"),2,input_str);
1633    return path;
1634}
1635
1636//Tests two boundary boxes for overlap (using the intersection over union metric) and returns the overlap if the objects
1637//defined by the two bounding boxes are considered to be matched according to the criterion outlined in
1638//the VOC documentation [namely intersection/union > some threshold] otherwise returns -1.0 (no match)
1639//----------------------------------------------------------------------------------------------------------
1640float VocData::testBoundingBoxesForOverlap(const Rect detection, const Rect ground_truth)
1641{
1642    int detection_x2 = detection.x + detection.width;
1643    int detection_y2 = detection.y + detection.height;
1644    int ground_truth_x2 = ground_truth.x + ground_truth.width;
1645    int ground_truth_y2 = ground_truth.y + ground_truth.height;
1646    //first calculate the boundaries of the intersection of the rectangles
1647    int intersection_x = std::max(detection.x, ground_truth.x); //rightmost left
1648    int intersection_y = std::max(detection.y, ground_truth.y); //bottommost top
1649    int intersection_x2 = std::min(detection_x2, ground_truth_x2); //leftmost right
1650    int intersection_y2 = std::min(detection_y2, ground_truth_y2); //topmost bottom
1651    //then calculate the width and height of the intersection rect
1652    int intersection_width = intersection_x2 - intersection_x + 1;
1653    int intersection_height = intersection_y2 - intersection_y + 1;
1654    //if there is no overlap then return false straight away
1655    if ((intersection_width <= 0) || (intersection_height <= 0)) return -1.0;
1656    //otherwise calculate the intersection
1657    int intersection_area = intersection_width*intersection_height;
1658
1659    //now calculate the union
1660    int union_area = (detection.width+1)*(detection.height+1) + (ground_truth.width+1)*(ground_truth.height+1) - intersection_area;
1661
1662    //calculate the intersection over union and use as threshold as per VOC documentation
1663    float overlap = static_cast<float>(intersection_area)/static_cast<float>(union_area);
1664    if (overlap > m_min_overlap)
1665    {
1666        return overlap;
1667    } else {
1668        return -1.0;
1669    }
1670}
1671
1672//Extracts the object class and dataset from the filename of a VOC standard results text file, which takes
1673//the format 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'
1674//----------------------------------------------------------------------------------------------------------
1675void VocData::extractDataFromResultsFilename(const string& input_file, string& class_name, string& dataset_name)
1676{
1677    string input_file_std = checkFilenamePathsep(input_file);
1678
1679    size_t fnamestart = input_file_std.rfind("/");
1680    size_t fnameend = input_file_std.rfind(".txt");
1681
1682    if ((fnamestart == input_file_std.npos) || (fnameend == input_file_std.npos))
1683        CV_Error(CV_StsError,"Could not extract filename of results file.");
1684
1685    ++fnamestart;
1686    if (fnamestart >= fnameend)
1687        CV_Error(CV_StsError,"Could not extract filename of results file.");
1688
1689    //extract dataset and class names, triggering exception if the filename format is not correct
1690    string filename = input_file_std.substr(fnamestart, fnameend-fnamestart);
1691    size_t datasetstart = filename.find("_");
1692    datasetstart = filename.find("_",datasetstart+1);
1693    size_t classstart = filename.find("_",datasetstart+1);
1694    //allow for appended index after a further '_' by discarding this part if it exists
1695    size_t classend = filename.find("_",classstart+1);
1696    if (classend == filename.npos) classend = filename.size();
1697    if ((datasetstart == filename.npos) || (classstart == filename.npos))
1698        CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1699    ++datasetstart;
1700    ++classstart;
1701    if (((datasetstart-classstart) < 1) || ((classend-datasetstart) < 1))
1702        CV_Error(CV_StsError,"Error parsing results filename. Is it in standard format of 'comp<n>_{cls/det}_<dataset>_<objclass>.txt'?");
1703
1704    dataset_name = filename.substr(datasetstart,classstart-datasetstart-1);
1705    class_name = filename.substr(classstart,classend-classstart);
1706}
1707
1708bool VocData::getClassifierGroundTruthImage(const string& obj_class, const string& id)
1709{
1710    /* if the classifier ground truth data for all images of the current class has not been loaded yet, load it now */
1711    if (m_classifier_gt_all_ids.empty() || (m_classifier_gt_class != obj_class))
1712    {
1713        m_classifier_gt_all_ids.clear();
1714        m_classifier_gt_all_present.clear();
1715        m_classifier_gt_class = obj_class;
1716        for (int i=0; i<2; ++i) //run twice (once over test set and once over training set)
1717        {
1718            //generate the filename of the classification ground-truth textfile for the object class
1719            string gtFilename = m_class_imageset_path;
1720            gtFilename.replace(gtFilename.find("%s"),2,obj_class);
1721            if (i == 0)
1722            {
1723                gtFilename.replace(gtFilename.find("%s"),2,m_train_set);
1724            } else {
1725                gtFilename.replace(gtFilename.find("%s"),2,m_test_set);
1726            }
1727
1728            //parse the ground truth file, storing in two separate vectors
1729            //for the image code and the ground truth value
1730            vector<string> image_codes;
1731            vector<char> object_present;
1732            readClassifierGroundTruth(gtFilename, image_codes, object_present);
1733
1734            m_classifier_gt_all_ids.insert(m_classifier_gt_all_ids.end(),image_codes.begin(),image_codes.end());
1735            m_classifier_gt_all_present.insert(m_classifier_gt_all_present.end(),object_present.begin(),object_present.end());
1736
1737            CV_Assert(m_classifier_gt_all_ids.size() == m_classifier_gt_all_present.size());
1738        }
1739    }
1740
1741
1742    //search for the image code
1743    vector<string>::iterator it = find (m_classifier_gt_all_ids.begin(), m_classifier_gt_all_ids.end(), id);
1744    if (it != m_classifier_gt_all_ids.end())
1745    {
1746        //image found, so return corresponding ground truth
1747        return m_classifier_gt_all_present[std::distance(m_classifier_gt_all_ids.begin(),it)];
1748    } else {
1749        string err_msg = "could not find classifier ground truth for image '" + id + "' and class '" + obj_class + "'";
1750        CV_Error(CV_StsError,err_msg.c_str());
1751    }
1752
1753    return true;
1754}
1755
1756//-------------------------------------------------------------------
1757// Protected Functions (utility) ------------------------------------
1758//-------------------------------------------------------------------
1759
1760//returns a vector containing indexes of the input vector in sorted ascending/descending order
1761void VocData::getSortOrder(const vector<float>& values, vector<size_t>& order, bool descending)
1762{
1763    /* 1. store sorting order in 'order_pair' */
1764    vector<std::pair<size_t, vector<float>::const_iterator> > order_pair(values.size());
1765
1766    size_t n = 0;
1767    for (vector<float>::const_iterator it = values.begin(); it != values.end(); ++it, ++n)
1768        order_pair[n] = make_pair(n, it);
1769
1770    std::sort(order_pair.begin(),order_pair.end(),orderingSorter());
1771    if (descending == false) std::reverse(order_pair.begin(),order_pair.end());
1772
1773    vector<size_t>(order_pair.size()).swap(order);
1774    for (size_t i = 0; i < order_pair.size(); ++i)
1775    {
1776        order[i] = order_pair[i].first;
1777    }
1778}
1779
1780void VocData::readFileToString(const string filename, string& file_contents)
1781{
1782    std::ifstream ifs(filename.c_str());
1783    if (ifs == false) CV_Error(CV_StsError,"could not open text file");
1784
1785    stringstream oss;
1786    oss << ifs.rdbuf();
1787
1788    file_contents = oss.str();
1789}
1790
1791int VocData::stringToInteger(const string input_str)
1792{
1793    int result;
1794
1795    stringstream ss(input_str);
1796    if ((ss >> result).fail())
1797    {
1798        CV_Error(CV_StsBadArg,"could not perform string to integer conversion");
1799    }
1800    return result;
1801}
1802
1803string VocData::integerToString(const int input_int)
1804{
1805    string result;
1806
1807    stringstream ss;
1808    if ((ss << input_int).fail())
1809    {
1810        CV_Error(CV_StsBadArg,"could not perform integer to string conversion");
1811    }
1812    result = ss.str();
1813    return result;
1814}
1815
1816string VocData::checkFilenamePathsep( const string filename, bool add_trailing_slash )
1817{
1818    string filename_new = filename;
1819
1820    size_t pos = filename_new.find("\\\\");
1821    while (pos != filename_new.npos)
1822    {
1823        filename_new.replace(pos,2,"/");
1824        pos = filename_new.find("\\\\", pos);
1825    }
1826    pos = filename_new.find("\\");
1827    while (pos != filename_new.npos)
1828    {
1829        filename_new.replace(pos,2,"/");
1830        pos = filename_new.find("\\", pos);
1831    }
1832    if (add_trailing_slash)
1833    {
1834        //add training slash if this is missing
1835        if (filename_new.rfind("/") != filename_new.length()-1) filename_new += "/";
1836    }
1837
1838    return filename_new;
1839}
1840
1841void VocData::convertImageCodesToObdImages(const vector<string>& image_codes, vector<ObdImage>& images)
1842{
1843    images.clear();
1844    images.reserve(image_codes.size());
1845
1846    string path;
1847    //transfer to output arrays
1848    for (size_t i = 0; i < image_codes.size(); ++i)
1849    {
1850        //generate image path and indices from extracted string code
1851        path = getImagePath(image_codes[i]);
1852        images.push_back(ObdImage(image_codes[i], path));
1853    }
1854}
1855
1856//Extract text from within a given tag from an XML file
1857//-----------------------------------------------------
1858//INPUTS:
1859// - src            XML source file
1860// - tag            XML tag delimiting block to extract
1861// - searchpos      position within src at which to start search
1862//OUTPUTS:
1863// - tag_contents   text extracted between <tag> and </tag> tags
1864//RETURN VALUE:
1865// - the position of the final character extracted in tag_contents within src
1866//      (can be used to call extractXMLBlock recursively to extract multiple blocks)
1867//      returns -1 if the tag could not be found
1868int VocData::extractXMLBlock(const string src, const string tag, const int searchpos, string& tag_contents)
1869{
1870    size_t startpos, next_startpos, endpos;
1871    int embed_count = 1;
1872
1873    //find position of opening tag
1874    startpos = src.find("<" + tag + ">", searchpos);
1875    if (startpos == string::npos) return -1;
1876
1877    //initialize endpos -
1878    // start searching for end tag anywhere after opening tag
1879    endpos = startpos;
1880
1881    //find position of next opening tag
1882    next_startpos = src.find("<" + tag + ">", startpos+1);
1883
1884    //match opening tags with closing tags, and only
1885    //accept final closing tag of same level as original
1886    //opening tag
1887    while (embed_count > 0)
1888    {
1889        endpos = src.find("</" + tag + ">", endpos+1);
1890        if (endpos == string::npos) return -1;
1891
1892        //the next code is only executed if there are embedded tags with the same name
1893        if (next_startpos != string::npos)
1894        {
1895            while (next_startpos<endpos)
1896            {
1897                //counting embedded start tags
1898                ++embed_count;
1899                next_startpos = src.find("<" + tag + ">", next_startpos+1);
1900                if (next_startpos == string::npos) break;
1901            }
1902        }
1903        //passing end tag so decrement nesting level
1904        --embed_count;
1905    }
1906
1907    //finally, extract the tag region
1908    startpos += tag.length() + 2;
1909    if (startpos > src.length()) return -1;
1910    if (endpos > src.length()) return -1;
1911    tag_contents = src.substr(startpos,endpos-startpos);
1912    return static_cast<int>(endpos);
1913}
1914
1915/****************************************************************************************\
1916*                            Sample on image classification                             *
1917\****************************************************************************************/
1918//
1919// This part of the code was a little refactor
1920//
1921struct DDMParams
1922{
1923    DDMParams() : detectorType("SURF"), descriptorType("SURF"), matcherType("BruteForce") {}
1924    DDMParams( const string _detectorType, const string _descriptorType, const string& _matcherType ) :
1925        detectorType(_detectorType), descriptorType(_descriptorType), matcherType(_matcherType){}
1926    void read( const FileNode& fn )
1927    {
1928        fn["detectorType"] >> detectorType;
1929        fn["descriptorType"] >> descriptorType;
1930        fn["matcherType"] >> matcherType;
1931    }
1932    void write( FileStorage& fs ) const
1933    {
1934        fs << "detectorType" << detectorType;
1935        fs << "descriptorType" << descriptorType;
1936        fs << "matcherType" << matcherType;
1937    }
1938    void print() const
1939    {
1940        cout << "detectorType: " << detectorType << endl;
1941        cout << "descriptorType: " << descriptorType << endl;
1942        cout << "matcherType: " << matcherType << endl;
1943    }
1944
1945    string detectorType;
1946    string descriptorType;
1947    string matcherType;
1948};
1949
1950struct VocabTrainParams
1951{
1952    VocabTrainParams() : trainObjClass("chair"), vocabSize(1000), memoryUse(200), descProportion(0.3f) {}
1953    VocabTrainParams( const string _trainObjClass, size_t _vocabSize, size_t _memoryUse, float _descProportion ) :
1954            trainObjClass(_trainObjClass), vocabSize(_vocabSize), memoryUse(_memoryUse), descProportion(_descProportion) {}
1955    void read( const FileNode& fn )
1956    {
1957        fn["trainObjClass"] >> trainObjClass;
1958        fn["vocabSize"] >> vocabSize;
1959        fn["memoryUse"] >> memoryUse;
1960        fn["descProportion"] >> descProportion;
1961    }
1962    void write( FileStorage& fs ) const
1963    {
1964        fs << "trainObjClass" << trainObjClass;
1965        fs << "vocabSize" << vocabSize;
1966        fs << "memoryUse" << memoryUse;
1967        fs << "descProportion" << descProportion;
1968    }
1969    void print() const
1970    {
1971        cout << "trainObjClass: " << trainObjClass << endl;
1972        cout << "vocabSize: " << vocabSize << endl;
1973        cout << "memoryUse: " << memoryUse << endl;
1974        cout << "descProportion: " << descProportion << endl;
1975    }
1976
1977
1978    string trainObjClass; // Object class used for training visual vocabulary.
1979                          // It shouldn't matter which object class is specified here - visual vocab will still be the same.
1980    int vocabSize; //number of visual words in vocabulary to train
1981    int memoryUse; // Memory to preallocate (in MB) when training vocab.
1982                      // Change this depending on the size of the dataset/available memory.
1983    float descProportion; // Specifies the number of descriptors to use from each image as a proportion of the total num descs.
1984};
1985
1986struct SVMTrainParamsExt
1987{
1988    SVMTrainParamsExt() : descPercent(0.5f), targetRatio(0.4f), balanceClasses(true) {}
1989    SVMTrainParamsExt( float _descPercent, float _targetRatio, bool _balanceClasses,
1990                       int _svmType, int _kernelType, double _degree, double _gamma, double _coef0,
1991                       double _C, double _nu, double _p, Mat& _class_weights, TermCriteria _termCrit ) :
1992            descPercent(_descPercent), targetRatio(_targetRatio), balanceClasses(_balanceClasses) {}
1993    void read( const FileNode& fn )
1994    {
1995        fn["descPercent"] >> descPercent;
1996        fn["targetRatio"] >> targetRatio;
1997        fn["balanceClasses"] >> balanceClasses;
1998    }
1999    void write( FileStorage& fs ) const
2000    {
2001        fs << "descPercent" << descPercent;
2002        fs << "targetRatio" << targetRatio;
2003        fs << "balanceClasses" << balanceClasses;
2004    }
2005    void print() const
2006    {
2007        cout << "descPercent: " << descPercent << endl;
2008        cout << "targetRatio: " << targetRatio << endl;
2009        cout << "balanceClasses: " << balanceClasses << endl;
2010    }
2011
2012    float descPercent; // Percentage of extracted descriptors to use for training.
2013    float targetRatio; // Try to get this ratio of positive to negative samples (minimum).
2014    bool balanceClasses;    // Balance class weights by number of samples in each (if true cSvmTrainTargetRatio is ignored).
2015};
2016
2017void readUsedParams( const FileNode& fn, string& vocName, DDMParams& ddmParams, VocabTrainParams& vocabTrainParams, SVMTrainParamsExt& svmTrainParamsExt )
2018{
2019    fn["vocName"] >> vocName;
2020
2021    FileNode currFn = fn;
2022
2023    currFn = fn["ddmParams"];
2024    ddmParams.read( currFn );
2025
2026    currFn = fn["vocabTrainParams"];
2027    vocabTrainParams.read( currFn );
2028
2029    currFn = fn["svmTrainParamsExt"];
2030    svmTrainParamsExt.read( currFn );
2031}
2032
2033void writeUsedParams( FileStorage& fs, const string& vocName, const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams, const SVMTrainParamsExt& svmTrainParamsExt )
2034{
2035    fs << "vocName" << vocName;
2036
2037    fs << "ddmParams" << "{";
2038    ddmParams.write(fs);
2039    fs << "}";
2040
2041    fs << "vocabTrainParams" << "{";
2042    vocabTrainParams.write(fs);
2043    fs << "}";
2044
2045    fs << "svmTrainParamsExt" << "{";
2046    svmTrainParamsExt.write(fs);
2047    fs << "}";
2048}
2049
2050void printUsedParams( const string& vocPath, const string& resDir,
2051                      const DDMParams& ddmParams, const VocabTrainParams& vocabTrainParams,
2052                      const SVMTrainParamsExt& svmTrainParamsExt )
2053{
2054    cout << "CURRENT CONFIGURATION" << endl;
2055    cout << "----------------------------------------------------------------" << endl;
2056    cout << "vocPath: " << vocPath << endl;
2057    cout << "resDir: " << resDir << endl;
2058    cout << endl; ddmParams.print();
2059    cout << endl; vocabTrainParams.print();
2060    cout << endl; svmTrainParamsExt.print();
2061    cout << "----------------------------------------------------------------" << endl << endl;
2062}
2063
2064bool readVocabulary( const string& filename, Mat& vocabulary )
2065{
2066    cout << "Reading vocabulary...";
2067    FileStorage fs( filename, FileStorage::READ );
2068    if( fs.isOpened() )
2069    {
2070        fs["vocabulary"] >> vocabulary;
2071        cout << "done" << endl;
2072        return true;
2073    }
2074    return false;
2075}
2076
2077bool writeVocabulary( const string& filename, const Mat& vocabulary )
2078{
2079    cout << "Saving vocabulary..." << endl;
2080    FileStorage fs( filename, FileStorage::WRITE );
2081    if( fs.isOpened() )
2082    {
2083        fs << "vocabulary" << vocabulary;
2084        return true;
2085    }
2086    return false;
2087}
2088
2089Mat trainVocabulary( const string& filename, VocData& vocData, const VocabTrainParams& trainParams,
2090                     const Ptr<FeatureDetector>& fdetector, const Ptr<DescriptorExtractor>& dextractor )
2091{
2092    Mat vocabulary;
2093    if( !readVocabulary( filename, vocabulary) )
2094    {
2095        CV_Assert( dextractor->descriptorType() == CV_32FC1 );
2096        const int descByteSize = dextractor->descriptorSize()*4;
2097        const int maxDescCount = (trainParams.memoryUse * 1048576) / descByteSize; // Total number of descs to use for training.
2098
2099        cout << "Extracting VOC data..." << endl;
2100        vector<ObdImage> images;
2101        vector<char> objectPresent;
2102        vocData.getClassImages( trainParams.trainObjClass, CV_OBD_TRAIN, images, objectPresent );
2103
2104        cout << "Computing descriptors..." << endl;
2105        RNG& rng = theRNG();
2106        TermCriteria terminate_criterion;
2107        terminate_criterion.epsilon = FLT_EPSILON;
2108        BOWKMeansTrainer bowTrainer( trainParams.vocabSize, terminate_criterion, 3, KMEANS_PP_CENTERS );
2109
2110        while( images.size() > 0 )
2111        {
2112            if( bowTrainer.descripotorsCount() >= maxDescCount )
2113            {
2114                assert( bowTrainer.descripotorsCount() == maxDescCount );
2115#ifdef DEBUG_DESC_PROGRESS
2116                cout << "Breaking due to full memory ( descriptors count = " << bowTrainer.descripotorsCount()
2117                        << "; descriptor size in bytes = " << descByteSize << "; all used memory = "
2118                        << bowTrainer.descripotorsCount()*descByteSize << endl;
2119#endif
2120                break;
2121            }
2122
2123            // Randomly pick an image from the dataset which hasn't yet been seen
2124            // and compute the descriptors from that image.
2125            int randImgIdx = rng( images.size() );
2126            Mat colorImage = imread( images[randImgIdx].path );
2127            vector<KeyPoint> imageKeypoints;
2128            fdetector->detect( colorImage, imageKeypoints );
2129            Mat imageDescriptors;
2130            dextractor->compute( colorImage, imageKeypoints, imageDescriptors );
2131
2132            //check that there were descriptors calculated for the current image
2133            if( !imageDescriptors.empty() )
2134            {
2135                int descCount = imageDescriptors.rows;
2136                // Extract trainParams.descProportion descriptors from the image, breaking if the 'allDescriptors' matrix becomes full
2137                int descsToExtract = static_cast<int>(trainParams.descProportion * static_cast<float>(descCount));
2138                // Fill mask of used descriptors
2139                vector<char> usedMask( descCount, false );
2140                fill( usedMask.begin(), usedMask.begin() + descsToExtract, true );
2141                for( int i = 0; i < descCount; i++ )
2142                {
2143                    int i1 = rng(descCount), i2 = rng(descCount);
2144                    char tmp = usedMask[i1]; usedMask[i1] = usedMask[i2]; usedMask[i2] = tmp;
2145                }
2146
2147                for( int i = 0; i < descCount; i++ )
2148                {
2149                    if( usedMask[i] && bowTrainer.descripotorsCount() < maxDescCount )
2150                        bowTrainer.add( imageDescriptors.row(i) );
2151                }
2152            }
2153
2154#ifdef DEBUG_DESC_PROGRESS
2155            cout << images.size() << " images left, " << images[randImgIdx].id << " processed - "
2156                    <</* descs_extracted << "/" << image_descriptors.rows << " extracted - " << */
2157                    cvRound((static_cast<double>(bowTrainer.descripotorsCount())/static_cast<double>(maxDescCount))*100.0)
2158                    << " % memory used" << ( imageDescriptors.empty() ? " -> no descriptors extracted, skipping" : "") << endl;
2159#endif
2160
2161            // Delete the current element from images so it is not added again
2162            images.erase( images.begin() + randImgIdx );
2163        }
2164
2165        cout << "Maximum allowed descriptor count: " << maxDescCount << ", Actual descriptor count: " << bowTrainer.descripotorsCount() << endl;
2166
2167        cout << "Training vocabulary..." << endl;
2168        vocabulary = bowTrainer.cluster();
2169
2170        if( !writeVocabulary(filename, vocabulary) )
2171        {
2172            cout << "Error: file " << filename << " can not be opened to write" << endl;
2173            exit(-1);
2174        }
2175    }
2176    return vocabulary;
2177}
2178
2179bool readBowImageDescriptor( const string& file, Mat& bowImageDescriptor )
2180{
2181    FileStorage fs( file, FileStorage::READ );
2182    if( fs.isOpened() )
2183    {
2184        fs["imageDescriptor"] >> bowImageDescriptor;
2185        return true;
2186    }
2187    return false;
2188}
2189
2190bool writeBowImageDescriptor( const string& file, const Mat& bowImageDescriptor )
2191{
2192    FileStorage fs( file, FileStorage::WRITE );
2193    if( fs.isOpened() )
2194    {
2195        fs << "imageDescriptor" << bowImageDescriptor;
2196        return true;
2197    }
2198    return false;
2199}
2200
2201// Load in the bag of words vectors for a set of images, from file if possible
2202void calculateImageDescriptors( const vector<ObdImage>& images, vector<Mat>& imageDescriptors,
2203                                const Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2204                                const string& resPath )
2205{
2206    CV_Assert( !bowExtractor->getVocabulary().empty() );
2207    imageDescriptors.resize( images.size() );
2208
2209    for( size_t i = 0; i < images.size(); i++ )
2210    {
2211        string filename = resPath + bowImageDescriptorsDir + "/" + images[i].id + ".xml.gz";
2212        if( readBowImageDescriptor( filename, imageDescriptors[i] ) )
2213        {
2214#ifdef DEBUG_DESC_PROGRESS
2215            cout << "Loaded bag of word vector for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << endl;
2216#endif
2217        }
2218        else
2219        {
2220            Mat colorImage = imread( images[i].path );
2221#ifdef DEBUG_DESC_PROGRESS
2222            cout << "Computing descriptors for image " << i+1 << " of " << images.size() << " (" << images[i].id << ")" << flush;
2223#endif
2224            vector<KeyPoint> keypoints;
2225            fdetector->detect( colorImage, keypoints );
2226#ifdef DEBUG_DESC_PROGRESS
2227                cout << " + generating BoW vector" << std::flush;
2228#endif
2229            bowExtractor->compute( colorImage, keypoints, imageDescriptors[i] );
2230#ifdef DEBUG_DESC_PROGRESS
2231            cout << " ...DONE " << static_cast<int>(static_cast<float>(i+1)/static_cast<float>(images.size())*100.0)
2232                 << " % complete" << endl;
2233#endif
2234            if( !imageDescriptors[i].empty() )
2235            {
2236                if( !writeBowImageDescriptor( filename, imageDescriptors[i] ) )
2237                {
2238                    cout << "Error: file " << filename << "can not be opened to write bow image descriptor" << endl;
2239                    exit(-1);
2240                }
2241            }
2242        }
2243    }
2244}
2245
2246void removeEmptyBowImageDescriptors( vector<ObdImage>& images, vector<Mat>& bowImageDescriptors,
2247                                     vector<char>& objectPresent )
2248{
2249    CV_Assert( !images.empty() );
2250    for( int i = (int)images.size() - 1; i >= 0; i-- )
2251    {
2252        bool res = bowImageDescriptors[i].empty();
2253        if( res )
2254        {
2255            cout << "Removing image " << images[i].id << " due to no descriptors..." << endl;
2256            images.erase( images.begin() + i );
2257            bowImageDescriptors.erase( bowImageDescriptors.begin() + i );
2258            objectPresent.erase( objectPresent.begin() + i );
2259        }
2260    }
2261}
2262
2263void removeBowImageDescriptorsByCount( vector<ObdImage>& images, vector<Mat> bowImageDescriptors, vector<char> objectPresent,
2264                                       const SVMTrainParamsExt& svmParamsExt, int descsToDelete )
2265{
2266    RNG& rng = theRNG();
2267    int pos_ex = std::count( objectPresent.begin(), objectPresent.end(), true );
2268    int neg_ex = std::count( objectPresent.begin(), objectPresent.end(), false );
2269
2270    while( descsToDelete != 0 )
2271    {
2272        int randIdx = rng(images.size());
2273
2274        // Prefer positive training examples according to svmParamsExt.targetRatio if required
2275        if( objectPresent[randIdx] )
2276        {
2277            if( (static_cast<float>(pos_ex)/static_cast<float>(neg_ex+pos_ex)  < svmParamsExt.targetRatio) &&
2278                (neg_ex > 0) && (svmParamsExt.balanceClasses == false) )
2279            { continue; }
2280            else
2281            { pos_ex--; }
2282        }
2283        else
2284        { neg_ex--; }
2285
2286        images.erase( images.begin() + randIdx );
2287        bowImageDescriptors.erase( bowImageDescriptors.begin() + randIdx );
2288        objectPresent.erase( objectPresent.begin() + randIdx );
2289
2290        descsToDelete--;
2291    }
2292    CV_Assert( bowImageDescriptors.size() == objectPresent.size() );
2293}
2294
2295void setSVMParams( CvSVMParams& svmParams, CvMat& class_wts_cv, const Mat& responses, bool balanceClasses )
2296{
2297    int pos_ex = countNonZero(responses == 1);
2298    int neg_ex = countNonZero(responses == -1);
2299    cout << pos_ex << " positive training samples; " << neg_ex << " negative training samples" << endl;
2300
2301    svmParams.svm_type = CvSVM::C_SVC;
2302    svmParams.kernel_type = CvSVM::RBF;
2303    if( balanceClasses )
2304    {
2305        Mat class_wts( 2, 1, CV_32FC1 );
2306        // The first training sample determines the '+1' class internally, even if it is negative,
2307        // so store whether this is the case so that the class weights can be reversed accordingly.
2308        bool reversed_classes = (responses.at<float>(0) < 0.f);
2309        if( reversed_classes == false )
2310        {
2311            class_wts.at<float>(0) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of positive class + 1 (i.e. cost of false positive - larger gives greater cost)
2312            class_wts.at<float>(1) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex); // weighting for costs of negative class - 1 (i.e. cost of false negative)
2313        }
2314        else
2315        {
2316            class_wts.at<float>(0) = static_cast<float>(neg_ex)/static_cast<float>(pos_ex+neg_ex);
2317            class_wts.at<float>(1) = static_cast<float>(pos_ex)/static_cast<float>(pos_ex+neg_ex);
2318        }
2319        class_wts_cv = class_wts;
2320        svmParams.class_weights = &class_wts_cv;
2321    }
2322}
2323
2324void setSVMTrainAutoParams( CvParamGrid& c_grid, CvParamGrid& gamma_grid,
2325                            CvParamGrid& p_grid, CvParamGrid& nu_grid,
2326                            CvParamGrid& coef_grid, CvParamGrid& degree_grid )
2327{
2328    c_grid = CvSVM::get_default_grid(CvSVM::C);
2329
2330    gamma_grid = CvSVM::get_default_grid(CvSVM::GAMMA);
2331
2332    p_grid = CvSVM::get_default_grid(CvSVM::P);
2333    p_grid.step = 0;
2334
2335    nu_grid = CvSVM::get_default_grid(CvSVM::NU);
2336    nu_grid.step = 0;
2337
2338    coef_grid = CvSVM::get_default_grid(CvSVM::COEF);
2339    coef_grid.step = 0;
2340
2341    degree_grid = CvSVM::get_default_grid(CvSVM::DEGREE);
2342    degree_grid.step = 0;
2343}
2344
2345void trainSVMClassifier( CvSVM& svm, const SVMTrainParamsExt& svmParamsExt, const string& objClassName, VocData& vocData,
2346                         const Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2347                         const string& resPath )
2348{
2349    /* first check if a previously trained svm for the current class has been saved to file */
2350    string svmFilename = resPath + svmsDir + "/" + objClassName + ".xml.gz";
2351
2352    FileStorage fs( svmFilename, FileStorage::READ);
2353    if( fs.isOpened() )
2354    {
2355        cout << "*** LOADING SVM CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2356        svm.load( svmFilename.c_str() );
2357    }
2358    else
2359    {
2360        cout << "*** TRAINING CLASSIFIER FOR CLASS " << objClassName << " ***" << endl;
2361        cout << "CALCULATING BOW VECTORS FOR TRAINING SET OF " << objClassName << "..." << endl;
2362
2363        // Get classification ground truth for images in the training set
2364        vector<ObdImage> images;
2365        vector<Mat> bowImageDescriptors;
2366        vector<char> objectPresent;
2367        vocData.getClassImages( objClassName, CV_OBD_TRAIN, images, objectPresent );
2368
2369        // Compute the bag of words vector for each image in the training set.
2370        calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2371
2372        // Remove any images for which descriptors could not be calculated
2373        removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent );
2374
2375        CV_Assert( svmParamsExt.descPercent > 0.f && svmParamsExt.descPercent <= 1.f );
2376        if( svmParamsExt.descPercent < 1.f )
2377        {
2378            int descsToDelete = static_cast<int>(static_cast<float>(images.size())*(1.0-svmParamsExt.descPercent));
2379
2380            cout << "Using " << (images.size() - descsToDelete) << " of " << images.size() <<
2381                    " descriptors for training (" << svmParamsExt.descPercent*100.0 << " %)" << endl;
2382            removeBowImageDescriptorsByCount( images, bowImageDescriptors, objectPresent, svmParamsExt, descsToDelete );
2383        }
2384
2385        // Prepare the input matrices for SVM training.
2386        Mat trainData( images.size(), bowExtractor->getVocabulary().rows, CV_32FC1 );
2387        Mat responses( images.size(), 1, CV_32SC1 );
2388
2389        // Transfer bag of words vectors and responses across to the training data matrices
2390        for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2391        {
2392            // Transfer image descriptor (bag of words vector) to training data matrix
2393            Mat submat = trainData.row(imageIdx);
2394            if( bowImageDescriptors[imageIdx].cols != bowExtractor->descriptorSize() )
2395            {
2396                cout << "Error: computed bow image descriptor size " << bowImageDescriptors[imageIdx].cols
2397                     << " differs from vocabulary size" << bowExtractor->getVocabulary().cols << endl;
2398                exit(-1);
2399            }
2400            bowImageDescriptors[imageIdx].copyTo( submat );
2401
2402            // Set response value
2403            responses.at<int>(imageIdx) = objectPresent[imageIdx] ? 1 : -1;
2404        }
2405
2406        cout << "TRAINING SVM FOR CLASS ..." << objClassName << "..." << endl;
2407        CvSVMParams svmParams;
2408        CvMat class_wts_cv;
2409        setSVMParams( svmParams, class_wts_cv, responses, svmParamsExt.balanceClasses );
2410        CvParamGrid c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid;
2411        setSVMTrainAutoParams( c_grid, gamma_grid,  p_grid, nu_grid, coef_grid, degree_grid );
2412        svm.train_auto( trainData, responses, Mat(), Mat(), svmParams, 10, c_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
2413        cout << "SVM TRAINING FOR CLASS " << objClassName << " COMPLETED" << endl;
2414
2415        svm.save( svmFilename.c_str() );
2416        cout << "SAVED CLASSIFIER TO FILE" << endl;
2417    }
2418}
2419
2420void computeConfidences( CvSVM& svm, const string& objClassName, VocData& vocData,
2421                         const Ptr<BOWImgDescriptorExtractor>& bowExtractor, const Ptr<FeatureDetector>& fdetector,
2422                         const string& resPath )
2423{
2424    cout << "*** CALCULATING CONFIDENCES FOR CLASS " << objClassName << " ***" << endl;
2425    cout << "CALCULATING BOW VECTORS FOR TEST SET OF " << objClassName << "..." << endl;
2426    // Get classification ground truth for images in the test set
2427    vector<ObdImage> images;
2428    vector<Mat> bowImageDescriptors;
2429    vector<char> objectPresent;
2430    vocData.getClassImages( objClassName, CV_OBD_TEST, images, objectPresent );
2431
2432    // Compute the bag of words vector for each image in the test set
2433    calculateImageDescriptors( images, bowImageDescriptors, bowExtractor, fdetector, resPath );
2434    // Remove any images for which descriptors could not be calculated
2435    removeEmptyBowImageDescriptors( images, bowImageDescriptors, objectPresent);
2436
2437    // Use the bag of words vectors to calculate classifier output for each image in test set
2438    cout << "CALCULATING CONFIDENCE SCORES FOR CLASS " << objClassName << "..." << endl;
2439    vector<float> confidences( images.size() );
2440    float signMul;
2441    for( size_t imageIdx = 0; imageIdx < images.size(); imageIdx++ )
2442    {
2443        if( imageIdx == 0 )
2444        {
2445            // In the first iteration, determine the sign of the positive class
2446            float classVal = confidences[imageIdx] = svm.predict( bowImageDescriptors[imageIdx], false );
2447            float scoreVal = confidences[imageIdx] = svm.predict( bowImageDescriptors[imageIdx], true );
2448            signMul = (classVal < 0) == (scoreVal < 0) ? 1.f : -1.f;
2449        }
2450        // svm output of decision function
2451        confidences[imageIdx] = signMul * svm.predict( bowImageDescriptors[imageIdx], true );
2452    }
2453
2454    cout << "WRITING QUERY RESULTS TO VOC RESULTS FILE FOR CLASS " << objClassName << "..." << endl;
2455    vocData.writeClassifierResultsFile( resPath + plotsDir, objClassName, CV_OBD_TEST, images, confidences, 1, true );
2456
2457    cout << "DONE - " << objClassName << endl;
2458    cout << "---------------------------------------------------------------" << endl;
2459}
2460
2461void computeGnuPlotOutput( const string& resPath, const string& objClassName, VocData& vocData )
2462{
2463    vector<float> precision, recall;
2464    float ap;
2465
2466    const string resultFile = vocData.getResultsFilename( objClassName, CV_VOC_TASK_CLASSIFICATION, CV_OBD_TEST);
2467    const string plotFile = resultFile.substr(0, resultFile.size()-4) + ".plt";
2468
2469    cout << "Calculating precision recall curve for class '" <<objClassName << "'" << endl;
2470    vocData.calcClassifierPrecRecall( resPath + plotsDir + "/" + resultFile, precision, recall, ap, true );
2471    cout << "Outputting to GNUPlot file..." << endl;
2472    vocData.savePrecRecallToGnuplot( resPath + plotsDir + "/" + plotFile, precision, recall, ap, objClassName, CV_VOC_PLOT_PNG );
2473}
2474
2475/* Input parameters
2476 * [VOC path]             Path to Pascal VOC data (e.g. /home/my/VOCdevkit/VOC2010). Note: VOC2007-VOC2010 are supported.
2477 * [result directory]     Path to result diractory. Following folders will be created in [result directory]:
2478 *                          bowImageDescriptors - to store image descriptors,
2479 *                          svms - to store trained svms,
2480 *                          plots - to store files for plots creating.
2481 * [feature detector]     Feature detector name (e.g. SURF, FAST...) - see createFeatureDetector() function.
2482 * [descriptor extractor] Descriptor extractor name (e.g. SURF, SIFT) - see createDescriptorExtractor() function.
2483 * [descriptor matcher]   Descriptor matcher name (e.g. BruteForce) - see createDescriptorMatcher() function.
2484 */
2485int main(int argc, char** argv)
2486{
2487    if( argc != 3 && argc != 6 )
2488    {
2489        cout << "Format: " << endl <<
2490                "   ./" << argv[0] << " [VOC path] [result directory] " << endl <<
2491                "       or" << endl <<
2492                "   ./" << argv[0] << " [VOC path] [result directory] [feature detector] [descriptor extractor] [descriptor matcher]" << endl;
2493        return -1;
2494    }
2495
2496    const string vocPath = argv[1], resPath = argv[2];
2497
2498    // Read or set default parameters
2499    string vocName;
2500    DDMParams ddmParams;
2501    VocabTrainParams vocabTrainParams;
2502    SVMTrainParamsExt svmTrainParamsExt;
2503
2504    makeUsedDirs( resPath );
2505
2506    FileStorage paramsFS( resPath + "/" + paramsFile, FileStorage::READ );
2507    if( paramsFS.isOpened() )
2508    {
2509       readUsedParams( paramsFS.root(), vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2510       CV_Assert( vocName == getVocName(vocPath) );
2511    }
2512    else
2513    {
2514        vocName = getVocName(vocPath);
2515        if( argc!= 6 )
2516        {
2517            cout << "Feature detector, descriptor extractor, descriptor matcher must be set" << endl;
2518            return -1;
2519        }
2520        ddmParams = DDMParams( argv[3], argv[4], argv[5] ); // from command line
2521        // vocabTrainParams and svmTrainParamsExt is set by defaults
2522        paramsFS.open( resPath + "/" + paramsFile, FileStorage::WRITE );
2523        if( paramsFS.isOpened() )
2524        {
2525            writeUsedParams( paramsFS, vocName, ddmParams, vocabTrainParams, svmTrainParamsExt );
2526            paramsFS.release();
2527        }
2528        else
2529        {
2530            cout << "File " << (resPath + "/" + paramsFile) << "can not be opened to write" << endl;
2531            return -1;
2532        }
2533    }
2534
2535    // Create detector, descriptor, matcher.
2536    Ptr<FeatureDetector> featureDetector = createFeatureDetector( ddmParams.detectorType );
2537    Ptr<DescriptorExtractor> descExtractor = createDescriptorExtractor( ddmParams.descriptorType );
2538    Ptr<BOWImgDescriptorExtractor> bowExtractor;
2539    if( featureDetector.empty() || descExtractor.empty() )
2540    {
2541        cout << "featureDetector or descExtractor was not created" << endl;
2542        return -1;
2543    }
2544    {
2545        Ptr<DescriptorMatcher> descMatcher = createDescriptorMatcher( ddmParams.matcherType );
2546        if( featureDetector.empty() || descExtractor.empty() || descMatcher.empty() )
2547        {
2548            cout << "descMatcher was not created" << endl;
2549            return -1;
2550        }
2551        bowExtractor = new BOWImgDescriptorExtractor( descExtractor, descMatcher );
2552    }
2553
2554    // Print configuration to screen
2555    printUsedParams( vocPath, resPath, ddmParams, vocabTrainParams, svmTrainParamsExt );
2556    // Create object to work with VOC
2557    VocData vocData( vocPath, false );
2558
2559    // 1. Train visual word vocabulary if a pre-calculated vocabulary file doesn't already exist from previous run
2560    Mat vocabulary = trainVocabulary( resPath + "/" + vocabularyFile, vocData, vocabTrainParams,
2561                                      featureDetector, descExtractor );
2562    bowExtractor->setVocabulary( vocabulary );
2563
2564    // 2. Train a classifier and run a sample query for each object class
2565    const vector<string>& objClasses = vocData.getObjectClasses(); // object class list
2566    for( size_t classIdx = 0; classIdx < objClasses.size(); ++classIdx )
2567    {
2568        // Train a classifier on train dataset
2569        CvSVM svm;
2570        trainSVMClassifier( svm, svmTrainParamsExt, objClasses[classIdx], vocData,
2571                            bowExtractor, featureDetector, resPath );
2572
2573        // Now use the classifier over all images on the test dataset and rank according to score order
2574        // also calculating precision-recall etc.
2575        computeConfidences( svm, objClasses[classIdx], vocData,
2576                            bowExtractor, featureDetector, resPath );
2577        // Calculate precision/recall/ap and use GNUPlot to output to a pdf file
2578        computeGnuPlotOutput( resPath, objClasses[classIdx], vocData );
2579    }
2580    return 0;
2581}
Note: See TracBrowser for help on using the browser.