1. 下载安装JetCam
git clone https://github.com/NVIDIA-AI-IOT/jetcam
cd jetcam
sudo python3 setup.py install
2. 测试摄像头
2.1 查看摄像头设备
ls -ltrh /dev/video*
2.2 创建相机对象
from jetcam.csi_camera import CSICamera
camera = CSICamera(width=224, height=224)
2.3 通过摄像头捕获图片
image = camera.read()
print(image.shape)
2.4 查看相机的像素高度,像素宽度和颜色通道数
print(camera.value.shape)
2.5 在jupyter notebook/jupyter lab
显示图像信息
import ipywidgets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
image_widget = ipywidgets.Image(format='jpeg')
image_widget.value = bgr8_to_jpeg(image)
display(image_widget)
camera.running = True
def update_image(change):
image = change['new']
image_widget.value = bgr8_to_jpeg(image)
camera.observe(update_image, names='value')
2.6 关闭视频流
camera.unobserve(update_image, names='value')
3. 训练图像分类模型
3.1 打开摄像头
from jetcam.csi_camera import CSICamera
camera = CSICamera(width=224, height=224)
camera.running = True
3.2 定义训练的标签
import torchvision.transforms as transforms
from dataset import ImageClassificationDataset
TASK = 'classiofication_test'
CATEGORIES = ['dinosaur', 'giraffe']
DATASETS = ['A']
TRANSFORMS = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
datasets = {}
for name in DATASETS:
datasets[name] = ImageClassificationDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS)
print("任务{}已创建".format(TASK))
3.3 定于数据采集方法
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
dataset = datasets[DATASETS[0]]
camera.unobserve_all()
camera_widget = ipywidgets.Image()
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')
count_widget = ipywidgets.IntText(description='count')
save_widget = ipywidgets.Button(description='add')
count_widget.value = dataset.get_count(category_widget.value)
def set_dataset(change):
global dataset
dataset = datasets[change['new']]
count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')
def update_counts(change):
count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')
def save(c):
dataset.save_entry(camera.value, category_widget.value)
count_widget.value = dataset.get_count(category_widget.value)
save_widget.on_click(save)
data_collection_widget = ipywidgets.VBox([
ipywidgets.HBox([camera_widget]), dataset_widget, category_widget, count_widget, save_widget
])
3.4 定义模型
import torch
import torchvision
device = torch.device('cuda')
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, len(dataset.categories))
model = model.to(device)
model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='classification_model.pth')
def load_model(c):
model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
def save_model(c):
torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)
model_widget = ipywidgets.VBox([
model_path_widget,
ipywidgets.HBox([model_load_button, model_save_button])
])
3.5 定义实时执行的方法
import threading
import time
from utils import preprocess
import torch.nn.functional as F
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Text(description='prediction')
score_widgets = []
for category in dataset.categories:
score_widget = ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical')
score_widgets.append(score_widget)
def live(state_widget, model, camera, prediction_widget, score_widget):
global dataset
while state_widget.value == 'live':
image = camera.value
preprocessed = preprocess(image)
output = model(preprocessed)
output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
category_index = output.argmax()
prediction_widget.value = dataset.categories[category_index]
for i, score in enumerate(list(output)):
score_widgets[i].value = score
def start_live(change):
if change['new'] == 'live':
execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget, score_widget))
execute_thread.start()
state_widget.observe(start_live, names='value')
live_execution_widget = ipywidgets.VBox([
ipywidgets.HBox(score_widgets),
prediction_widget,
state_widget
])
print("组件已创建")
3.6 定义训练和评估方法
BATCH_SIZE = 8
optimizer = torch.optim.Adam(model.parameters())
epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
accuracy_widget = ipywidgets.FloatText(description='accuracy')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
def train_eval(is_training):
global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
try:
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True
)
state_widget.value = 'stop'
train_button.disabled = True
eval_button.disabled = True
time.sleep(1)
if is_training:
model = model.train()
else:
model = model.eval()
while epochs_widget.value > 0:
i = 0
sum_loss = 0.0
error_count = 0.0
for images, labels in iter(train_loader):
images = images.to(device)
labels = labels.to(device)
if is_training:
optimizer.zero_grad()
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
if is_training:
loss.backward()
optimizer.step()
error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())
count = len(labels.flatten())
i += count
sum_loss += float(loss)
progress_widget.value = i / len(dataset)
loss_widget.value = sum_loss / i
accuracy_widget.value = 1.0 - error_count / i
if is_training:
epochs_widget.value = epochs_widget.value - 1
else:
break
except e:
pass
model = model.eval()
train_button.disabled = False
eval_button.disabled = False
state_widget.value = 'live'
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
train_eval_widget = ipywidgets.VBox([
epochs_widget,
progress_widget,
loss_widget,
accuracy_widget,
ipywidgets.HBox([train_button, eval_button])
])
print("训练评估方法已创建")
3.7 显示交互工具
all_widget = ipywidgets.VBox([
ipywidgets.HBox([data_collection_widget, live_execution_widget]),
train_eval_widget,
model_widget
])
display(all_widget)