Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
文档服务地址:
http://47.92.0.57:3000/
周报索引地址:
http://47.92.0.57:3000/s/NruNXRYmV
Open sidebar
王肇一
Im
Commits
9ef386bf
Commit
9ef386bf
authored
Feb 02, 2020
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
CE-net based module, pause for a while
parent
af27b9fa
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
13 deletions
+64
-13
UnetBasedMethod.py
UnetBasedMethod.py
+7
-3
__init__.py
ceunet/__init__.py
+3
-2
ceunet_model.py
ceunet/ceunet_model.py
+5
-2
ceunet_parts.py
ceunet/ceunet_parts.py
+0
-0
train.py
train.py
+5
-6
trainCE-Net.py
trainCE-Net.py
+44
-0
No files found.
UnetBasedMethod.py
View file @
9ef386bf
...
@@ -82,13 +82,13 @@ def get_args():
...
@@ -82,13 +82,13 @@ def get_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
if
__name__
==
'__main__'
:
def
cli_main
()
:
args
=
get_args
()
args
=
get_args
()
path
=
[(
y
,
x
)
for
y
in
filter
(
lambda
x
:
x
!=
'.DS_Store'
,
os
.
listdir
(
'data/imgs'
))
for
x
in
filter
(
path
=
[(
y
,
x
)
for
y
in
filter
(
lambda
x
:
x
!=
'.DS_Store'
,
os
.
listdir
(
'data/imgs'
))
for
x
in
filter
(
lambda
x
:
x
.
endswith
(
'.tif'
)
and
not
x
.
endswith
(
'dc.tif'
)
and
not
x
.
endswith
(
'DC.tif'
)
and
not
x
.
endswith
(
lambda
x
:
x
.
endswith
(
'.tif'
)
and
not
x
.
endswith
(
'dc.tif'
)
and
not
x
.
endswith
(
'DC.tif'
)
and
not
x
.
endswith
(
'dc .tif'
),
os
.
listdir
(
'data/imgs/'
+
y
))]
'dc .tif'
),
os
.
listdir
(
'data/imgs/'
+
y
))]
seperate_path
=
divide_list
(
path
,
args
.
process
)
seperate_path
=
divide_list
(
path
,
args
.
process
)
if
args
.
step
==
1
:
if
args
.
step
==
1
:
net
=
UNet
(
n_channels
=
1
,
n_classes
=
1
)
net
=
UNet
(
n_channels
=
1
,
n_classes
=
1
)
...
@@ -107,10 +107,13 @@ if __name__ == '__main__':
...
@@ -107,10 +107,13 @@ if __name__ == '__main__':
elif
args
.
step
==
2
:
elif
args
.
step
==
2
:
dir
=
[
x
for
x
in
filter
(
lambda
x
:
x
!=
'.DS_Store'
,
os
.
listdir
(
'data/imgs/'
))]
dir
=
[
x
for
x
in
filter
(
lambda
x
:
x
!=
'.DS_Store'
,
os
.
listdir
(
'data/imgs/'
))]
sep_dir
=
divide_list
(
dir
,
args
.
process
)
sep_dir
=
divide_list
(
dir
,
args
.
process
)
pool
=
Pool
(
args
.
process
)
pool
=
Pool
(
args
.
process
)
for
i
,
list
in
enumerate
(
sep_dir
):
for
i
,
list
in
enumerate
(
sep_dir
):
pool
.
apply_async
(
step_2
,
args
=
(
list
,
i
))
pool
.
apply_async
(
step_2
,
args
=
(
list
,
i
))
pool
.
close
()
pool
.
close
()
pool
.
join
()
pool
.
join
()
if
__name__
==
'__main__'
:
cli_main
()
\ No newline at end of file
CE-U
net/__init__.py
→
ceu
net/__init__.py
View file @
9ef386bf
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
\ No newline at end of file
from
.ceunet_model
import
CEUnet
\ No newline at end of file
CE-U
net/ceunet_model.py
→
ceu
net/ceunet_model.py
View file @
9ef386bf
...
@@ -15,11 +15,14 @@ class CEUnet(nn.Module):
...
@@ -15,11 +15,14 @@ class CEUnet(nn.Module):
resnet
=
models
.
resnet34
(
pretrained
=
True
)
resnet
=
models
.
resnet34
(
pretrained
=
True
)
weight
=
resnet
.
conv1
.
weight
weight
=
resnet
.
conv1
.
weight
self
.
n_channels
=
n_channels
self
.
n_classes
=
n_classes
self
.
inc
=
nn
.
Conv2d
(
in_channels
=
n_channels
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
1
)
self
.
inc
=
nn
.
Conv2d
(
in_channels
=
n_channels
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
1
)
self
.
inc
.
weight
=
nn
.
Parameter
(
weight
[:,
:
1
,
:,
:])
self
.
inc
.
weight
=
nn
.
Parameter
(
weight
[:,
:
1
,
:,
:])
self
.
bn1
=
resnet
.
bn1
self
.
bn1
=
resnet
.
bn1
self
.
relu
1
=
resnet
.
relu1
self
.
relu
=
resnet
.
relu
self
.
maxpool1
=
resnet
.
maxpool
self
.
maxpool1
=
resnet
.
maxpool
self
.
encoder1
=
resnet
.
layer1
self
.
encoder1
=
resnet
.
layer1
...
@@ -44,7 +47,7 @@ class CEUnet(nn.Module):
...
@@ -44,7 +47,7 @@ class CEUnet(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
inc
(
x
)
x
=
self
.
inc
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool1
(
x
)
x
=
self
.
maxpool1
(
x
)
e1
=
self
.
encoder1
(
x
)
e1
=
self
.
encoder1
(
x
)
...
...
CE-U
net/ceunet_parts.py
→
ceu
net/ceunet_parts.py
View file @
9ef386bf
File moved
train.py
View file @
9ef386bf
...
@@ -57,11 +57,10 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
...
@@ -57,11 +57,10 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
for
batch
in
train_loader
:
for
batch
in
train_loader
:
imgs
=
batch
[
'image'
]
imgs
=
batch
[
'image'
]
true_masks
=
batch
[
'mask'
]
true_masks
=
batch
[
'mask'
]
assert
imgs
.
shape
[
assert
imgs
.
shape
[
1
]
==
net
.
n_channels
,
\
1
]
==
net
.
n_channels
,
f
'Network has been defined with {net.n_channels} input channels, '
\
f
'Network has been defined with {net.n_channels} input channels, '
\
f
'but loaded images have {imgs.shape[1]} channels. Please check that '
\
f
'but loaded images have {imgs.shape[1]} channels. Please check that '
\
f
''
\
'the images are loaded correctly.'
'the images are loaded correctly.'
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
...
@@ -153,7 +152,7 @@ if __name__ == '__main__':
...
@@ -153,7 +152,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
# cudnn.benchmark = True
try
:
try
:
train_net
(
net
=
net
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
device
=
device
,
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
...
...
trainCE-Net.py
0 → 100644
View file @
9ef386bf
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import
argparse
import
logging
import
os
import
sys
import
torch
from
ceunet
import
CEUnet
from
train
import
get_args
,
train_net
dir_img
=
'data/train_imgs/'
dir_mask
=
'data/train_masks/'
dir_checkpoint
=
'ce_checkpoints/'
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
'
%(levelname)
s:
%(message)
s'
)
args
=
get_args
()
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
logging
.
info
(
f
'Using device {device}'
)
net
=
CEUnet
(
n_channels
=
1
,
n_classes
=
1
)
logging
.
info
(
f
'Network:
\n
'
f
'
\t
{net.n_channels} input channels
\n
'
f
'
\t
{net.n_classes} output channels (classes)
\n
'
)
if
args
.
load
:
net
.
load_state_dict
(
torch
.
load
(
args
.
load
,
map_location
=
device
))
logging
.
info
(
f
'Model loaded from {args.load}'
)
net
.
to
(
device
=
device
)
try
:
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
logging
.
info
(
'Saved interrupt'
)
try
:
sys
.
exit
(
0
)
except
SystemExit
:
os
.
_exit
(
0
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment