coco数据集十分庞大,最近打算从里面截取一个只包含person、bicycle、bus三个类别,图片总数约1000张的小型数据集,在此记录截取过程。
step1、将json瘦身,只包含这三个类别,segmentation不需要可以去掉,为了后续可以对应到图片,增加file_name字段存储图片路径。
import json
import hashlib
from tqdm import tqdm
import time
import os
def filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2):
st = time.time()
with open(input_path, 'r') as f:
data = json.load(f)
print('read cost: {}'.format(time.time() - st))
st = time.time()
id2name = {}
for img in tqdm(data['images']):
id2name[img['id']] = img['file_name']
categories_sample = [
{
"id": 1,
"name": "person"
},
{
"id": 2,
"name": "bicycle"
}
,
{
"id": 6,
"name": "bus"
}
]
newanns1 = {'annotations': [], "categories": categories_sample}
cate_counts = {1:0, 2:0, 6:0}
imgpath = '/data/det_coco2017/train2017/'
annnumber = 0
for ann in tqdm(data['annotations']):
if ann['category_id'] in category_ids:
cate_counts[ann['category_id']] += 1
annnumber += 1
ann.pop('segmentation', None)
ann['file_name'] = imgpath + id2name[ann['image_id']]
ann['md5'] = get_md5(imgpath, ann['file_name'])
newanns1['annotations'].append(ann)
print(cate_counts)
with open(output_path_1, 'w') as f:
json.dump(newanns1, f)
def get_md5(folder_path, image_file):
with open(os.path.join(folder_path, image_file), 'rb') as f:
image_data = f.read()
md5 = hashlib.md5(image_data).hexdigest()
return md5
if __name__ == '__main__':
input_path = '/data/det_coco2017/annotations/instances_train2017.json'
output_path_1 = 'train.json'
output_path_2 = 'test.json'
category_ids = [1, 2, 6]
num_images_1 = 1000
num_images_2 = 200
filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2)
step2、通过以下指令查看各类数目
cat train2.json | jq '[.annotations[] | select(.category_id == 6)] | length'
step3、取固定图片数目,重新生成json
import json
import hashlib
from tqdm import tqdm
import time
import os
def filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2):
st = time.time()
with open(input_path, 'r') as f:
data = json.load(f)
print('read cost: {}'.format(time.time() - st))
st = time.time()
img_ids = []
for ann in tqdm(data['annotations']):
if ann['category_id'] in [2,6]:
img_ids.append(ann['image_id'])
unique_list = list(set(img_ids))
unique_list = unique_list[:200]
newann = []
for ann in tqdm(data['annotations']):
if ann['image_id'] in unique_list:
newann.append(ann)
data['annotations'] = newann
print(unique_list, ' imgs')
with open(output_path_1, 'w') as f:
json.dump(data, f)
def get_md5(folder_path, image_file):
with open(os.path.join(folder_path, image_file), 'rb') as f:
image_data = f.read()
md5 = hashlib.md5(image_data).hexdigest()
return md5
if __name__ == '__main__':
input_path = '/data/det_coco2017/train.json'
output_path_1 = 'test.json'
output_path_2 = 'test.json'
category_ids = [1, 2, 6]
num_images_1 = 1000
num_images_2 = 200
filter_annotations(input_path, output_path_1, output_path_2, category_ids, num_images_1, num_images_2)
step4、将图片单独放到文件夹
cat test.json | jq -r '.annotations[].file_name' | xargs -I {} cp {} test/
step5、打包
zip -r test.zip test/ test.json