文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 8c28bb60 by 王肇一

mrnet now ready

parent 9b02e4db
...@@ -5,4 +5,6 @@ data/masks/* ...@@ -5,4 +5,6 @@ data/masks/*
data/output/* data/output/*
data/train_imgs/* data/train_imgs/*
data/train_masks/* data/train_masks/*
data/train_imgs_32/*
.ipynb_checkpoints/ .ipynb_checkpoints/
runs
...@@ -14,7 +14,7 @@ class MultiUnet(nn.Module): ...@@ -14,7 +14,7 @@ class MultiUnet(nn.Module):
self.bilinear = bilinear self.bilinear = bilinear
self.inconv = nn.Sequential( self.inconv = nn.Sequential(
nn.Conv2d(n_channels, 8, kernel_size = 3, stride = 1, padding_mode = 'same'), nn.Conv2d(n_channels, 8, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(8), nn.BatchNorm2d(8),
nn.ReLU(inplace = True) nn.ReLU(inplace = True)
) )
...@@ -41,7 +41,7 @@ class MultiUnet(nn.Module): ...@@ -41,7 +41,7 @@ class MultiUnet(nn.Module):
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding_mode = 'same') self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding = 1)
def forward(self, x): def forward(self, x):
x = self.inconv(x) x = self.inconv(x)
......
...@@ -16,8 +16,8 @@ class MultiResBlock(nn.Module): ...@@ -16,8 +16,8 @@ class MultiResBlock(nn.Module):
def conv(self, in_channel, out_channel, kernel_size): def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1, nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1),
padding_mode = 'same'), nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True)) nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
def forward(self, x): def forward(self, x):
shortcut = nn.Sequential( shortcut = nn.Sequential(
...@@ -41,9 +41,9 @@ class TransCompose(nn.Module): ...@@ -41,9 +41,9 @@ class TransCompose(nn.Module):
self.inc = in_channels self.inc = in_channels
self.outc = out_channels self.outc = out_channels
def forward(self,x): def forward(self, x):
return nn.Sequential( return nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 3, stride = 2), nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 2, stride = 2),
nn.BatchNorm2d(self.outc) nn.BatchNorm2d(self.outc)
)(x) )(x)
...@@ -61,7 +61,8 @@ class ResPath(nn.Module): ...@@ -61,7 +61,8 @@ class ResPath(nn.Module):
nn.BatchNorm2d(self.outc))(x) nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential( conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1), nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True))(x) nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut) result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result) return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
......
...@@ -18,7 +18,7 @@ dir_mask = 'data/train_masks/' ...@@ -18,7 +18,7 @@ dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint' dir_checkpoint = 'checkpoint'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5): def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale) dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent) n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val n_train = len(dataset) - n_val
...@@ -27,7 +27,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -27,7 +27,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True) val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr) optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss() criterion = nn.BCEWithLogitsLoss()
for epoch in tqdm(range(epochs)): for epoch in tqdm(range(epochs)):
net.train() net.train()
......
...@@ -40,14 +40,13 @@ if __name__ == '__main__': ...@@ -40,14 +40,13 @@ if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}') logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images # n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel # n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1 # - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1 # - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N # - For N > 2 classes, use n_classes=N
#net = UNet(n_channels = 1, n_classes = 1) # net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1) net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n' logging.info(f'Network:\n'
...@@ -64,7 +63,7 @@ if __name__ == '__main__': ...@@ -64,7 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
mrnet.train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100) img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
......
...@@ -23,11 +23,10 @@ class BasicDataset(Dataset): ...@@ -23,11 +23,10 @@ class BasicDataset(Dataset):
return len(self.ids) return len(self.ids)
@classmethod @classmethod
def preprocess(cls, pil_img, scale): def preprocess(cls, pil_img):
w, h = pil_img.size #newW, newH = 256,256 #int(scale * w), int(scale * h)
newW, newH = 256,256 #int(scale * w), int(scale * h) #assert newW > 0 and newH > 0, 'Scale is too small'
assert newW > 0 and newH > 0, 'Scale is too small' pil_img = pil_img.resize((256, 256))
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img) img_nd = np.array(pil_img)
...@@ -56,7 +55,7 @@ class BasicDataset(Dataset): ...@@ -56,7 +55,7 @@ class BasicDataset(Dataset):
assert img.size == mask.size, \ assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale) img = self.preprocess(img)
mask = self.preprocess(mask, self.scale) mask = self.preprocess(mask)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment