文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 2a781f86 by 王肇一

adjusted mrnet

parent 8c28bb60
...@@ -15,7 +15,7 @@ from utils.eval import eval_net ...@@ -15,7 +15,7 @@ from utils.eval import eval_net
dir_img = 'data/train_imgs/' dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/' dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint' dir_checkpoint = 'checkpoint/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5): def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5):
...@@ -32,21 +32,24 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -32,21 +32,24 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
for epoch in tqdm(range(epochs)): for epoch in tqdm(range(epochs)):
net.train() net.train()
epoch_loss = 0 epoch_loss = 0
for batch in train_loader: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
imgs = batch['image'] for batch in train_loader:
true_masks = batch['mask'] imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs) masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks) loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item() epoch_loss += loss.item()
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pbar.update(imgs.shape[0])
val_score = eval_net(net, val_loader, device, n_val) val_score = eval_net(net, val_loader, device, n_val)
logging.info('Validation cross entropy: {}'.format(val_score)) logging.info('Validation cross entropy: {}'.format(val_score))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment