Bootstrap

pytorch语义分割内置模型

import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

def decodeSegMap(image, nc=21):
    label_colors = np.array([(0,0,0),(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128),
                             (0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0),(192,128,0),(64,0,128),
                             (192,0,128),(64,128,128),(192,128,128),(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)])
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l][0]
        g[idx] = label_colors[l][1]
        b[idx] = label_colors[l][2]

    return np.stack([r,g,b], axis=2)


# 获取模型
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
# model.load_state_dict(torch.load('./deeplabv3_resnet50_coco-cd0a2569.pth'))
model = model.eval()

# 预处理
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(),
                      T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# 加载图片
img_pth = './dog.2.jpg'
img = Image.open(img_pth)
plt.imshow(img)
plt.axis('off')
plt.show()

img = transform(img).unsqueeze(0)
# 显示用transform转换后的图片
img_transform = np.transpose(img.detach().numpy()[0], (1, 2, 0))
plt.imshow(img_transform)
plt.show()

output = model(img)
print(f"输出结果的形状:{output['out'].shape}")
output = torch.argmax(output['out'].squeeze(), dim=0).detach().cpu().numpy()
result_class = set(list(output.flat))
print(result_class)

rgb = decodeSegMap(output)
img = Image.fromarray(rgb)
plt.axis('off')
plt.imshow(img)
plt.show()

原图:
在这里插入图片描述
在这里插入图片描述
结果图:
在这里插入图片描述

;