train文件中的代码往往分为dataset.py
, module.py
, trainer.py
。为了测试这三款文件中的代码,我们准备了以下TinyModule。
在x.1中是不带core.py
版本,在x.2中是带core.py
版本。
x.1.1 dataset.py
dataset.py
主要书写Dataset派生类,测试代码如下,
if __name__=="__main__":
# test Dataset
ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
True,
None,)
dl = DataLoader(ds, batch_size=4, num_workers=2)
print(next(iter(dl))[0].shape)
x.1.2 module.py
dataset.py
主要书写网络结构,我们需要创建简易Dataset和简易Trainer来进行测试,代码如下,
if __name__=="__main__":
Net = "your network"
from torch.utils.data import DataLoader, Dataset
import torch
class TinyDataset(Dataset):
def __init__(self, X, Y):
# 定义好 image 的路径
self.X, self.Y = X, Y
def __getitem__(self, index):
return self.X[index], self.Y[index]
def __len__(self):
return len(self.X)
class TinyTrainer():
def fit():
X_tensor = torch.ones((4,1,32, 256, 256))
Y_tensor = torch.zeros((4,1,32, 256, 256))
mydataset = TinyDataset(X_tensor, Y_tensor)
train_loader = DataLoader(mydataset, batch_size=2, shuffle=True)
net=Net()
print(net)
import torch.nn as nn
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
# 3) Training loop
for epoch in range(10):
for i, (X, y) in enumerate(train_loader):
# predict = forward pass with our model
pred = net(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch={},i={}'.format(epoch,i))
TinyTrainer().fit()
x.1.3 trainer.py
trainer.py
主要进行网络训练,我们需要创建极简网络进行训练,代码如下,
if __name__=="__main__":
import torch.nn as nn
class TinyNet(nn.Module):
def __init__(self, input=28*28, output=28*28):
super().__init__()
# define any number of nn.Modules (or use your current ones)
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
def forward(self, x):
y = self.encoder(x)
z = self.decoder(y)
return z
Net = TinyNet
x.2.1 dataset.py
dataset.py
主要书写Dataset派生类和DataModule派生类,测试代码如下,
if __name__=="__main__":
# test Dataset
ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
True,
None,)
dl = DataLoader(ds, batch_size=4, num_workers=2)
print(next(iter(dl))[0].shape)
# test DataModule
root = "/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128"
dm = MicroDLDM(root=root)
print(next(iter(dm.train_dataloader()))[0].shape)