TensorFlow Hub (2)-Object Detection Colab
文章目录
导入和设置
import os
import pathlib
import matplotlib
import matplotlib.pyplot as plt
import io
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from six.moves.urllib.request import urlopen
import tensorflow as tf
import tensorflow_hub as hub
tf.get_logger().setLevel('ERROR')
实用工具函数
运行以下单元格以创建稍后将需要的一些实用程序:
包括load_image_into_numpy_array 将图片转换为numpy array
模型的字典
IMAGES_FOR_TEST 图片名称和地址的字典
COCO17_HUMAN_POSE_KEYPOINTS 人体姿态关键点 的list
加载图像的辅助方法utilities
Map of Model Name to TF Hub handle 模型名称到TF Hub handle的映射
COCO 2017数据集的带有人类关键点的元组列表。具有关键点的模型需要此功能。
# @title Run this!!
def load_image_into_numpy_array(path):
"""Load an image from file into a numpy array.
Puts image into numpy array to feed into tensorflow graph.
Note that by convention we put it into a numpy array with shape
(height, width, channels), where channels=3 for RGB.
Args:
path: the file path to the image
Returns:
uint8 numpy array with shape (img_height, img_width, 3)
"""
image = None
if(path.startswith('http')):
response = urlopen(path)
image_data = response.read()
image_data = BytesIO(image_data)
image = Image.open(image_data)
else:
image_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(image_data))
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(1, im_height, im_width, 3)).astype(np.uint8)
ALL_MODELS = {
'CenterNet HourGlass104 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/hourglass_512x512/1',
'CenterNet HourGlass104 Keypoints 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/hourglass_512x512_kpts/1',
'CenterNet HourGlass104 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/hourglass_1024x1024/1',
'CenterNet HourGlass104 Keypoints 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/hourglass_1024x1024_kpts/1',
'CenterNet Resnet50 V1 FPN 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/resnet50v1_fpn_512x512/1',
'CenterNet Resnet50 V1 FPN Keypoints 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/resnet50v1_fpn_512x512_kpts/1',
'CenterNet Resnet101 V1 FPN 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/resnet101v1_fpn_512x512/1',
'CenterNet Resnet50 V2 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/resnet50v2_512x512/1',
'CenterNet Resnet50 V2 Keypoints 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/centernet/resnet50v2_512x512_kpts/1',
'EfficientDet D0 512x512' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d0/1',
'EfficientDet D1 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d1/1',
'EfficientDet D2 768x768' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d2/1',
'EfficientDet D3 896x896' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d3/1',
'EfficientDet D4 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d4/1',
'EfficientDet D5 1280x1280' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d5/1',
'EfficientDet D6 1280x1280' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d6/1',
'EfficientDet D7 1536x1536' : 'https://hub.tensorflow.google.cn/tensorflow/efficientdet/d7/1',
'SSD MobileNet v2 320x320' : 'https://hub.tensorflow.google.cn/tensorflow/ssd_mobilenet_v2/2',
'SSD MobileNet V1 FPN 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/ssd_mobilenet_v1/fpn_640x640/1',
'SSD MobileNet V2 FPNLite 320x320' : 'https://hub.tensorflow.google.cn/tensorflow/ssd_mobilenet_v2/fpnlite_320x320/1',
'SSD MobileNet V2 FPNLite 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/ssd_mobilenet_v2/fpnlite_640x640/1',
'SSD ResNet50 V1 FPN 640x640 (RetinaNet50)' : 'https://hub.tensorflow.google.cn/tensorflow/retinanet/resnet50_v1_fpn_640x640/1',
'SSD ResNet50 V1 FPN 1024x1024 (RetinaNet50)' : 'https://hub.tensorflow.google.cn/tensorflow/retinanet/resnet50_v1_fpn_1024x1024/1',
'SSD ResNet101 V1 FPN 640x640 (RetinaNet101)' : 'https://hub.tensorflow.google.cn/tensorflow/retinanet/resnet101_v1_fpn_640x640/1',
'SSD ResNet101 V1 FPN 1024x1024 (RetinaNet101)' : 'https://hub.tensorflow.google.cn/tensorflow/retinanet/resnet101_v1_fpn_1024x1024/1',
'SSD ResNet152 V1 FPN 640x640 (RetinaNet152)' : 'https://hub.tensorflow.google.cn/tensorflow/retinanet/resnet152_v1_fpn_640x640/1',
'SSD ResNet152 V1 FPN 1024x1024 (RetinaNet152)' : 'https://hub.tensorflow.google.cn/tensorflow/retinanet/resnet152_v1_fpn_1024x1024/1',
'Faster R-CNN ResNet50 V1 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet50_v1_640x640/1',
'Faster R-CNN ResNet50 V1 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet50_v1_1024x1024/1',
'Faster R-CNN ResNet50 V1 800x1333' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet50_v1_800x1333/1',
'Faster R-CNN ResNet101 V1 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet101_v1_640x640/1',
'Faster R-CNN ResNet101 V1 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet101_v1_1024x1024/1',
'Faster R-CNN ResNet101 V1 800x1333' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet101_v1_800x1333/1',
'Faster R-CNN ResNet152 V1 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet152_v1_640x640/1',
'Faster R-CNN ResNet152 V1 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet152_v1_1024x1024/1',
'Faster R-CNN ResNet152 V1 800x1333' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/resnet152_v1_800x1333/1',
'Faster R-CNN Inception ResNet V2 640x640' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1',
'Faster R-CNN Inception ResNet V2 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/faster_rcnn/inception_resnet_v2_1024x1024/1',
'Mask R-CNN Inception ResNet V2 1024x1024' : 'https://hub.tensorflow.google.cn/tensorflow/mask_rcnn/inception_resnet_v2_1024x1024/1'
}
IMAGES_FOR_TEST = {
'Beach' : 'models/research/object_detection/test_images/image2.jpg',
'Dogs' : 'models/research/object_detection/test_images/image1.jpg',
# By Heiko Gorski, Source: https://commons.wikimedia.org/wiki/File:Naxos_Taverna.jpg
'Naxos Taverna' : 'https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg',
# Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg
'Beatles' : 'https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg',
# By Américo Toledano, Source: https://commons.wikimedia.org/wiki/File:Biblioteca_Maim%C3%B3nides,_Campus_Universitario_de_Rabanales_007.jpg
'Phones' : 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg/1024px-Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg',
# Source: https://commons.wikimedia.org/wiki/File:The_smaller_British_birds_(8053836633).jpg
'Birds' : 'https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053836633%29.jpg',
}
COCO17_HUMAN_POSE_KEYPOINTS = [(0, 1),
(0, 2),
(1, 3),
(2, 4),
(0, 5),
(0, 6),
(5, 7),
(7, 9),
(6, 8),
(8, 10),
(5, 6),
(5, 11),
(6, 12),
(11, 12),
(11, 13),
(13, 15),
(12, 14),
(14, 16)]
可视化工具
为了使用适当的检测框,关键点和分割形象化图像,我们将使用TensorFlow对象检测API。要安装它,我们将克隆存储库。
# Clone the tensorflow models repository
git clone --depth 1 https://github.com/tensorflow/models
Cloning into 'models'...
remote: Enumerating objects: 2404, done.[K
remote: Counting objects: 100% (2404/2404), done.[K
remote: Compressing objects: 100% (2001/2001), done.[K
remote: Total 2404 (delta 569), reused 1415 (delta 376), pack-reused 0[K
Receiving objects: 100% (2404/2404), 30.77 MiB | 11.60 MiB/s, done.
Resolving deltas: 100% (569/569), done.
安装对象检测API
sudo apt install -y protobuf-compiler
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install -q .
现在我们可以导入以后需要的依赖项
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import ops as utils_ops
%matplotlib inline
加载标签图数据(用于绘图)
标签图将索引号与类别名称相对应,因此当我们的卷积网络预测时5,我们知道它与相对应airplane。这里我们使用内部实用程序函数,但是任何返回将整数映射到适当的字符串标签的字典的方法都可以。
为了简单起见,我们将从加载对象检测API代码的存储库中加载
PATH_TO_LABELS = './models/research/object_detection/data/mscoco_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
建立检测模型并加载预训练的模型权重
在这里,我们将选择我们将使用的对象检测模型。选择架构,它将自动加载。如果要更改模型以在以后尝试其他架构,只需更改下一个单元并执行后续单元即可。
提示:如果要阅读有关所选模型的更多详细信息,可以单击链接(模型句柄)并阅读TF Hub上的其他文档。选择模型后,我们将打印手柄以使其更加容易。
model_display_name = 'CenterNet HourGlass104 Keypoints 512x512' # @param ['CenterNet HourGlass104 512x512','CenterNet HourGlass104 Keypoints 512x512','CenterNet HourGlass104 1024x1024','CenterNet HourGlass104 Keypoints 1024x1024','CenterNet Resnet50 V1 FPN 512x512','CenterNet Resnet50 V1 FPN Keypoints 512x512','CenterNet Resnet101 V1 FPN 512x512','CenterNet Resnet50 V2 512x512','CenterNet Resnet50 V2 Keypoints 512x512','EfficientDet D0 512x512','EfficientDet D1 640x640','EfficientDet D2 768x768','EfficientDet D3 896x896','EfficientDet D4 1024x1024','EfficientDet D5 1280x1280','EfficientDet D6 1280x1280','EfficientDet D7 1536x1536','SSD MobileNet v2 320x320','SSD MobileNet V1 FPN 640x640','SSD MobileNet V2 FPNLite 320x320','SSD MobileNet V2 FPNLite 640x640','SSD ResNet50 V1 FPN 640x640 (RetinaNet50)','SSD ResNet50 V1 FPN 1024x1024 (RetinaNet50)','SSD ResNet101 V1 FPN 640x640 (RetinaNet101)','SSD ResNet101 V1 FPN 1024x1024 (RetinaNet101)','SSD ResNet152 V1 FPN 640x640 (RetinaNet152)','SSD ResNet152 V1 FPN 1024x1024 (RetinaNet152)','Faster R-CNN ResNet50 V1 640x640','Faster R-CNN ResNet50 V1 1024x1024','Faster R-CNN ResNet50 V1 800x1333','Faster R-CNN ResNet101 V1 640x640','Faster R-CNN ResNet101 V1 1024x1024','Faster R-CNN ResNet101 V1 800x1333','Faster R-CNN ResNet152 V1 640x640','Faster R-CNN ResNet152 V1 1024x1024','Faster R-CNN ResNet152 V1 800x1333','Faster R-CNN Inception ResNet V2 640x640','Faster R-CNN Inception ResNet V2 1024x1024','Mask R-CNN Inception ResNet V2 1024x1024']
model_handle = ALL_MODELS[model_display_name]
print('Selected model:'+ model_display_name)
print('Model Handle at TensorFlow Hub: {}'.format(model_handle))
Selected model:CenterNet HourGlass104 Keypoints 512x512
Model Handle at TensorFlow Hub: https://hub.tensorflow.google.cn/tensorflow/centernet/hourglass_512x512_kpts/1
从TensorFlow Hub加载所选模型
在这里,我们只需要选择的模型句柄model handle并使用Tensorflow Hub库将其加载到内存。
print('loading model...')
hub_model = hub.load(model_handle)
print('model loaded!')
加载图像
让我们在简单的图像上尝试模型。为了解决这个问题,我们提供了测试图像列表。
如果您感到好奇,可以尝试以下一些简单的方法:
- 尝试在您自己的图像上运行推理,只需将其上传到colab并以与在下面的单元格中完成的方式相同的方式加载即可。
- 修改一些输入图像,然后查看检测是否仍然有效。这里可以尝试一些简单的操作,包括水平翻转图像或转换为灰度(请注意,我们仍然希望输入图像具有3个通道)。
请注意:当使用带有Alpha通道的图像时,模型需要3个通道的图像,并且Alpha将被视为第4个。
图像选择(不要忘记执行单元!)
是否水平翻转图像,是否将图像转换为灰度
selected_image = 'Beach' # @param ['Beach', 'Dogs', 'Naxos Taverna', 'Beatles', 'Phones', 'Birds']
flip_image_horizontally = False
convert_image_to_grayscale = False
image_path = IMAGES_FOR_TEST[selected_image]
image_np = load_image_into_numpy_array(image_path)
# Flip horizontally
if(flip_image_horizontally):
image_np[0] = np.fliplr(image_np[0]).copy()
# Convert image to grayscale
if(convert_image_to_grayscale):
image_np[0] = np.tile(
np.mean(image_np[0], 2, keepdims=True), (1, 1, 3)).astype(np.uint8)
plt.figure(figsize=(24,32))
plt.imshow(image_np[0])
plt.show()
进行推断
为了进行推断,我们只需要调用TF Hub加载的模型即可。
您可以尝试的事情:
- 打印出result[‘detection_boxes’]并尝试使框的位置与图像中的框匹配。请注意,坐标以规范化形式给出(即,在区间[0,1]中)。
- 检查结果中存在的其他输出键。可以在模型文档页面上找到完整的文档(将浏览器指向先前打印的模型手柄)
image_np是前面把图像读取成numpy_array
传入hub_model
然后列举results 结果字典
# running inference
results = hub_model(image_np)
# different object detection models have additional results
# all of them are explained in the documentation
result = {key:value.numpy() for key,value in results.items()}
print(result.keys())
dict_keys(['detection_keypoints', 'detection_scores', 'detection_classes', 'num_detections', 'detection_boxes', 'detection_keypoint_scores'])
可视化结果
这是我们需要TensorFlow对象检测API的地方,以显示推理inference步骤中的平方(以及可用时的关键点)。
有关此方法的完整文档visualization_utils.py,请参见此处
例如,您可以在此处设置min_score_thresh其他值(0到1之间)以允许更多检测或过滤出更多检测。
执行逻辑:
获得预测结果字典中的
detection_keypoints,
detection_keypoint_scores
调用
viz_utils.visualize_boxes_and_labels_on_image_array()
传入参数
1、图形的np—array
2、result['detection_boxes'][0] 检测框
3、(result['detection_classes'][0] + label_id_offset).astype(int),
label_id_offset是label——id的偏移量,即有时候要改变一下id
4、result['detection_scores'][0] 检测得分
5、category_index ,这个是调用函数获得的,将index转换为string类别
label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
6、 use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False,
7、 keypoints=keypoints,
keypoint_scores=keypoint_scores,
keypoint_edges=COCO17_HUMAN_POSE_KEYPOINTS是前面定义的字典
label_id_offset = 0
image_np_with_detections = image_np.copy()
# 获得关键点和关键点分数
# Use keypoints if available in detections
keypoints, keypoint_scores = None, None
if 'detection_keypoints' in result:
keypoints = result['detection_keypoints'][0]
keypoint_scores = result['detection_keypoint_scores'][0]
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections[0],
result['detection_boxes'][0],
(result['detection_classes'][0] + label_id_offset).astype(int),
result['detection_scores'][0],
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False,
keypoints=keypoints,
keypoint_scores=keypoint_scores,
keypoint_edges=COCO17_HUMAN_POSE_KEYPOINTS)
plt.figure(figsize=(24,32))
plt.imshow(image_np_with_detections[0])
plt.show()
[可选的]
在可用的对象检测模型中,有Mask R-CNN,该模型的输出允许实例分割instance segmentation.。
为了使它可视化,我们将使用与之前相同的方法,但添加一个附加参数: instance_masks=output_dict.get(‘detection_masks_reframed’, None)
# Handle models with masks:
image_np_with_mask = image_np.copy()
if 'detection_masks' in result:
# we need to convert np.arrays to tensors
detection_masks = tf.convert_to_tensor(result['detection_masks'][0])
detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])
# Reframe the the bbox mask to the image size.
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
detection_masks, detection_boxes,
image_np.shape[1], image_np.shape[2])
detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5,
tf.uint8)
result['detection_masks_reframed'] = detection_masks_reframed.numpy()
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_mask[0],
result['detection_boxes'][0],
(result['detection_classes'][0] + label_id_offset).astype(int),
result['detection_scores'][0],
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False,
instance_masks=result.get('detection_masks_reframed', None),
line_thickness=8)
plt.figure(figsize=(24,32))
plt.imshow(image_np_with_mask[0])
plt.show()