Bootstrap

深度学习第三站——图片分类任务(扩展半监督学习)

1、半监督具体实行过程

        半监督学习流程:先由“有标签数据集”训练模型,使模型达到一个较好的分类准确率(0.6),接着让模型对无标签的数据进行分类,对分类的结果中,让置信度高于0.99的部分作为有标签的数据集,加入到对模型的训练中。

(个人难以使分类准确率达到0.6,本人在代码中调用了resnet的特征提取器来完成这一过程。

2、增添的代码块

class semiDataset(Dataset):  # 评估得到的标签可靠性
    def __init__(self, no_label_loader, model, device, thres):  # 传入无标签参数,模型,gpu设备,置信度
        X, Y = self.data_pred(no_label_loader, model, device, thres)
        if X == []:
            self.flag = False
        else:
            self.flag = True
            self.X = np.array(X)  # 转为矩阵
            self.Y = torch.LongTensor(Y)
            self.transform = train_transform

    def data_pred(self, no_label_loader, model, device, thres):  # 打标签函数
        model = model.to(device)
        soft = nn.Softmax(dim=1)
        pred_prob = []  # 记录预测值
        labels = []  # 记录Y,标签
        x = []
        y = []
        with torch.no_grad():  # 非训练过程,不用计算梯度
            for data in no_label_loader:
                data = data[0].to(device)  # 取增广后的数据
                pred = model(data)
                pred_soft = soft(pred)  # 预测值->概率值
                pred_max, pred_value = pred_soft.max(1)  # 返回最大值以及对应下标
                pred_prob.extend(pred_max.cpu().numpy().tolist())  # 把全部概率存入数组  #append放一个数,extend放一组
                labels.extend(pred_value.cpu().numpy().tolist())  # 把分类结果存入数组
            for index, prob in enumerate(pred_prob):  # 枚举、列举的意思,用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据的下标和数据
                if prob > thres:
                    x.append(no_label_loader.dataset[index][1])  # 把对应的图片数据放入X中(未增广的原数据)
                    y.append(labels[index])  # x对应的标签
            return x, y

    def __getitem__(self, item):
        return self.transform(self.X[item]), self.Y[item]

    def __len__(self):
        return len(self.Y)

        本semiDataset类主要实现:传入无标签数据,对其进行分类,再由所设定置信度为门槛,决定对哪些数据进打上标签,最后返回两个列表,即X和对应的标签Y。

def get_semi_loader(no_label_loader, model, device, thres):
    semi_set = semiDataset(no_label_loader, model, device, thres)  # 经过半监督学习得到的训练集
    if semi_set.flag == False:
        return None
    semi_loader = DataLoader(semi_set, batch_size=16, shuffle=False)  # 一次取一批
    return semi_loader

        本get_semi_loader函数,当半监督训练集非空的时候,创建并返回一个数据提供器semi_loader(一次提供16份数据,不进行打乱)

;