Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
王肇一
Im
Commits
f9f59944
Commit
f9f59944
authored
Jan 09, 2020
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Unet now functional ready for test
parent
b835e670
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
41 additions
and
259 deletions
+41
-259
README.md
README.md
+9
-8
MODEL.pth
data/module/MODEL.pth
+0
-0
filter.java
imageJ/filter.java
+0
-141
main.py
main.py
+14
-3
predict.py
predict.py
+18
-90
data_vis.py
utils/data_vis.py
+0
-17
quality.py
utils/quality.py
+0
-0
No files found.
README.md
View file @
f9f59944
# 使用说明
# 使用说明
(编写未完成)
## 安装依赖
`pip install -r requirement.txt`
## 目录结构
```
.
```
.
├── README.md
├── data
│ ├── module
...
...
@@ -21,7 +22,7 @@
```
## 主程序入口
### 参数
*
-m ,--method : 0 使用Kmeans,1 使用阈值法(butterworth滤波),2 使用阈值法(fft)。默认Kmeans
*
-m ,--method : 0 使用Kmeans,1 使用阈值法(butterworth滤波),2 使用阈值法(fft)
,3 使用unet
。默认Kmeans
*
-c ,--core : Kmeans分为几类,默认5,仅对Kmeans法有效
*
-p ,--process : 使用线程数量,默认8,仅对阈值法有效
...
...
@@ -33,8 +34,8 @@
## 使用Unet模型
### 训练
```
shell script
>
python train.py -h
```
python train.py -h
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]
Train the UNet on images and target masks
...
...
@@ -52,16 +53,16 @@ optional arguments:
-v VAL, --validation VAL
Percent of the data that is used as validation (0-100)
(default: 15.0)
```
训练后将生成一个checkout目录,存放所得模型。可根据需要,选择保留或删除。
### 监控
使用tensorboard可视化监控
`tensorboard --logdir=runs`
## 数据
输入图像存放于imgs文件夹下,单通道灰度图像,8bit,200
*
200
标注的mask存放于masks文件夹下,像素0为背景,1为目标
-
用于训练模型或用于提取信号的
输入图像存放于imgs文件夹下,单通道灰度图像,8bit,200
*
200
-
标注的mask存放于masks文件夹下,像素0为背景,1为目标
## cli工具
部分可能用得到的cli工具,主要为python或shell脚本。可根据需要自行修改。
...
...
data/module/MODEL
_52
.pth
→
data/module/MODEL.pth
View file @
f9f59944
File moved
imageJ/filter.java
deleted
100644 → 0
View file @
b835e670
void
filterLargeSmall
(
ImageProcessor
ip
,
double
filterLarge
,
double
filterSmall
,
int
stripesHorVert
,
double
scaleStripes
)
{
int
maxN
=
ip
.
getWidth
();
float
[]
fht
=
(
float
[])
ip
.
getPixels
();
float
[]
filter
=
new
float
[
maxN
*
maxN
];
for
(
int
i
=
0
;
i
<
maxN
*
maxN
;
i
++)
filter
[
i
]=
1
f
;
int
row
;
int
backrow
;
float
rowFactLarge
;
float
rowFactSmall
;
int
col
;
int
backcol
;
float
factor
;
float
colFactLarge
;
float
colFactSmall
;
float
factStripes
;
// calculate factor in exponent of Gaussian from filterLarge / filterSmall
double
scaleLarge
=
filterLarge
*
filterLarge
;
double
scaleSmall
=
filterSmall
*
filterSmall
;
scaleStripes
=
scaleStripes
*
scaleStripes
;
//float FactStripes;
// loop over rows
for
(
int
j
=
1
;
j
<
maxN
/
2
;
j
++)
{
row
=
j
*
maxN
;
backrow
=
(
maxN
-
j
)*
maxN
;
rowFactLarge
=
(
float
)
Math
.
exp
(-(
j
*
j
)
*
scaleLarge
);
rowFactSmall
=
(
float
)
Math
.
exp
(-(
j
*
j
)
*
scaleSmall
);
// loop over columns
for
(
col
=
1
;
col
<
maxN
/
2
;
col
++){
backcol
=
maxN
-
col
;
colFactLarge
=
(
float
)
Math
.
exp
(-
(
col
*
col
)
*
scaleLarge
);
colFactSmall
=
(
float
)
Math
.
exp
(-
(
col
*
col
)
*
scaleSmall
);
factor
=
(
1
-
rowFactLarge
*
colFactLarge
)
*
rowFactSmall
*
colFactSmall
;
switch
(
stripesHorVert
)
{
case
1
:
factor
*=
(
1
-
(
float
)
Math
.
exp
(-
(
col
*
col
)
*
scaleStripes
));
break
;
// hor stripes
case
2
:
factor
*=
(
1
-
(
float
)
Math
.
exp
(-
(
j
*
j
)
*
scaleStripes
));
// vert stripes
}
fht
[
col
+
row
]
*=
factor
;
fht
[
col
+
backrow
]
*=
factor
;
fht
[
backcol
+
row
]
*=
factor
;
fht
[
backcol
+
backrow
]
*=
factor
;
filter
[
col
+
row
]
*=
factor
;
filter
[
col
+
backrow
]
*=
factor
;
filter
[
backcol
+
row
]
*=
factor
;
filter
[
backcol
+
backrow
]
*=
factor
;
}
}
//process meeting points (maxN/2,0) , (0,maxN/2), and (maxN/2,maxN/2)
int
rowmid
=
maxN
*
(
maxN
/
2
);
rowFactLarge
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleLarge
);
rowFactSmall
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleSmall
);
factStripes
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleStripes
);
fht
[
maxN
/
2
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
// (maxN/2,0)
fht
[
rowmid
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
// (0,maxN/2)
fht
[
maxN
/
2
+
rowmid
]
*=
(
1
-
rowFactLarge
*
rowFactLarge
)
*
rowFactSmall
*
rowFactSmall
;
// (maxN/2,maxN/2)
filter
[
maxN
/
2
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
// (maxN/2,0)
filter
[
rowmid
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
// (0,maxN/2)
filter
[
maxN
/
2
+
rowmid
]
*=
(
1
-
rowFactLarge
*
rowFactLarge
)
*
rowFactSmall
*
rowFactSmall
;
// (maxN/2,maxN/2)
switch
(
stripesHorVert
)
{
case
1
:
fht
[
maxN
/
2
]
*=
(
1
-
factStripes
);
fht
[
rowmid
]
=
0
;
fht
[
maxN
/
2
+
rowmid
]
*=
(
1
-
factStripes
);
filter
[
maxN
/
2
]
*=
(
1
-
factStripes
);
filter
[
rowmid
]
=
0
;
filter
[
maxN
/
2
+
rowmid
]
*=
(
1
-
factStripes
);
break
;
// hor stripes
case
2
:
fht
[
maxN
/
2
]
=
0
;
fht
[
rowmid
]
*=
(
1
-
factStripes
);
fht
[
maxN
/
2
+
rowmid
]
*=
(
1
-
factStripes
);
filter
[
maxN
/
2
]
=
0
;
filter
[
rowmid
]
*=
(
1
-
factStripes
);
filter
[
maxN
/
2
+
rowmid
]
*=
(
1
-
factStripes
);
break
;
// vert stripes
}
//loop along row 0 and maxN/2
rowFactLarge
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleLarge
);
rowFactSmall
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleSmall
);
for
(
col
=
1
;
col
<
maxN
/
2
;
col
++){
backcol
=
maxN
-
col
;
colFactLarge
=
(
float
)
Math
.
exp
(-
(
col
*
col
)
*
scaleLarge
);
colFactSmall
=
(
float
)
Math
.
exp
(-
(
col
*
col
)
*
scaleSmall
);
switch
(
stripesHorVert
)
{
case
0
:
fht
[
col
]
*=
(
1
-
colFactLarge
)
*
colFactSmall
;
fht
[
backcol
]
*=
(
1
-
colFactLarge
)
*
colFactSmall
;
fht
[
col
+
rowmid
]
*=
(
1
-
colFactLarge
*
rowFactLarge
)
*
colFactSmall
*
rowFactSmall
;
fht
[
backcol
+
rowmid
]
*=
(
1
-
colFactLarge
*
rowFactLarge
)
*
colFactSmall
*
rowFactSmall
;
filter
[
col
]
*=
(
1
-
colFactLarge
)
*
colFactSmall
;
filter
[
backcol
]
*=
(
1
-
colFactLarge
)
*
colFactSmall
;
filter
[
col
+
rowmid
]
*=
(
1
-
colFactLarge
*
rowFactLarge
)
*
colFactSmall
*
rowFactSmall
;
filter
[
backcol
+
rowmid
]
*=
(
1
-
colFactLarge
*
rowFactLarge
)
*
colFactSmall
*
rowFactSmall
;
break
;
}
}
// loop along column 0 and maxN/2
colFactLarge
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleLarge
);
colFactSmall
=
(
float
)
Math
.
exp
(-
(
maxN
/
2
)*(
maxN
/
2
)
*
scaleSmall
);
for
(
int
j
=
1
;
j
<
maxN
/
2
;
j
++)
{
row
=
j
*
maxN
;
backrow
=
(
maxN
-
j
)*
maxN
;
rowFactLarge
=
(
float
)
Math
.
exp
(-
(
j
*
j
)
*
scaleLarge
);
rowFactSmall
=
(
float
)
Math
.
exp
(-
(
j
*
j
)
*
scaleSmall
);
switch
(
stripesHorVert
)
{
case
0
:
fht
[
row
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
fht
[
backrow
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
fht
[
row
+
maxN
/
2
]
*=
(
1
-
rowFactLarge
*
colFactLarge
)
*
rowFactSmall
*
colFactSmall
;
fht
[
backrow
+
maxN
/
2
]
*=
(
1
-
rowFactLarge
*
colFactLarge
)
*
rowFactSmall
*
colFactSmall
;
filter
[
row
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
filter
[
backrow
]
*=
(
1
-
rowFactLarge
)
*
rowFactSmall
;
filter
[
row
+
maxN
/
2
]
*=
(
1
-
rowFactLarge
*
colFactLarge
)
*
rowFactSmall
*
colFactSmall
;
filter
[
backrow
+
maxN
/
2
]
*=
(
1
-
rowFactLarge
*
colFactLarge
)
*
rowFactSmall
*
colFactSmall
;
break
;
}
}
if
(
displayFilter
&&
slice
==
1
)
{
FHT
f
=
new
FHT
(
new
FloatProcessor
(
maxN
,
maxN
,
filter
,
null
));
f
.
swapQuadrants
();
new
ImagePlus
(
"Filter"
,
f
).
show
();
}
}
\ No newline at end of file
main.py
View file @
f9f59944
...
...
@@ -6,6 +6,7 @@ import argparse
from
cvBasedMethod.util
import
*
from
cvBasedMethod.kmeans
import
kmeans
,
kmeans_back
from
cvBasedMethod.threshold
import
threshold
from
predict
import
predict
def
method_kmeans
(
imglist
,
core
=
5
):
...
...
@@ -36,17 +37,27 @@ def get_args():
parser
.
add_argument
(
'-c'
,
'--core'
,
metavar
=
'C'
,
type
=
int
,
default
=
5
,
help
=
'Num of cluster'
,
dest
=
'core'
)
parser
.
add_argument
(
'-p'
,
'--process'
,
metavar
=
'P'
,
type
=
int
,
default
=
8
,
help
=
'Num of process'
,
dest
=
'process'
)
# Unet para
parser
.
add_argument
(
'--load'
,
'-L'
,
default
=
'data/module/MODEL.pth'
,
metavar
=
'FILE'
,
help
=
"Specify the file in which the model is stored"
)
parser
.
add_argument
(
'--mask-threshold'
,
'-t'
,
type
=
float
,
help
=
"Minimum probability value to consider a mask pixel white"
,
default
=
0.5
)
parser
.
add_argument
(
'--scale'
,
'-s'
,
type
=
float
,
help
=
"Scale factor for the input images"
,
default
=
0.5
)
return
parser
.
parse_args
()
if
__name__
==
'__main__'
:
args
=
get_args
()
path
=
[(
y
,
x
)
for
y
in
os
.
listdir
(
args
.
dir
)
for
x
in
filter
(
lambda
x
:
x
.
endswith
(
'.tif'
)
and
not
x
.
endswith
(
'dc.tif'
)
and
not
x
.
endswith
(
'DC.tif'
)
and
not
x
.
endswith
(
'dc .tif'
),
os
.
listdir
(
'img/'
+
y
))]
path
=
[
'data/imgs/'
+
x
for
x
in
os
.
listdir
(
'data/imgs'
)]
if
args
.
method
==
0
:
method_kmeans
(
path
,
args
.
core
)
elif
args
.
method
==
1
:
method_threshold
(
path
,
args
.
process
)
elif
args
.
method
==
2
:
method_newThreshold
(
path
,
args
.
process
)
elif
args
.
method
==
3
:
predict
(
path
,
[
'data/output/imgs/'
+
name
[
10
:]
for
name
in
path
],
args
.
load
,
args
.
scale
,
args
.
mask_threshold
)
predict.py
View file @
f9f59944
import
argparse
import
logging
import
os
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
...
...
@@ -9,107 +6,46 @@ from PIL import Image
from
torchvision
import
transforms
from
unet
import
UNet
from
utils.data_vis
import
plot_img_and_mask
from
utils.dataset
import
BasicDataset
def
predict_img
(
net
,
full_img
,
device
,
scale_factor
=
1
,
out_threshold
=
0.5
):
def
predict_img
(
net
,
full_img
,
device
,
scale_factor
=
1
,
out_threshold
=
0.5
):
net
.
eval
()
img
=
torch
.
from_numpy
(
BasicDataset
.
preprocess
(
full_img
,
scale_factor
))
img
=
img
.
unsqueeze
(
0
)
img
=
img
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
img
=
img
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
output
=
net
(
img
)
if
net
.
n_classes
>
1
:
probs
=
F
.
softmax
(
output
,
dim
=
1
)
probs
=
F
.
softmax
(
output
,
dim
=
1
)
else
:
probs
=
torch
.
sigmoid
(
output
)
probs
=
probs
.
squeeze
(
0
)
tf
=
transforms
.
Compose
(
[
transforms
.
ToPILImage
(),
transforms
.
Resize
(
full_img
.
size
[
1
]),
transforms
.
ToTensor
()
]
)
tf
=
transforms
.
Compose
([
transforms
.
ToPILImage
(),
transforms
.
Resize
(
full_img
.
size
[
1
]),
transforms
.
ToTensor
()])
probs
=
tf
(
probs
.
cpu
())
full_mask
=
probs
.
squeeze
()
.
cpu
()
.
numpy
()
return
full_mask
>
out_threshold
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Predict masks from input images'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'--model'
,
'-m'
,
default
=
'MODEL.pth'
,
metavar
=
'FILE'
,
help
=
"Specify the file in which the model is stored"
)
parser
.
add_argument
(
'--input'
,
'-i'
,
metavar
=
'INPUT'
,
nargs
=
'+'
,
help
=
'filenames of input images'
,
required
=
True
)
parser
.
add_argument
(
'--output'
,
'-o'
,
metavar
=
'INPUT'
,
nargs
=
'+'
,
help
=
'Filenames of ouput images'
)
parser
.
add_argument
(
'--viz'
,
'-v'
,
action
=
'store_true'
,
help
=
"Visualize the images as they are processed"
,
default
=
False
)
parser
.
add_argument
(
'--no-save'
,
'-n'
,
action
=
'store_true'
,
help
=
"Do not save the output masks"
,
default
=
False
)
parser
.
add_argument
(
'--mask-threshold'
,
'-t'
,
type
=
float
,
help
=
"Minimum probability value to consider a mask pixel white"
,
default
=
0.5
)
parser
.
add_argument
(
'--scale'
,
'-s'
,
type
=
float
,
help
=
"Scale factor for the input images"
,
default
=
0.5
)
return
parser
.
parse_args
()
def
get_output_filenames
(
args
):
in_files
=
args
.
input
out_files
=
[]
if
not
args
.
output
:
for
f
in
in_files
:
pathsplit
=
os
.
path
.
splitext
(
f
)
out_files
.
append
(
"{}_OUT{}"
.
format
(
pathsplit
[
0
],
pathsplit
[
1
]))
elif
len
(
in_files
)
!=
len
(
args
.
output
):
logging
.
error
(
"Input files and output files are not of the same length"
)
raise
SystemExit
()
else
:
out_files
=
args
.
output
return
out_files
def
mask_to_image
(
mask
):
return
Image
.
fromarray
((
mask
*
255
)
.
astype
(
np
.
uint8
))
if
__name__
==
"__main__"
:
args
=
get_args
()
in_files
=
args
.
input
out_files
=
get_output_filenames
(
args
)
def
predict
(
img_name
,
outdir
,
model
,
scale
,
mask_threshold
):
in_files
=
img_name
out_files
=
outdir
net
=
UNet
(
n_channels
=
1
,
n_classes
=
1
)
net
=
UNet
(
n_channels
=
1
,
n_classes
=
1
)
logging
.
info
(
"Loading model {}"
.
format
(
args
.
model
))
logging
.
info
(
"Loading model {}"
.
format
(
model
))
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
logging
.
info
(
f
'Using device {device}'
)
net
.
to
(
device
=
device
)
net
.
load_state_dict
(
torch
.
load
(
args
.
model
,
map_location
=
device
))
net
.
to
(
device
=
device
)
net
.
load_state_dict
(
torch
.
load
(
model
,
map_location
=
device
))
logging
.
info
(
"Model loaded !"
)
...
...
@@ -118,19 +54,10 @@ if __name__ == "__main__":
img
=
Image
.
open
(
fn
)
mask
=
predict_img
(
net
=
net
,
full_img
=
img
,
scale_factor
=
args
.
scale
,
out_threshold
=
args
.
mask_threshold
,
device
=
device
)
if
not
args
.
no_save
:
out_fn
=
out_files
[
i
]
result
=
mask_to_image
(
mask
)
result
.
save
(
out_files
[
i
])
logging
.
info
(
"Mask saved to {}"
.
format
(
out_files
[
i
]))
mask
=
predict_img
(
net
=
net
,
full_img
=
img
,
scale_factor
=
scale
,
out_threshold
=
mask_threshold
,
device
=
device
)
#out_fn = out_files[i]
result
=
mask_to_image
(
mask
)
result
.
save
(
out_files
)
if
args
.
viz
:
logging
.
info
(
"Visualizing results for image {}, close to continue ..."
.
format
(
fn
))
plot_img_and_mask
(
img
,
mask
)
logging
.
info
(
"Mask saved to {}"
.
format
(
out_files
[
i
]))
\ No newline at end of file
utils/data_vis.py
deleted
100644 → 0
View file @
b835e670
import
matplotlib.pyplot
as
plt
def
plot_img_and_mask
(
img
,
mask
):
classes
=
mask
.
shape
[
2
]
if
len
(
mask
.
shape
)
>
2
else
1
fig
,
ax
=
plt
.
subplots
(
1
,
classes
+
1
)
ax
[
0
]
.
set_title
(
'Input image'
)
ax
[
0
]
.
imshow
(
img
)
if
classes
>
1
:
for
i
in
range
(
classes
):
ax
[
i
+
1
]
.
set_title
(
f
'Output mask (class {i+1})'
)
ax
[
i
+
1
]
.
imshow
(
mask
[:,
:,
i
])
else
:
ax
[
1
]
.
set_title
(
f
'Output mask'
)
ax
[
1
]
.
imshow
(
mask
)
plt
.
xticks
([]),
plt
.
yticks
([])
plt
.
show
()
cli
/quality.py
→
utils
/quality.py
View file @
f9f59944
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment