1、半监督具体实行过程
半监督学习流程:先由“有标签数据集”训练模型,使模型达到一个较好的分类准确率(0.6),接着让模型对无标签的数据进行分类,对分类的结果中,让置信度高于0.99的部分作为有标签的数据集,加入到对模型的训练中。
(个人难以使分类准确率达到0.6,本人在代码中调用了resnet的特征提取器来完成这一过程。
)
2、增添的代码块
class semiDataset(Dataset): # 评估得到的标签可靠性
def __init__(self, no_label_loader, model, device, thres): # 传入无标签参数,模型,gpu设备,置信度
X, Y = self.data_pred(no_label_loader, model, device, thres)
if X == []:
self.flag = False
else:
self.flag = True
self.X = np.array(X) # 转为矩阵
self.Y = torch.LongTensor(Y)
self.transform = train_transform
def data_pred(self, no_label_loader, model, device, thres): # 打标签函数
model = model.to(device)
soft = nn.Softmax(dim=1)
pred_prob = [] # 记录预测值
labels = [] # 记录Y,标签
x = []
y = []
with torch.no_grad(): # 非训练过程,不用计算梯度
for data in no_label_loader:
data = data[0].to(device) # 取增广后的数据
pred = model(data)
pred_soft = soft(pred) # 预测值->概率值
pred_max, pred_value = pred_soft.max(1) # 返回最大值以及对应下标
pred_prob.extend(pred_max.cpu().numpy().tolist()) # 把全部概率存入数组 #append放一个数,extend放一组
labels.extend(pred_value.cpu().numpy().tolist()) # 把分类结果存入数组
for index, prob in enumerate(pred_prob): # 枚举、列举的意思,用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据的下标和数据
if prob > thres:
x.append(no_label_loader.dataset[index][1]) # 把对应的图片数据放入X中(未增广的原数据)
y.append(labels[index]) # x对应的标签
return x, y
def __getitem__(self, item):
return self.transform(self.X[item]), self.Y[item]
def __len__(self):
return len(self.Y)
本semiDataset类主要实现:传入无标签数据,对其进行分类,再由所设定置信度为门槛,决定对哪些数据进打上标签,最后返回两个列表,即X和对应的标签Y。
def get_semi_loader(no_label_loader, model, device, thres):
semi_set = semiDataset(no_label_loader, model, device, thres) # 经过半监督学习得到的训练集
if semi_set.flag == False:
return None
semi_loader = DataLoader(semi_set, batch_size=16, shuffle=False) # 一次取一批
return semi_loader
本get_semi_loader函数,当半监督训练集非空的时候,创建并返回一个数据提供器semi_loader(一次提供16份数据,不进行打乱)