Bootstrap

nx上darknet的使用-目标分类-在python中的使用

目录

1  预测单张图像

2  摄像头调用模型


1  预测单张图像

要把py文件放在与darknet.py相同的路径下

import darknet

def classify(net, meta, im):
    out = darknet.predict_image(net, im)
    res = []
    for i in range(meta.classes):
        res.append((meta.names[i], out[i]))
    res = sorted(res, key=lambda x: -x[1])
    return res

net = darknet.load_net("/home/suyu/darknet/custom_classification/my.cfg".encode("ascii"), "/home/suyu/darknet/custom_classification/trained_models/my_final.weights".encode("ascii"), 0)
meta = darknet.load_meta("/home/suyu/darknet/custom_classification/custom_training.data".encode("ascii"))

im = darknet.load_image("/home/suyu/darknet/custom_classification/dataset/airplane/airplane_149.jpg".encode("ascii"), 0, 0)

r = classify(net, meta, im)
print(r[:10])

最后的r是个数组,按照置信度排序

上面是输入图像路径预测的,下面是用opencv读图后预测的

import darknet

def classify(net, meta, im):
    out = darknet.predict_image(net, im)
    res = []
    for i in range(meta.classes):
        res.append((meta.names[i], out[i]))
    res = sorted(res, key=lambda x: -x[1])
    return res

net = darknet.load_net("/home/suyu/darknet/custom_classification/my.cfg".encode("ascii"), "/home/suyu/darknet/custom_classification/trained_models/my_final.weights".encode("ascii"), 0)
meta = darknet.load_meta("/home/suyu/darknet/custom_classification/custom_training.data".encode("ascii"))

import cv2

width = darknet.network_width(net)
height = darknet.network_height(net)
im = darknet.make_image(width, height, 3)
img = cv2.imread('/home/hdkj/darknet/custom_classification/dataset/airplane/airplane_149.jpg')
darknet.copy_image_from_bytes(im,img.tobytes())

r = classify(net, meta, im)
print(r[:10])

经测试效果一样

2  摄像头调用模型

要把py文件放在与darknet.py相同的路径下

import cv2
import numpy as np
import queue
import threading
import time
import darknet

def classify(net, meta, im):
    out = darknet.predict_image(net, im)
    res = []
    for i in range(meta.classes):
        res.append((meta.names[i], out[i]))
    res = sorted(res, key=lambda x: -x[1])
    return res

net = darknet.load_net("/home/suyu/darknet/custom_classification/my.cfg".encode("ascii"), "/home/suyu/darknet/custom_classification/trained_models/my_final.weights".encode("ascii"), 0)
meta = darknet.load_meta("/home/suyu/darknet/custom_classification/custom_training.data".encode("ascii"))

width = darknet.network_width(net)
height = darknet.network_height(net)
darknet_image = darknet.make_image(width, height, 3)

cap = cv2.VideoCapture(0)

frame_queue = queue.Queue(maxsize=1)
detection_result_queue = queue.Queue(maxsize=1)

def video_capture():
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image_resized = cv2.resize(image_rgb, (width, height), interpolation=cv2.INTER_LINEAR)
            frame_queue.put(image_resized)
    cap.release()


def predict():
    while cap.isOpened():
        start_time = time.time()

        image_resized = frame_queue.get()

        darknet.copy_image_from_bytes(darknet_image, image_resized.tobytes())
        detections = classify(net, meta, darknet_image)
        # darknet.free_image(darknet_image)
        #print(time.time() - start_time)
        print(detections)
        detection_result_queue.put(detections)
    cap.release()


def draw():
    while cap.isOpened():
        draw_frame = frame_queue.get()

        try:
            detections = detection_result_queue.get(block=False)
            label = detections[0][0].decode()
            confidence = round(detections[0][1],2)
            if confidence < 0.9:
                label = 'no_detection'
            draw_frame = cv2.putText(draw_frame, label + '  ' + str(confidence), (20,30), cv2.FONT_HERSHEY_SIMPLEX,1,(255, 0, 0),1,lineType=cv2.LINE_AA)

        except:
            draw_frame = cv2.putText(draw_frame, 'please wait', (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1,lineType=cv2.LINE_AA)
            pass

        draw_frame = cv2.cvtColor(draw_frame, cv2.COLOR_BGR2RGB)

        cv2.imshow('draw_frame', draw_frame)
        cv2.waitKey(1)
    cap.release()


threading.Thread(target=video_capture).start()
threading.Thread(target=predict).start()
threading.Thread(target=draw).start()

目标分类的特点是无论你给一张什么图,模型都能给你一个结果,但这个结果置信率不高。这个时候我们可以利用低置信率手动生成一个 什么都不是 的种类

经测试效果可以,在nx上0.07-0.08秒能识别一帧图像,预测效果流畅

;