文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 8c28bb60 by 王肇一

mrnet now ready

parent 9b02e4db
......@@ -5,4 +5,6 @@ data/masks/*
data/output/*
data/train_imgs/*
data/train_masks/*
data/train_imgs_32/*
.ipynb_checkpoints/
runs
......@@ -14,7 +14,7 @@ class MultiUnet(nn.Module):
self.bilinear = bilinear
self.inconv = nn.Sequential(
nn.Conv2d(n_channels, 8, kernel_size = 3, stride = 1, padding_mode = 'same'),
nn.Conv2d(n_channels, 8, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(8),
nn.ReLU(inplace = True)
)
......@@ -41,7 +41,7 @@ class MultiUnet(nn.Module):
self.pool = nn.MaxPool2d(2)
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding_mode = 'same')
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding = 1)
def forward(self, x):
x = self.inconv(x)
......
......@@ -16,8 +16,8 @@ class MultiResBlock(nn.Module):
def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1,
padding_mode = 'same'), nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
def forward(self, x):
shortcut = nn.Sequential(
......@@ -41,9 +41,9 @@ class TransCompose(nn.Module):
self.inc = in_channels
self.outc = out_channels
def forward(self,x):
def forward(self, x):
return nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 3, stride = 2),
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 2, stride = 2),
nn.BatchNorm2d(self.outc)
)(x)
......@@ -61,7 +61,8 @@ class ResPath(nn.Module):
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True))(x)
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
......
......@@ -18,7 +18,7 @@ dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
......@@ -27,7 +27,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()
for epoch in tqdm(range(epochs)):
net.train()
......
......@@ -40,14 +40,13 @@ if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
#net = UNet(n_channels = 1, n_classes = 1)
# net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n'
......@@ -64,8 +63,8 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try:
mrnet.train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100)
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
......
......@@ -23,11 +23,10 @@ class BasicDataset(Dataset):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = 256,256 #int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
def preprocess(cls, pil_img):
#newW, newH = 256,256 #int(scale * w), int(scale * h)
#assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((256, 256))
img_nd = np.array(pil_img)
......@@ -56,7 +55,7 @@ class BasicDataset(Dataset):
assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
img = self.preprocess(img)
mask = self.preprocess(mask)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment