Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
王肇一
Im
Commits
8c28bb60
Commit
8c28bb60
authored
5 years ago
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mrnet now ready
parent
9b02e4db
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
20 deletions
+21
-20
.gitignore
.gitignore
+2
-0
mrnet_module.py
mrnet/mrnet_module.py
+2
-2
mrnet_parts.py
mrnet/mrnet_parts.py
+6
-5
train.py
mrnet/train.py
+2
-2
train.py
train.py
+3
-4
dataset.py
utils/dataset.py
+6
-7
No files found.
.gitignore
View file @
8c28bb60
...
...
@@ -5,4 +5,6 @@ data/masks/*
data/output/*
data/train_imgs/*
data/train_masks/*
data/train_imgs_32/*
.ipynb_checkpoints/
runs
This diff is collapsed.
Click to expand it.
mrnet/mrnet_module.py
View file @
8c28bb60
...
...
@@ -14,7 +14,7 @@ class MultiUnet(nn.Module):
self
.
bilinear
=
bilinear
self
.
inconv
=
nn
.
Sequential
(
nn
.
Conv2d
(
n_channels
,
8
,
kernel_size
=
3
,
stride
=
1
,
padding
_mode
=
'same'
),
nn
.
Conv2d
(
n_channels
,
8
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
8
),
nn
.
ReLU
(
inplace
=
True
)
)
...
...
@@ -41,7 +41,7 @@ class MultiUnet(nn.Module):
self
.
pool
=
nn
.
MaxPool2d
(
2
)
self
.
outconv
=
nn
.
Conv2d
(
self
.
res9
.
outc
,
n_classes
,
kernel_size
=
3
,
padding
_mode
=
'same'
)
self
.
outconv
=
nn
.
Conv2d
(
self
.
res9
.
outc
,
n_classes
,
kernel_size
=
3
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
self
.
inconv
(
x
)
...
...
This diff is collapsed.
Click to expand it.
mrnet/mrnet_parts.py
View file @
8c28bb60
...
...
@@ -16,8 +16,8 @@ class MultiResBlock(nn.Module):
def
conv
(
self
,
in_channel
,
out_channel
,
kernel_size
):
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
inplace
=
True
))
nn
.
Conv2d
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
1
)
,
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
inplace
=
True
))
def
forward
(
self
,
x
):
shortcut
=
nn
.
Sequential
(
...
...
@@ -41,9 +41,9 @@ class TransCompose(nn.Module):
self
.
inc
=
in_channels
self
.
outc
=
out_channels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
3
,
stride
=
2
),
nn
.
ConvTranspose2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
2
,
stride
=
2
),
nn
.
BatchNorm2d
(
self
.
outc
)
)(
x
)
...
...
@@ -61,7 +61,8 @@ class ResPath(nn.Module):
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
))(
x
)
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
))(
x
)
result
=
torch
.
add
(
conv
,
shortcut
)
return
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
out_channels
))(
result
)
...
...
This diff is collapsed.
Click to expand it.
mrnet/train.py
View file @
8c28bb60
...
...
@@ -18,7 +18,7 @@ dir_mask = 'data/train_masks/'
dir_checkpoint
=
'checkpoint'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
save_cp
=
True
,
img_scale
=
0.5
):
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
img_scale
=
0.5
):
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_train
=
len
(
dataset
)
-
n_val
...
...
@@ -27,7 +27,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
criterion
=
nn
.
BCELoss
()
criterion
=
nn
.
BCE
WithLogits
Loss
()
for
epoch
in
tqdm
(
range
(
epochs
)):
net
.
train
()
...
...
This diff is collapsed.
Click to expand it.
train.py
View file @
8c28bb60
...
...
@@ -40,14 +40,13 @@ if __name__ == '__main__':
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
logging
.
info
(
f
'Using device {device}'
)
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
#net = UNet(n_channels = 1, n_classes = 1)
#
net = UNet(n_channels = 1, n_classes = 1)
net
=
MultiUnet
(
n_channels
=
1
,
n_classes
=
1
)
logging
.
info
(
f
'Network:
\n
'
...
...
@@ -64,8 +63,8 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try
:
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
logging
.
info
(
'Saved interrupt'
)
...
...
This diff is collapsed.
Click to expand it.
utils/dataset.py
View file @
8c28bb60
...
...
@@ -23,11 +23,10 @@ class BasicDataset(Dataset):
return
len
(
self
.
ids
)
@classmethod
def
preprocess
(
cls
,
pil_img
,
scale
):
w
,
h
=
pil_img
.
size
newW
,
newH
=
256
,
256
#int(scale * w), int(scale * h)
assert
newW
>
0
and
newH
>
0
,
'Scale is too small'
pil_img
=
pil_img
.
resize
((
newW
,
newH
))
def
preprocess
(
cls
,
pil_img
):
#newW, newH = 256,256 #int(scale * w), int(scale * h)
#assert newW > 0 and newH > 0, 'Scale is too small'
pil_img
=
pil_img
.
resize
((
256
,
256
))
img_nd
=
np
.
array
(
pil_img
)
...
...
@@ -56,7 +55,7 @@ class BasicDataset(Dataset):
assert
img
.
size
==
mask
.
size
,
\
f
'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img
=
self
.
preprocess
(
img
,
self
.
scale
)
mask
=
self
.
preprocess
(
mask
,
self
.
scale
)
img
=
self
.
preprocess
(
img
)
mask
=
self
.
preprocess
(
mask
)
return
{
'image'
:
torch
.
from_numpy
(
img
),
'mask'
:
torch
.
from_numpy
(
mask
)}
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment