Bootstrap

训练常用API函数及方法

Tensor转Image

output = img.data.squeeze().float().clamp_(0, 1).numpy()
if output.ndim == 3:
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # 1.BGR 2。HWC 
output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8

维度转置

#交换第二维度和第三维度
array.transpose(0, 2, 1)

训练模型

def train(model):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for input, label in tqdm(train_loader):
            output = model(input)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch + 1} loss: {epoch_loss:.4f}')

颜色转换

bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)

自定义数据集

    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path  # 图像文件夹路径
        self.transform = transform  # 可选的转换
        self.image_files = os.listdir(folder_path)  # 获取文件夹中所有文件

    def __len__(self):
        return len(self.image_files)  # 返回图像数量

    def __getitem__(self, idx):
        img_name = os.path.join(self.folder_path, self.image_files[idx])  # 构建图像路径
        image = Image.open(img_name)  # 加载图像
        label = 0  # 这里可以根据需求定义标签(例如,如果是单一类别,标签可以是0)

        if self.transform:
            image = self.transform(image)  # 应用转换

        return image, label  # 返回图像和标签

图表显示

import matplotlib.pyplot as plt
num_images=50
# 行数,列数,全局大小(长,宽)
fig, axes = plt.subplots(num_images, 2, figsize=(10, 5*num_images))
for i in range(num_images):
    axes[i][0].imshow(image) 
    axes[i][1].imshow(enhanced_image)
    axes[i][0].axis('off')
    axes[i][1].axis('off')
    
plt.show()

保存模型恢复训练

import torch
torch.save(model.state_dict(), 'model_checkpoint.pth')
# later
model = Model()
model.load_state_dict(torch.load('model_checkpoint.pth'))
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# continue to train
;