文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 9ef386bf by 王肇一

CE-net based module, pause for a while

parent af27b9fa
...@@ -82,13 +82,13 @@ def get_args(): ...@@ -82,13 +82,13 @@ def get_args():
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': def cli_main():
args = get_args() args = get_args()
path = [(y, x) for y in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs')) for x in filter( path = [(y, x) for y in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs')) for x in filter(
lambda x: x.endswith('.tif') and not x.endswith('dc.tif') and not x.endswith('DC.tif') and not x.endswith( lambda x: x.endswith('.tif') and not x.endswith('dc.tif') and not x.endswith('DC.tif') and not x.endswith(
'dc .tif'), os.listdir('data/imgs/' + y))] 'dc .tif'), os.listdir('data/imgs/' + y))]
seperate_path = divide_list(path,args.process) seperate_path = divide_list(path, args.process)
if args.step == 1: if args.step == 1:
net = UNet(n_channels = 1, n_classes = 1) net = UNet(n_channels = 1, n_classes = 1)
...@@ -107,10 +107,13 @@ if __name__ == '__main__': ...@@ -107,10 +107,13 @@ if __name__ == '__main__':
elif args.step == 2: elif args.step == 2:
dir = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/'))] dir = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/'))]
sep_dir = divide_list(dir,args.process) sep_dir = divide_list(dir, args.process)
pool = Pool(args.process) pool = Pool(args.process)
for i, list in enumerate(sep_dir): for i, list in enumerate(sep_dir):
pool.apply_async(step_2, args = (list, i)) pool.apply_async(step_2, args = (list, i))
pool.close() pool.close()
pool.join() pool.join()
if __name__ == '__main__':
cli_main()
\ No newline at end of file
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
\ No newline at end of file from .ceunet_model import CEUnet
\ No newline at end of file
...@@ -15,11 +15,14 @@ class CEUnet(nn.Module): ...@@ -15,11 +15,14 @@ class CEUnet(nn.Module):
resnet = models.resnet34(pretrained = True) resnet = models.resnet34(pretrained = True)
weight = resnet.conv1.weight weight = resnet.conv1.weight
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = nn.Conv2d(in_channels = n_channels, out_channels = 64, kernel_size = 7, stride = 2, padding = 1) self.inc = nn.Conv2d(in_channels = n_channels, out_channels = 64, kernel_size = 7, stride = 2, padding = 1)
self.inc.weight = nn.Parameter(weight[:, :1, :, :]) self.inc.weight = nn.Parameter(weight[:, :1, :, :])
self.bn1 = resnet.bn1 self.bn1 = resnet.bn1
self.relu1 = resnet.relu1 self.relu = resnet.relu
self.maxpool1 = resnet.maxpool self.maxpool1 = resnet.maxpool
self.encoder1 = resnet.layer1 self.encoder1 = resnet.layer1
...@@ -44,7 +47,7 @@ class CEUnet(nn.Module): ...@@ -44,7 +47,7 @@ class CEUnet(nn.Module):
def forward(self,x): def forward(self,x):
x = self.inc(x) x = self.inc(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu1(x) x = self.relu(x)
x = self.maxpool1(x) x = self.maxpool1(x)
e1 = self.encoder1(x) e1 = self.encoder1(x)
......
...@@ -57,11 +57,10 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -57,11 +57,10 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
for batch in train_loader: for batch in train_loader:
imgs = batch['image'] imgs = batch['image']
true_masks = batch['mask'] true_masks = batch['mask']
assert imgs.shape[ assert imgs.shape[1] == net.n_channels, \
1] == net.n_channels, f'Network has been defined with {net.n_channels} input channels, ' \ f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
f'' \ 'the images are loaded correctly.'
'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32 if net.n_classes == 1 else torch.long
...@@ -153,7 +152,7 @@ if __name__ == '__main__': ...@@ -153,7 +152,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
train_net(net = net, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, device = device, train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100) img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import argparse
import logging
import os
import sys
import torch
from ceunet import CEUnet
from train import get_args,train_net
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'ce_checkpoints/'
if __name__ == '__main__':
logging.basicConfig(level = logging.INFO, format = '%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net = CEUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n')
if args.load:
net.load_state_dict(torch.load(args.load, map_location = device))
logging.info(f'Model loaded from {args.load}')
net.to(device = device)
try:
train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment