文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 5fee1f79 by 王肇一

Jaccard score

parent ca83308e
...@@ -17,7 +17,7 @@ import re ...@@ -17,7 +17,7 @@ import re
from unet import UNet from unet import UNet
from mrnet import MultiUnet from mrnet import MultiUnet
from utils.predict import predict_img from utils.predict import predict_img,predict
from resCalc import save_img, get_subarea_info, save_img_mask from resCalc import save_img, get_subarea_info, save_img_mask
...@@ -30,8 +30,8 @@ def step_1(net, args, device, list, position): ...@@ -30,8 +30,8 @@ def step_1(net, args, device, list, position):
for fn in tqdm(list, position = position): for fn in tqdm(list, position = position):
logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1])) logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1]) img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
mask = predict_img(net = net, full_img = img, out_threshold = args.mask_threshold, #mask = predict_img(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
device = device) mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result = (mask * 255).astype(np.uint8) result = (mask * 255).astype(np.uint8)
#save_img({'ori': img, 'mask': result}, fn[0], fn[1]) #save_img({'ori': img, 'mask': result}, fn[0], fn[1])
......
_background_
target
\ No newline at end of file
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 256ug 852nm 30mw 300mw tune 43.00 002
Step size0Dwell time50 K.p atcc 700603 12.16 Ampicillin 64ug 40mw 400mw tune 43.08 002
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 512ug 852nm 30mw 300mw tune 43.00 002
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 2ug 852nm 30mw 300mw tune 43.06 002
Step size0Dwell time50 S.a atcc 29213 Linezolid 0.25ug 852nm 40mw 400mw tune 43.08 003
Step size0Dwell time50 K.p atcc 700603 12.16 Tobramycin 4ug 40mw 400mw tune 43.08 001
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 1ug 852nm 30mw 300mw tune 43.06 003
Step size0Dwell time50 S.a atcc 29213 Levofloxacin 0.5ug 852nm 40mw 400mw tune 43.08 002
Step size0Dwell time50 S.a atcc 29213 Levofloxacin 0.5ug 852nm 40mw 400mw tune 43.08 003
Step size0Dwell time50 S.a atcc 29213 Linezolid 0.25ug 852nm 40mw 400mw tune 43.08 002
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 2ug 852nm 30mw 300mw tune 43.06 003
Step size0Dwell time50 Clinical Sample K. p Tobramycin 64ug 852nm 30mw 300mw tune 43.00 004
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 512ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 K.p atcc 700603 12.16 Ampicillin 64ug 40mw 400mw tune 43.08 003
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 256ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 P.a 1h lb 852nm 30mw 300mw tune 43.03 005
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 32ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 4ug 852nm 30mw 300mw tune 43.06 001
Step size0Dwell time50 K.p atcc 700603 12.16 LB 40mw 400mw tune 43.08 002
Step size0Dwell time50 E.coil atcc25922 Ceftazime 1ug 852nm 40mw 400mw tune43.08 60x oil obj 003
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 32ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 S.a atcc 29213 Ampicillin 1ug 852nm 40mw 400mw tune 43.08 004
Step size0Dwell time50 E.coil atcc25922 Imipenen 1ug 852nm 40mw 400mw tune43.08 60x oil obj 001
Step size0Dwell time50 K.p atcc 700603 12.16 Ceftazidime 64ug 40mw 400mw tune 43.08 003
Step size0Dwell time50 K.p atcc 700603 12.16 Levofloxacin 1ug 40mw 400mw tune 43.08 001
Step size0Dwell time50 K.p atcc700603 Imipenen 1ug 852nm 40mw 400mw tune43.08 001
...@@ -40,11 +40,11 @@ class MultiUnet(nn.Module): ...@@ -40,11 +40,11 @@ class MultiUnet(nn.Module):
self.res9 = MultiResBlock(self.up9.outc*2, 32) self.res9 = MultiResBlock(self.up9.outc*2, 32)
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
# self.outconv = nn.Sequential( self.outconv = nn.Sequential(
# nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1), nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
# nn.Sigmoid() nn.Sigmoid()
# ) )
self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1) # self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
def forward(self, x): def forward(self, x):
x = self.inconv(x) x = self.inconv(x)
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torchvision import torchvision
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchsnooper # import torchsnooper
def conv(in_channel, out_channel): def conv(in_channel, out_channel):
...@@ -37,7 +37,7 @@ class MultiResBlock(nn.Module): ...@@ -37,7 +37,7 @@ class MultiResBlock(nn.Module):
self.norm = nn.BatchNorm2d(self.outc) self.norm = nn.BatchNorm2d(self.outc)
self.seq = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc)) self.seq = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))
#@torchsnooper.snoop() # @torchsnooper.snoop()
def forward(self, x): def forward(self, x):
shortcut = self.shortcut(x) shortcut = self.shortcut(x)
......
...@@ -7,11 +7,12 @@ from tqdm import tqdm ...@@ -7,11 +7,12 @@ from tqdm import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch.optim import lr_scheduler
from torchvision import transforms from torchvision import transforms
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from utils.dataset import BasicDataset,VOCSegmentation from utils.dataset import BasicDataset, VOCSegmentation
from utils.eval import eval_net from utils.eval import eval_net,eval_multi,eval_jac
dir_checkpoint = 'checkpoint/' dir_checkpoint = 'checkpoint/'
...@@ -27,7 +28,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -27,7 +28,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True) val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr) optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCELoss()#nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.StepLR(optimizer,30,0.5)#lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
...@@ -47,9 +49,13 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -47,9 +49,13 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
dice = eval_net(net, val_loader, device, n_val)
val_score = eval_net(net, val_loader, device, n_val) jac = eval_jac(net,val_loader,device,n_val)
logging.info('Validation : {}'.format(val_score)) # overall_acc, avg_per_class_acc, avg_jacc, avg_dice = eval_multi(net, val_loader, device, n_val)
scheduler.step()
logging.info(f'Avg Dice:{dice}\n'
f'Jaccard:{jac}\n'
f'Learning Rate:{scheduler.get_lr()[0]}')
if epoch % 5 == 0: if epoch % 5 == 0:
try: try:
os.mkdir(dir_checkpoint) os.mkdir(dir_checkpoint)
......
...@@ -51,8 +51,7 @@ if __name__ == '__main__': ...@@ -51,8 +51,7 @@ if __name__ == '__main__':
logging.info(f'Network:\n' logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n' f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n' f'\t{net.n_classes} output channels (classes)\n')
f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')
if args.load: if args.load:
net.load_state_dict(torch.load(args.load, map_location = device)) net.load_state_dict(torch.load(args.load, map_location = device))
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torchvision import transforms from torchvision import transforms
from torch.optim import lr_scheduler
from tqdm import tqdm from tqdm import tqdm
from utils.eval import eval_net from utils.eval import eval_net
...@@ -46,6 +47,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -46,6 +47,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8) # optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8) optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# criterion = nn.BCEWithLogitsLoss() # criterion = nn.BCEWithLogitsLoss()
if net.n_classes > 1: if net.n_classes > 1:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
...@@ -54,7 +56,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -54,7 +56,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
epoch_loss = 0 epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs,true_masks in train_loader: for imgs,true_masks in train_loader:
...@@ -82,8 +83,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -82,8 +83,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0: global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val) val_score = eval_net(net, val_loader, device, n_val)
scheduler.step(val_score)
if net.n_classes > 1: if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score)) logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step) writer.add_scalar('Loss/test', val_score, global_step)
......
...@@ -8,9 +8,6 @@ import torch ...@@ -8,9 +8,6 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import logging import logging
from PIL import Image from PIL import Image
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import os import os
from torchvision.datasets.vision import VisionDataset from torchvision.datasets.vision import VisionDataset
...@@ -87,10 +84,21 @@ class VOCSegmentation(VisionDataset): ...@@ -87,10 +84,21 @@ class VOCSegmentation(VisionDataset):
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks)) assert (len(self.images) == len(self.masks))
self.seq = iaa.Sequential([iaa.SomeOf((0, 5), [iaa.Noop(), iaa.Fliplr(0.5), @classmethod
iaa.Sometimes(0.25, iaa.Dropout(p = (0, 0.1))), iaa.Affine(rotate = (-45, 45)), def preprocess(cls, pil_img):
iaa.ElasticTransformation(alpha = 50, sigma = 5) pil_img = pil_img.resize((256, 256))
], random_order = True)])
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis = 2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, index): def __getitem__(self, index):
img = Image.open(self.images[index]).convert('L') img = Image.open(self.images[index]).convert('L')
...@@ -99,10 +107,11 @@ class VOCSegmentation(VisionDataset): ...@@ -99,10 +107,11 @@ class VOCSegmentation(VisionDataset):
pim = target.load() pim = target.load()
for i in range(200): for i in range(200):
for j in range(200): for j in range(200):
pim[i, j] = 1 if pim[i, j] > 0 else 0 pim[i, j] = 255 if pim[i, j] > 0 else 0
# img, target = self.seq(image=np.array(img), segmentation_maps = np.array(target)) # img, target = self.seq(image=np.array(img), segmentation_maps = np.array(target))
# img = self.preprocess(img)
# target = self.preprocess(img)
if self.transforms is not None: if self.transforms is not None:
img, target = self.transforms(img, target) img, target = self.transforms(img, target)
......
import torch import torch
from torch.autograd import Function from torch.autograd import Function
class DiceCoeff(Function): class DiceCoeff(Function):
"""Dice coeff for individual examples""" """Dice coeff for individual examples"""
...@@ -10,22 +9,18 @@ class DiceCoeff(Function): ...@@ -10,22 +9,18 @@ class DiceCoeff(Function):
eps = 0.0001 eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1)) self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps self.union = torch.sum(input) + torch.sum(target) + eps
t = (2 * self.inter.float() + eps) / self.union.float() t = (2 * self.inter.float() + eps) / self.union.float()
return t return t
# This function has only a single output, so it gets only one gradient # This function has only a single output, so it gets only one gradient
def backward(self, grad_output): def backward(self, grad_output):
input, target = self.saved_variables input, target = self.saved_variables
grad_input = grad_target = None grad_input = grad_target = None
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) \ grad_input = grad_output * 2 * (target * self.union - self.inter) \
/ (self.union * self.union) / (self.union * self.union)
if self.needs_input_grad[1]: if self.needs_input_grad[1]:
grad_target = None grad_target = None
return grad_input, grad_target return grad_input, grad_target
...@@ -40,3 +35,13 @@ def dice_coeff(input, target): ...@@ -40,3 +35,13 @@ def dice_coeff(input, target):
s = s + DiceCoeff().forward(c[0], c[1]) s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1) return s / (i + 1)
def dice_coef(pred, target):
smooth = 1.
num = pred.size(0)
m1 = pred.view(num, -1) # Flatten
m2 = target.view(num, -1) # Flatten
intersection = (m1 * m2).sum()
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
\ No newline at end of file
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from sklearn.metrics import jaccard_score
from utils.dice_loss import dice_coeff from utils.dice_loss import dice_coeff, dice_coef
from .metrics import eval_metrics
def eval_net(net, loader, device, n_val): def eval_net(net, loader, device, n_val):
...@@ -11,10 +13,7 @@ def eval_net(net, loader, device, n_val): ...@@ -11,10 +13,7 @@ def eval_net(net, loader, device, n_val):
tot = 0 tot = 0
with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
for batch in loader: for imgs,true_masks in loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32) imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type) true_masks = true_masks.to(device=device, dtype=mask_type)
...@@ -26,7 +25,48 @@ def eval_net(net, loader, device, n_val): ...@@ -26,7 +25,48 @@ def eval_net(net, loader, device, n_val):
if net.n_classes > 1: if net.n_classes > 1:
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else: else:
tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item() tot += dice_coef(pred, true_mask.squeeze(dim=1)).item()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
return tot / n_val return tot / n_val
def eval_multi(net, loader, device, n_val):
net.eval()
overall_acc = 0
avg_per_class_acc = 0
avg_jacc = 0
avg_dice = 0
with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar:
for imgs, true_masks in loader:
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
pred_mask = net(imgs)
oac, apca, aj, ad = eval_metrics(true_masks, pred_mask, 1)
overall_acc += oac
avg_per_class_acc += apca
avg_jacc += aj
avg_dice += ad
pbar.update(imgs.shape[0])
return
def eval_jac(net, loader, device, n_val):
net.eval()
jac = 0
with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar:
for imgs, true_masks in loader:
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
pred_masks = net(imgs)
pred_masks = torch.round(pred_masks).detach().numpy()
true_masks = torch.round(true_masks).numpy()
jac += jaccard_score(true_masks.flatten(), pred_masks.flatten())
pbar.update(imgs.shape[0])
return jac/n_val
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""Common image segmentation metrics.
"""
import torch
EPS = 1e-10
def nanmean(x):
"""Computes the arithmetic mean ignoring any NaNs."""
return torch.mean(x[x == x])
def _fast_hist(true, pred, num_classes):
mask = (true >= 0) & (true < num_classes)
hist = torch.bincount(num_classes * true[mask] + pred[mask], minlength = num_classes ** 2).reshape(num_classes,num_classes).float()
return hist
def overall_pixel_accuracy(hist):
"""Computes the total pixel accuracy.
The overall pixel accuracy provides an intuitive
approximation for the qualitative perception of the
label when it is viewed in its overall shape but not
its details.
Args:
hist: confusion matrix.
Returns:
overall_acc: the overall pixel accuracy.
"""
correct = torch.diag(hist).sum()
total = hist.sum()
overall_acc = correct / (total + EPS)
return overall_acc
def per_class_pixel_accuracy(hist):
"""Computes the average per-class pixel accuracy.
The per-class pixel accuracy is a more fine-grained
version of the overall pixel accuracy. A model could
score a relatively high overall pixel accuracy by
correctly predicting the dominant labels or areas
in the image whilst incorrectly predicting the
possibly more important/rare labels. Such a model
will score a low per-class pixel accuracy.
Args:
hist: confusion matrix.
Returns:
avg_per_class_acc: the average per-class pixel accuracy.
"""
correct_per_class = torch.diag(hist)
total_per_class = hist.sum(dim = 1)
per_class_acc = correct_per_class / (total_per_class + EPS)
avg_per_class_acc = nanmean(per_class_acc)
return avg_per_class_acc
def jaccard_index(hist):
"""Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
Args:
hist: confusion matrix.
Returns:
avg_jacc: the average per-class jaccard index.
"""
A_inter_B = torch.diag(hist)
A = hist.sum(dim = 1)
B = hist.sum(dim = 0)
jaccard = A_inter_B / (A + B - A_inter_B + EPS)
avg_jacc = nanmean(jaccard)
return avg_jacc
def dice_coefficient(hist):
"""Computes the Sørensen–Dice coefficient, a.k.a the F1 score.
Args:
hist: confusion matrix.
Returns:
avg_dice: the average per-class dice coefficient.
"""
A_inter_B = torch.diag(hist)
A = hist.sum(dim = 1)
B = hist.sum(dim = 0)
dice = (2 * A_inter_B) / (A + B + EPS)
avg_dice = nanmean(dice)
return avg_dice
def eval_metrics(true, pred, num_classes):
"""Computes various segmentation metrics on 2D feature maps.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
pred: a tensor of shape [B, H, W] or [B, 1, H, W].
num_classes: the number of classes to segment. This number
should be less than the ID of the ignored class.
Returns:
overall_acc: the overall pixel accuracy.
avg_per_class_acc: the average per-class pixel accuracy.
avg_jacc: the jaccard index.
avg_dice: the dice coefficient.
"""
hist = torch.zeros((num_classes, num_classes))
for t, p in zip(true, pred):
hist += _fast_hist(t.flatten(), p.flatten(), num_classes)
overall_acc = overall_pixel_accuracy(hist)
avg_per_class_acc = per_class_pixel_accuracy(hist)
avg_jacc = jaccard_index(hist)
avg_dice = dice_coefficient(hist)
return overall_acc, avg_per_class_acc, avg_jacc, avg_dice
class AverageMeter(object):
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n = 1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
...@@ -23,4 +23,25 @@ def predict_img(net, full_img, device, out_threshold = 0.5): ...@@ -23,4 +23,25 @@ def predict_img(net, full_img, device, out_threshold = 0.5):
probs = tf(probs.cpu()) probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy() full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold
def predict(net, full_img, device, out_threshold = 0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img))
img = img.unsqueeze(0)
img = img.to(device = device, dtype = torch.float32)
with torch.no_grad():
output = net(img)
# if net.n_classes > 1:
# probs = F.softmax(output, dim = 1)
# else:
# probs = torch.sigmoid(output)
probs = output.squeeze(0)
tf = transforms.Compose([transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor()])
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold return full_mask > out_threshold
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment