Commit 2a781f86 by 王肇一

adjusted mrnet

parent 8c28bb60
......@@ -15,7 +15,7 @@ from utils.eval import eval_net
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint'
dir_checkpoint = 'checkpoint/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5):
......@@ -32,21 +32,24 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
for epoch in tqdm(range(epochs)):
net.train()
epoch_loss = 0
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
val_score = eval_net(net, val_loader, device, n_val)
logging.info('Validation cross entropy: {}'.format(val_score))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment