Commit f9f59944 by 王肇一

Unet now functional ready for test

parent b835e670
# 使用说明
# 使用说明(编写未完成)
## 安装依赖
`pip install -r requirement.txt`
## 目录结构
```.
```
.
├── README.md
├── data
│ ├── module
......@@ -21,7 +22,7 @@
```
## 主程序入口
### 参数
* -m ,--method : 0 使用Kmeans,1 使用阈值法(butterworth滤波),2 使用阈值法(fft)。默认Kmeans
* -m ,--method : 0 使用Kmeans,1 使用阈值法(butterworth滤波),2 使用阈值法(fft),3 使用unet。默认Kmeans
* -c ,--core : Kmeans分为几类,默认5,仅对Kmeans法有效
* -p ,--process : 使用线程数量,默认8,仅对阈值法有效
......@@ -33,8 +34,8 @@
## 使用Unet模型
### 训练
```shell script
> python train.py -h
```
python train.py -h
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]
Train the UNet on images and target masks
......@@ -52,16 +53,16 @@ optional arguments:
-v VAL, --validation VAL
Percent of the data that is used as validation (0-100)
(default: 15.0)
```
训练后将生成一个checkout目录,存放所得模型。可根据需要,选择保留或删除。
### 监控
使用tensorboard可视化监控
`tensorboard --logdir=runs`
## 数据
输入图像存放于imgs文件夹下,单通道灰度图像,8bit,200*200
标注的mask存放于masks文件夹下,像素0为背景,1为目标
- 用于训练模型或用于提取信号的输入图像存放于imgs文件夹下,单通道灰度图像,8bit,200*200
- 标注的mask存放于masks文件夹下,像素0为背景,1为目标
## cli工具
部分可能用得到的cli工具,主要为python或shell脚本。可根据需要自行修改。
......
void filterLargeSmall(ImageProcessor ip, double filterLarge, double filterSmall, int stripesHorVert, double scaleStripes) {
int maxN = ip.getWidth();
float[] fht = (float[])ip.getPixels();
float[] filter = new float[maxN*maxN];
for (int i=0; i<maxN*maxN; i++)
filter[i]=1f;
int row;
int backrow;
float rowFactLarge;
float rowFactSmall;
int col;
int backcol;
float factor;
float colFactLarge;
float colFactSmall;
float factStripes;
// calculate factor in exponent of Gaussian from filterLarge / filterSmall
double scaleLarge = filterLarge*filterLarge;
double scaleSmall = filterSmall*filterSmall;
scaleStripes = scaleStripes*scaleStripes;
//float FactStripes;
// loop over rows
for (int j=1; j<maxN/2; j++) {
row = j * maxN;
backrow = (maxN-j)*maxN;
rowFactLarge = (float) Math.exp(-(j*j) * scaleLarge);
rowFactSmall = (float) Math.exp(-(j*j) * scaleSmall);
// loop over columns
for (col=1; col<maxN/2; col++){
backcol = maxN-col;
colFactLarge = (float) Math.exp(- (col*col) * scaleLarge);
colFactSmall = (float) Math.exp(- (col*col) * scaleSmall);
factor = (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
switch (stripesHorVert) {
case 1: factor *= (1 - (float) Math.exp(- (col*col) * scaleStripes)); break;// hor stripes
case 2: factor *= (1 - (float) Math.exp(- (j*j) * scaleStripes)); // vert stripes
}
fht[col+row] *= factor;
fht[col+backrow] *= factor;
fht[backcol+row] *= factor;
fht[backcol+backrow] *= factor;
filter[col+row] *= factor;
filter[col+backrow] *= factor;
filter[backcol+row] *= factor;
filter[backcol+backrow] *= factor;
}
}
//process meeting points (maxN/2,0) , (0,maxN/2), and (maxN/2,maxN/2)
int rowmid = maxN * (maxN/2);
rowFactLarge = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleLarge);
rowFactSmall = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleSmall);
factStripes = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleStripes);
fht[maxN/2] *= (1 - rowFactLarge) * rowFactSmall; // (maxN/2,0)
fht[rowmid] *= (1 - rowFactLarge) * rowFactSmall; // (0,maxN/2)
fht[maxN/2 + rowmid] *= (1 - rowFactLarge*rowFactLarge) * rowFactSmall*rowFactSmall; // (maxN/2,maxN/2)
filter[maxN/2] *= (1 - rowFactLarge) * rowFactSmall; // (maxN/2,0)
filter[rowmid] *= (1 - rowFactLarge) * rowFactSmall; // (0,maxN/2)
filter[maxN/2 + rowmid] *= (1 - rowFactLarge*rowFactLarge) * rowFactSmall*rowFactSmall; // (maxN/2,maxN/2)
switch (stripesHorVert) {
case 1: fht[maxN/2] *= (1 - factStripes);
fht[rowmid] = 0;
fht[maxN/2 + rowmid] *= (1 - factStripes);
filter[maxN/2] *= (1 - factStripes);
filter[rowmid] = 0;
filter[maxN/2 + rowmid] *= (1 - factStripes);
break; // hor stripes
case 2: fht[maxN/2] = 0;
fht[rowmid] *= (1 - factStripes);
fht[maxN/2 + rowmid] *= (1 - factStripes);
filter[maxN/2] = 0;
filter[rowmid] *= (1 - factStripes);
filter[maxN/2 + rowmid] *= (1 - factStripes);
break; // vert stripes
}
//loop along row 0 and maxN/2
rowFactLarge = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleLarge);
rowFactSmall = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleSmall);
for (col=1; col<maxN/2; col++){
backcol = maxN-col;
colFactLarge = (float) Math.exp(- (col*col) * scaleLarge);
colFactSmall = (float) Math.exp(- (col*col) * scaleSmall);
switch (stripesHorVert) {
case 0:
fht[col] *= (1 - colFactLarge) * colFactSmall;
fht[backcol] *= (1 - colFactLarge) * colFactSmall;
fht[col+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
fht[backcol+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
filter[col] *= (1 - colFactLarge) * colFactSmall;
filter[backcol] *= (1 - colFactLarge) * colFactSmall;
filter[col+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
filter[backcol+rowmid] *= (1 - colFactLarge*rowFactLarge) * colFactSmall*rowFactSmall;
break;
}
}
// loop along column 0 and maxN/2
colFactLarge = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleLarge);
colFactSmall = (float) Math.exp(- (maxN/2)*(maxN/2) * scaleSmall);
for (int j=1; j<maxN/2; j++) {
row = j * maxN;
backrow = (maxN-j)*maxN;
rowFactLarge = (float) Math.exp(- (j*j) * scaleLarge);
rowFactSmall = (float) Math.exp(- (j*j) * scaleSmall);
switch (stripesHorVert) {
case 0:
fht[row] *= (1 - rowFactLarge) * rowFactSmall;
fht[backrow] *= (1 - rowFactLarge) * rowFactSmall;
fht[row+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
fht[backrow+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
filter[row] *= (1 - rowFactLarge) * rowFactSmall;
filter[backrow] *= (1 - rowFactLarge) * rowFactSmall;
filter[row+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
filter[backrow+maxN/2] *= (1 - rowFactLarge*colFactLarge) * rowFactSmall*colFactSmall;
break;
}
}
if (displayFilter && slice==1) {
FHT f = new FHT(new FloatProcessor(maxN, maxN, filter, null));
f.swapQuadrants();
new ImagePlus("Filter", f).show();
}
}
\ No newline at end of file
......@@ -6,6 +6,7 @@ import argparse
from cvBasedMethod.util import *
from cvBasedMethod.kmeans import kmeans, kmeans_back
from cvBasedMethod.threshold import threshold
from predict import predict
def method_kmeans(imglist, core = 5):
......@@ -36,17 +37,27 @@ def get_args():
parser.add_argument('-c', '--core', metavar = 'C', type = int, default = 5, help = 'Num of cluster', dest = 'core')
parser.add_argument('-p', '--process', metavar = 'P', type = int, default = 8, help = 'Num of process',
dest = 'process')
# Unet para
parser.add_argument('--load', '-L', default = 'data/module/MODEL.pth', metavar = 'FILE',
help = "Specify the file in which the model is stored")
parser.add_argument('--mask-threshold', '-t', type = float,
help = "Minimum probability value to consider a mask pixel white", default = 0.5)
parser.add_argument('--scale', '-s', type = float, help = "Scale factor for the input images", default = 0.5)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
path = [(y, x) for y in os.listdir(args.dir) for x in filter(
lambda x: x.endswith('.tif') and not x.endswith('dc.tif') and not x.endswith('DC.tif') and not x.endswith(
'dc .tif'), os.listdir('img/' + y))]
path = ['data/imgs/'+x for x in os.listdir('data/imgs')]
if args.method == 0:
method_kmeans(path,args.core)
elif args.method == 1:
method_threshold(path,args.process)
elif args.method == 2:
method_newThreshold(path,args.process)
elif args.method == 3:
predict(path, ['data/output/imgs/'+name[10:] for name in path], args.load, args.scale, args.mask_threshold)
import argparse
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
......@@ -9,107 +6,46 @@ from PIL import Image
from torchvision import transforms
from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5):
def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
img = img.to(device = device, dtype = torch.float32)
with torch.no_grad():
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
probs = F.softmax(output, dim = 1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.ToTensor()
]
)
tf = transforms.Compose([transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor()])
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='Filenames of ouput images')
parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed",
default=False)
parser.add_argument('--no-save', '-n', action='store_true',
help="Do not save the output masks",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.5)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=0.5)
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
logging.error("Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
return out_files
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
def predict(img_name, outdir, model, scale, mask_threshold):
in_files = img_name
out_files = outdir
net = UNet(n_channels=1, n_classes=1)
net = UNet(n_channels = 1, n_classes = 1)
logging.info("Loading model {}".format(args.model))
logging.info("Loading model {}".format(model))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
net.to(device = device)
net.load_state_dict(torch.load(model, map_location = device))
logging.info("Model loaded !")
......@@ -118,19 +54,10 @@ if __name__ == "__main__":
img = Image.open(fn)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
if not args.no_save:
out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
logging.info("Mask saved to {}".format(out_files[i]))
mask = predict_img(net = net, full_img = img, scale_factor = scale, out_threshold = mask_threshold,
device = device)
#out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files)
if args.viz:
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
plot_img_and_mask(img, mask)
logging.info("Mask saved to {}".format(out_files[i]))
\ No newline at end of file
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.shape[2] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i+1].set_title(f'Output mask (class {i+1})')
ax[i+1].imshow(mask[:, :, i])
else:
ax[1].set_title(f'Output mask')
ax[1].imshow(mask)
plt.xticks([]), plt.yticks([])
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment