Evaluation tool internal refactors
PiperOrigin-RevId: 277008167
Change-Id: Ia89dda649d330e8a68cc6f110d631bd4d949c930
diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py
index 2d7efc3..9859385 100644
--- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py
+++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py
@@ -32,6 +32,7 @@
import argparse
import ast
+import collections
import os
import shutil
import sys
@@ -39,7 +40,7 @@
def _get_ground_truth_detections(instances_file,
- whitelist_file,
+ whitelist_file=None,
num_images=None):
"""Processes the annotations JSON file and returns ground truth data corresponding to whitelisted image IDs.
@@ -62,19 +63,24 @@
'bbox' to a list of dimension-normalized [top, left, bottom, right]
bounding-box values.
"""
- # Read whitelist.
- with open(whitelist_file, 'r') as whitelist:
- image_id_whitelist = set([int(x) for x in whitelist.readlines()])
-
# Read JSON data into a dict.
with open(instances_file, 'r') as annotation_dump:
data_dict = ast.literal_eval(annotation_dump.readline())
- image_data = {}
+ image_data = collections.OrderedDict()
all_file_names = []
+
+ # Read whitelist.
+ if whitelist_file is not None:
+ with open(whitelist_file, 'r') as whitelist:
+ image_id_whitelist = set([int(x) for x in whitelist.readlines()])
+ else:
+ image_id_whitelist = [image['id'] for image in data_dict['images']]
+
# Get image names and dimensions.
for image_dict in data_dict['images']:
- if image_dict['id'] not in image_id_whitelist:
+ image_id = image_dict['id']
+ if image_id not in image_id_whitelist:
continue
image_data_dict = {}
image_data_dict['file_name'] = image_dict['file_name']
@@ -82,7 +88,7 @@
image_data_dict['height'] = image_dict['height']
image_data_dict['width'] = image_dict['width']
image_data_dict['detections'] = []
- image_data[image_dict['id']] = image_data_dict
+ image_data[image_id] = image_data_dict
if num_images:
all_file_names.sort()
@@ -92,7 +98,9 @@
# Get detected object annotations per image.
for annotation_dict in data_dict['annotations']:
image_id = annotation_dict['image_id']
- if image_id not in image_id_whitelist or image_id not in image_data:
+ if image_id not in image_id_whitelist:
+ continue
+ if image_id not in image_data:
continue
image_data_dict = image_data[image_id]
if image_data_dict['file_name'] not in all_file_names:
@@ -186,7 +194,7 @@
'--whitelist_file',
type=str,
help='File with COCO image ids to preprocess, one on each line.',
- required=True)
+ required=False)
parser.add_argument(
'--num_images',
type=int,