Resnet50进行迁移学习实现图片二分类
内容简介
本文使用预训练的Resnet50网络对皮肤病图片进行二分类,基于portch框架。
数据集说明
数据集存放目录为: used_dataset , 共200张图片,标签为:benign(良性)、malignant(患病)。
- 数据集划分如下:
数据集类型 | benign | malignant | totall |
---|---|---|---|
train | 64 | 64 | 128 |
val | 16 | 16 | 32 |
test | 20 | 20 | 40 |
代码目录介绍
- args.py 存放训练和测试所用的各种参数。
–mode字段表示运行模式:train or test.
–model_path字段是训练模型的保存路径。
其余字段都有默认值。 - create_dataset.py 该脚本是用来读json中的数据的,可以忽略。
- data_gen.py 该脚本实现划分数据集以及数据增强和数据加载。
- main.py 包含训练、评估和测试。
- transform.py 实现图片增强。
- utils.py 存放一些工具函数。
- models/Res.py 是重写的ResNet各种类型的网络。
- checkpoints 保存模型
运行命令
# 训练模型
python main.py --mode=train
# 测试模型
python main.py --mode=test --model_path='训练好的模型文件路径'
main.py 脚本介绍
main()函数 实现模型的训练和评估
- step1: 加载数据
# data
transformations = get_transforms(input_size=args.image_size,test_size=args.image_size)
train_set = data_gen.Dataset(root=args.train_txt_path,transform=transformations['val_train'])
train_loader = data.DataLoader(train_set,batch_size=args.batch_size,shuffle=True)
val_set = data_gen.ValDataset(root=args.val_txt_path,transform=transformations['val_test'])
val_loader = data.DataLoader(val_set,batch_size=args.batch_size,shuffle=False)
- step2: 构建模型
model = make_model(args)
if use_cuda:
model.cuda()
# define loss function and optimizer
if use_cuda:
criterion = nn.CrossEntropyLoss().cuda()
else:
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(model,args)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=