文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 863c51ff by 王肇一

several new cli tools:

1. img quality evaluate
2. rename files
3. pd denoise(not finished)
4. unet segmentation(not finished)
parent 06fb2e4b
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from sklearn.cluster import KMeans from sklearn.cluster import KMeans
from util import * from cvBasedMethod.util import *
from filters import butterworth,fft_mask
def kmeans(pair, cluster_num = 5,filter='butter'): def kmeans(pair, cluster_num = 5,filter='butter'):
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from util import * from cvBasedMethod.util import *
from filters import butterworth, fft_mask
def threshold(pair,filter='butter'): def threshold(pair,filter='butter'):
......
...@@ -6,7 +6,7 @@ import pandas as pd ...@@ -6,7 +6,7 @@ import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import logging import logging
import os import os
from filters import fft_mask,butterworth from cvBasedMethod.filters import fft_mask,butterworth
def remove_scratch(img): def remove_scratch(img):
f = np.fft.fft2(img) f = np.fft.fft2(img)
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torchvision
from torchvision import transforms, models, datasets
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3)
self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3)
self.pool1 = nn.MaxPool2d(kernel_size = 2)
self.conv3 = nn.Conv2d(in_channels = 64,out_channels = 128,kernel_size = 3)
self.conv4 = nn.Conv2d(in_channels = 128,out_channels = 128,kernel_size = 3)
self.pool2 = nn.MaxPool2d(kernel_size = 2)
self.conv5 = nn.Conv2d(in_channels = 128,out_channels = 256,kernel_size = 3)
self.conv6 = nn.Conv2d(in_channels = 256,out_channels = 256,kernel_size = 3)
self.pool3 = nn.MaxPool2d(kernel_size = 2)
self.conv7 = nn.Conv2d(in_channels = 256,out_channels = 512,kernel_size = 3)
self.conv8 = nn.Conv2d(in_channels = 512,out_channels = 512,kernel_size = 3)
self.pool4 = nn.MaxPool2d(kernel_size = 2)
self.conv9 = nn.Conv2d(in_channels = 512,out_channels = 1024,kernel_size = 3)
self.conv10 = nn.Conv2d(in_channels = 1024,out_channels = 1024,kernel_size = 3)
self.up1 = nn.C
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
void filterLargeSmall(ImageProcessor ip, double filterLarge, double filterSmall, int stripesHorVert, double scaleStripes) {
int maxN = ip.getWidth();
float[] fht = (float[])ip.getPixels();
float[] filter = new float[maxN*maxN];
for (int i=0; i<maxN*maxN; i++)
filter[i]=1f;
int row;
int backrow;
float rowFactLarge;
float rowFactSmall;
int col;
int backcol;
float factor;
float colFactLarge;
float colFactSmall;
float factStripes;
// calculate factor in exponent of Gaussian from filterLarge / filterSmall
double scaleLarge = filterLarge*filterLarge;
double scaleSmall = filterSmall*filterSmall;
scaleStripes = scaleStripes*scaleStripes;
//float FactStripes;
// loop over rows
for (int j=1; j<maxN/2; j++) {
row = j * maxN;
backrow = (maxN-j)*maxN;
rowFactLarge = (float) Math.exp(-(j*j) * scaleLarge);
rowFactSmall = (float) Math.exp(-(j*j) * scaleSmall);
// loop over columns
for (col=1; col<maxN/2; col++){
backcol = maxN-col;
colFactLarge = (float) Math.exp(- (col*col) * scaleLarge);
colFactSmall = (float) Math.exp(- (col*col) * scaleSmall);
factor = (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
switch (stripesHorVert) {
case 1: factor *= (1 - (float) Math.exp(- (col*col) * scaleStripes)); break;// hor stripes
case 2: factor *= (1 - (float) Math.exp(- (j*j) * scaleStripes)); // vert stripes
}
fht[col+row] *= factor;
fht[col+backrow] *= factor;
fht[backcol+row] *= factor;
fht[backcol+backrow] *= factor;
filter[col+row] *= factor;
filter[col+backrow] *= factor;
filter[backcol+row] *= factor;
filter[backcol+backrow] *= factor;
}
}
//process meeting points (maxN/2,0) , (0,maxN/2), and (maxN/2,maxN/2)
int rowmid = maxN * (maxN/2);
rowFactLarge = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleLarge);
rowFactSmall = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleSmall);
factStripes = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleStripes);
fht[maxN/2] *= (1 - rowFactLarge) * rowFactSmall; // (maxN/2,0)
fht[rowmid] *= (1 - rowFactLarge) * rowFactSmall; // (0,maxN/2)
fht[maxN/2 + rowmid] *= (1 - rowFactLarge*rowFactLarge) * rowFactSmall*rowFactSmall; // (maxN/2,maxN/2)
filter[maxN/2] *= (1 - rowFactLarge) * rowFactSmall; // (maxN/2,0)
filter[rowmid] *= (1 - rowFactLarge) * rowFactSmall; // (0,maxN/2)
filter[maxN/2 + rowmid] *= (1 - rowFactLarge*rowFactLarge) * rowFactSmall*rowFactSmall; // (maxN/2,maxN/2)
switch (stripesHorVert) {
case 1: fht[maxN/2] *= (1 - factStripes);
fht[rowmid] = 0;
fht[maxN/2 + rowmid] *= (1 - factStripes);
filter[maxN/2] *= (1 - factStripes);
filter[rowmid] = 0;
filter[maxN/2 + rowmid] *= (1 - factStripes);
break; // hor stripes
case 2: fht[maxN/2] = 0;
fht[rowmid] *= (1 - factStripes);
fht[maxN/2 + rowmid] *= (1 - factStripes);
filter[maxN/2] = 0;
filter[rowmid] *= (1 - factStripes);
filter[maxN/2 + rowmid] *= (1 - factStripes);
break; // vert stripes
}
//loop along row 0 and maxN/2
rowFactLarge = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleLarge);
rowFactSmall = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleSmall);
for (col=1; col<maxN/2; col++){
backcol = maxN-col;
colFactLarge = (float) Math.exp(- (col*col) * scaleLarge);
colFactSmall = (float) Math.exp(- (col*col) * scaleSmall);
switch (stripesHorVert) {
case 0:
fht[col] *= (1 - colFactLarge) * colFactSmall;
fht[backcol] *= (1 - colFactLarge) * colFactSmall;
fht[col+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
fht[backcol+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
filter[col] *= (1 - colFactLarge) * colFactSmall;
filter[backcol] *= (1 - colFactLarge) * colFactSmall;
filter[col+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
filter[backcol+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
break;
}
}
// loop along column 0 and maxN/2
colFactLarge = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleLarge);
colFactSmall = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleSmall);
for (int j=1; j<maxN/2; j++) {
row = j * maxN;
backrow = (maxN-j)*maxN;
rowFactLarge = (float) Math.exp(- (j*j) * scaleLarge);
rowFactSmall = (float) Math.exp(- (j*j) * scaleSmall);
switch (stripesHorVert) {
case 0:
fht[row] *= (1 - rowFactLarge) * rowFactSmall;
fht[backrow] *= (1 - rowFactLarge) * rowFactSmall;
fht[row+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
fht[backrow+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
filter[row] *= (1 - rowFactLarge) * rowFactSmall;
filter[backrow] *= (1 - rowFactLarge) * rowFactSmall;
filter[row+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
filter[backrow+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
break;
}
}
if (displayFilter && slice==1) {
FHT f = new FHT(new FloatProcessor(maxN, maxN, filter, null));
f.swapQuadrants();
new ImagePlus("Filter", f).show();
}
}
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import argparse
import json
import os
import os.path as osp
import warnings
import PIL.Image
import yaml
import imgviz
from labelme import utils
import base64
def main():
warnings.warn("This script is aimed to demonstrate how to convert the\n"
"JSON file to a single image dataset, and not to handle\n"
"multiple JSON files to generate a real-use dataset.")
parser = argparse.ArgumentParser()
parser.add_argument('json_file')
parser.add_argument('-o', '--out', default = None)
args = parser.parse_args()
json_file = args.json_file
if args.out is None:
out_dir = osp.basename(json_file).replace('.', '_')
out_dir = osp.join(osp.dirname(json_file), out_dir)
else:
out_dir = args.out
if not osp.exists(out_dir):
os.mkdir(out_dir)
count = [x for x in filter(lambda x: x.endswith('json'), os.listdir(json_file))]
for i in range(len(count)):
path = os.path.join(json_file, count[i])
if os.path.isfile(path):
print(path)
data = json.load(open(path))
if data['imageData']:
imageData = data['imageData']
else:
imagePath = os.path.join(os.path.dirname(path), data['imagePath'])
with open(imagePath, 'rb') as f:
imageData = f.read()
imageData = base64.b64encode(imageData).decode('utf-8')
img = utils.img_b64_to_arr(imageData)
label_name_to_value = {'_background_': 0}
for shape in data['shapes']:
label_name = shape['label']
if label_name in label_name_to_value:
label_value = label_name_to_value[label_name]
else:
label_value = len(label_name_to_value)
label_name_to_value[label_name] = label_value
# label_values must be dense
label_values, label_names = [], []
for ln, lv in sorted(label_name_to_value.items(), key = lambda x: x[1]):
label_values.append(lv)
label_names.append(ln)
assert label_values == list(range(len(label_values)))
lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value)
label_names = [None] * (max(label_name_to_value.values()) + 1)
for name, value in label_name_to_value.items():
label_names[value] = name
lbl_viz = imgviz.label2rgb(label = lbl, img = img, label_names = label_names, loc = 'rb')
out_dir = osp.basename(count[i]).replace('.', '_')
out_dir = osp.join(osp.dirname(count[i]), out_dir)
if not osp.exists(out_dir):
os.mkdir(out_dir)
PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))
# PIL.Image.fromarray(lbl).save(osp.join(out_dir, 'label.png'))
utils.lblsave(osp.join(out_dir, 'label.png'), lbl)
PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, 'label_viz.png'))
with open(osp.join(out_dir, 'label_names.txt'), 'w') as f:
for lbl_name in label_names:
f.write(lbl_name + '\n')
warnings.warn('info.yaml is being replaced by label_names.txt')
info = dict(label_names = label_names)
with open(osp.join(out_dir, 'info.yaml'), 'w') as f:
yaml.safe_dump(info, f, default_flow_style = False)
print('Saved to: %s' % out_dir)
if __name__ == '__main__':
main()
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
from multiprocessing import Pool from multiprocessing import Pool
from tqdm import tqdm from tqdm import tqdm
import argparse import argparse
from util import * from cvBasedMethod.util import *
from kmeans import kmeans, kmeans_back from cvBasedMethod.kmeans import kmeans, kmeans_back
from threshold import threshold from cvBasedMethod.threshold import threshold
def method_kmeans(imglist, core = 5): def method_kmeans(imglist, core = 5):
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import cv2 as cv
import numpy as np
import math
def brenner(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
shape = np.shape(img)
out = 0
for x in range(0, shape[0] - 2):
for y in range(0, shape[1]):
out += (int(img[x + 2, y]) - int(img[x, y])) ** 2
return out
def Laplacian(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
return cv.Laplacian(img, cv.CV_64F).var()
def SMD(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
shape = np.shape(img)
out = 0
for x in range(1, shape[0] - 1):
for y in range(0, shape[1]):
out += math.fabs(int(img[x, y]) - int(img[x, y - 1]))
out += math.fabs(int(img[x, y] - int(img[x + 1, y])))
return out
def SMD2(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
shape = np.shape(img)
out = 0
for x in range(0, shape[0] - 1):
for y in range(0, shape[1] - 1):
out += math.fabs(int(img[x, y]) - int(img[x + 1, y])) * math.fabs(int(img[x, y] - int(img[x, y + 1])))
return out
def variance(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
out = 0
u = np.mean(img)
shape = np.shape(img)
for x in range(0, shape[0]):
for y in range(0, shape[1]):
out += (img[x, y] - u) ** 2
return out
def energy(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
shape = np.shape(img)
out = 0
for x in range(0, shape[0] - 1):
for y in range(0, shape[1] - 1):
out += ((int(img[x + 1, y]) - int(img[x, y])) ** 2) + ((int(img[x, y + 1] - int(img[x, y]))) ** 2)
return out
def Vollath(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
shape = np.shape(img)
u = np.mean(img)
out = -shape[0] * shape[1] * (u ** 2)
for x in range(0, shape[0] - 1):
for y in range(0, shape[1]):
out += int(img[x, y]) * int(img[x + 1, y])
return out
def entropy(img):
'''
:param img:narray 二维灰度图像
:return: float 图像越清晰越大
'''
out = 0
count = np.shape(img)[0] * np.shape(img)[1]
p = np.bincount(np.array(img).flatten())
for i in range(0, len(p)):
if p[i] != 0:
out -= p[i] * math.log(p[i] / count) / count
return out
def combine(img):
value_array = [entropy(img), Vollath(img), energy(img), variance(img), SMD2(img), SMD(img), Laplacian(img),
brenner(img)]
return value_array
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import re
import shutil
names = [(path, name) for path in filter(lambda x: os.path.isdir('data/' + x), os.listdir('data')) for name in
filter(lambda x: x != '.DS_Store', os.listdir('data/' + path))]
# Note that input/output paths should be changed
for path, name in names:
res = re.match(r'(\d{8})\-(.*)\s(atcc|ATCC)\s?(\d+)\s(.*)', path)
new_path = res.group(2) + res.group(4) + res.group(5)
res = re.match(r'.*\s(.*[uU][gG]|[dD]2[oO]|[lL][bB]|TOB|levo).*\s(\d+.*)', name)
new_name = res.group(1) + res.group(2)
try:
os.makedirs('./dest/' + new_path)
except:
print('exist')
shutil.copyfile('./data/' + path + '/' + name, 'dest/' + new_path + '/' + new_name)
\ No newline at end of file
import argparse
import logging
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from eval import eval_net
from unet import UNet
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/imgs/'
dir_mask = 'data/masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net,
device,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.1,
save_cp=True,
img_scale=0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
writer.close()
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=1, n_classes=1)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
from .unet_model import UNet
""" Full assembly of the parts to form the complete network """
import torch.nn.functional as F
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(1024, 256, bilinear)
self.up2 = Up(512, 128, bilinear)
self.up3 = Up(256, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment