import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def decodeSegMap(image, nc=21):
label_colors = np.array([(0,0,0),(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128),
(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0),(192,128,0),(64,0,128),
(192,0,128),(64,128,128),(192,128,128),(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)])
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, nc):
idx = image == l
r[idx] = label_colors[l][0]
g[idx] = label_colors[l][1]
b[idx] = label_colors[l][2]
return np.stack([r,g,b], axis=2)
# 获取模型
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
# model.load_state_dict(torch.load('./deeplabv3_resnet50_coco-cd0a2569.pth'))
model = model.eval()
# 预处理
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# 加载图片
img_pth = './dog.2.jpg'
img = Image.open(img_pth)
plt.imshow(img)
plt.axis('off')
plt.show()
img = transform(img).unsqueeze(0)
# 显示用transform转换后的图片
img_transform = np.transpose(img.detach().numpy()[0], (1, 2, 0))
plt.imshow(img_transform)
plt.show()
output = model(img)
print(f"输出结果的形状:{output['out'].shape}")
output = torch.argmax(output['out'].squeeze(), dim=0).detach().cpu().numpy()
result_class = set(list(output.flat))
print(result_class)
rgb = decodeSegMap(output)
img = Image.fromarray(rgb)
plt.axis('off')
plt.imshow(img)
plt.show()
原图:
结果图: