Commit 1cab996a by 王肇一

generate mask avoid real signal

parent 9a523da8
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import cv2 as cv
import argparse
def get_args():
parser = argparse.ArgumentParser(description = 'Identify targets from background by KMeans or Threshold',
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i', '--input', metavar = 'I', type = str, default = './',
help = 'input_dir', dest = 'input')
parser.add_argument('-o','--output',metavar = 'O',type = str,default = './out/',help='output dir',dest = 'output')
return parser.parse_args()
args = get_args()
os.mkdir(args.output)
for name in os.listdir(args.input):
img = cv.imread(args.input+name,flags = cv.IMREAD_GRAYSCALE)
cv.imwrite(args.output+name,img)
\ No newline at end of file
...@@ -4,8 +4,10 @@ import numpy as np ...@@ -4,8 +4,10 @@ import numpy as np
import cv2 as cv import cv2 as cv
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
import logging import logging
import os import os
import re
from cvBasedMethod.filters import fft_mask,butterworth from cvBasedMethod.filters import fft_mask,butterworth
def remove_scratch(img): def remove_scratch(img):
...@@ -71,32 +73,47 @@ def save_img(img_list, dir, name): ...@@ -71,32 +73,47 @@ def save_img(img_list, dir, name):
plt.title(title) plt.title(title)
plt.imshow(img, 'gray') plt.imshow(img, 'gray')
try: try:
os.makedirs('out/' + dir + '/graph') os.makedirs('data/output/' + dir + '/graph')
except: except:
logging.info('Existing dir: out/' + dir + '/graph') logging.info('Existing dir: data/output/' + dir )
plt.savefig('out/' + dir + '/graph/' + name + '.png') plt.savefig('data/output/' + dir + '/graph/' + name[:-4] + '.png')
plt.close() plt.close()
def calcRes(img, mask, dir = 'output', name = 'output'): def calcRes(img, mask, dir = 'default_output', name = 'output'):
dic = get_subarea_infor(img, mask) dic = get_subarea_infor(img, mask)
df = pd.DataFrame(dic) df = pd.DataFrame(dic)
try: try:
os.makedirs('out/' + dir + '/csv') os.makedirs('data/output/' + dir + '/csv')
except: except:
logging.info('Existing dir: out/' + dir + '/csv') logging.info('Existing dir: out/' + dir + '/csv')
df.to_csv('out/' + dir + '/csv/' + name + '.csv') if len(df)!=0:
df.to_csv('data/output/' + dir + '/csv/' + name + '.csv',index = False)
def draw_bar(exName,names):
df = pd.DataFrame(columns = ('class', 'perc', 'Label', 'Area', 'Mean', 'Std', 'BackMean', 'BackStd'))
for name in names:
tmp = pd.read_csv('data/output/' + exName + '/csv/' + name)
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.csv', name)
tmp['perc'] = str.lower(match_group.group(1))[:-2] if str.lower(match_group.group(1)).endswith('ug') else str.lower(
match_group.group(1))
tmp['perc'].replace({'d2o':'0'},inplace = True)
tmp['class'] = str.lower(match_group.group(2))
df = df.append(tmp, ignore_index = True, sort = True)
df = df[df['Area']>19]
df['Pure'] = df['Mean'] - df['BackMean']
sns.set_style("darkgrid")
#sns.catplot(x = 'perc',y = 'Mean',hue = 'class',kind='bar',data = df)
sns.pairplot(df, vars = ['Area', 'Mean', 'perc','class'])
plt.show()
def get_subarea_infor(img, mask): def get_subarea_infor(img, mask):
area_num, labels, stats, centroids = cv.connectedComponentsWithStats(mask) area_num, labels, stats, centroids = cv.connectedComponentsWithStats(mask)
info = []
label_group = []
area_group = []
mean_group = []
std_group = []
back_mean = []
back_std = []
for i in filter(lambda x: x != 0, range(area_num)): for i in filter(lambda x: x != 0, range(area_num)):
group = np.where(labels == i) group = np.where(labels == i)
...@@ -113,19 +130,16 @@ def get_subarea_infor(img, mask): ...@@ -113,19 +130,16 @@ def get_subarea_infor(img, mask):
res[x, y] = mask[x, y] res[x, y] = mask[x, y]
else: else:
res[x, y] = 0 res[x, y] = 0
kernel = np.ones((17, 17), np.uint8) kernel = np.ones((17, 17), np.uint8)
mask_background = cv.erode(255 - res, kernel) mask_background = cv.erode(255 - res, kernel)
minimask = cv.bitwise_xor(mask_background, 255 - res) minimask = cv.bitwise_xor(mask_background, 255 - res)
realminimask = cv.bitwise_and(minimask, 255 - mask)
img_background = img[np.where(minimask != 0)] img_background = img[np.where(realminimask != 0)]
mean_value = np.mean(img_background) mean_value = np.mean(img_background)
std_value = np.std(img_background) std_value = np.std(img_background)
label_group.append(i) info.append({'Label': i, 'Area': area_tmp, 'Mean': mean_tmp, 'Std': std_tmp, 'BackMean': mean_value,
area_group.append(area_tmp) 'BackStd': std_value})
mean_group.append(mean_tmp) return info
std_group.append(std_tmp) \ No newline at end of file
back_mean.append(mean_value)
back_std.append(std_value)
return {'Label': label_group, 'Area': area_group, 'Mean': mean_group, 'Std': std_group, 'BackMean': back_mean,
'BackStd': back_std}
...@@ -24,7 +24,7 @@ def method_threshold(imglist, process = 8): ...@@ -24,7 +24,7 @@ def method_threshold(imglist, process = 8):
def method_newThreshold(imglist, process = 8): def method_newThreshold(imglist, process = 8):
pool = Pool(process) pool = Pool(process)
pool.map(lambda x: threshold(x,'fft'), imglist) pool.map(lambda x: threshold(x, 'fft'), imglist)
pool.close() pool.close()
pool.join() pool.join()
...@@ -32,8 +32,9 @@ def method_newThreshold(imglist, process = 8): ...@@ -32,8 +32,9 @@ def method_newThreshold(imglist, process = 8):
def get_args(): def get_args():
parser = argparse.ArgumentParser(description = 'Identify targets from background by KMeans or Threshold', parser = argparse.ArgumentParser(description = 'Identify targets from background by KMeans or Threshold',
formatter_class = argparse.ArgumentDefaultsHelpFormatter) formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-m', '--method', metavar = 'M', type = int, default = 0, parser.add_argument('-m', '--method', metavar = 'M', type = int, default = 3,
help = '0 for KMeans; 1 for Threshold; 2 for newThreshold', dest = 'method') help = '0 for KMeans; 1 for Threshold; 2 for newThreshold; 3 for Unet; 4 for further process',
dest = 'method')
parser.add_argument('-c', '--core', metavar = 'C', type = int, default = 5, help = 'Num of cluster', dest = 'core') parser.add_argument('-c', '--core', metavar = 'C', type = int, default = 5, help = 'Num of cluster', dest = 'core')
parser.add_argument('-p', '--process', metavar = 'P', type = int, default = 8, help = 'Num of process', parser.add_argument('-p', '--process', metavar = 'P', type = int, default = 8, help = 'Num of process',
dest = 'process') dest = 'process')
...@@ -48,16 +49,18 @@ def get_args(): ...@@ -48,16 +49,18 @@ def get_args():
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
args = get_args() args = get_args()
path = ['data/imgs/'+x for x in os.listdir('data/imgs')] path = [(y, x) for y in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs')) for x in filter(
lambda x: x.endswith('.tif') and not x.endswith('dc.tif') and not x.endswith('DC.tif') and not x.endswith(
'dc .tif'), os.listdir('data/imgs/' + y))]
if args.method == 0: if args.method == 0:
method_kmeans(path,args.core) method_kmeans(path, args.core)
elif args.method == 1: elif args.method == 1:
method_threshold(path,args.process) method_threshold(path, args.process)
elif args.method == 2: elif args.method == 2:
method_newThreshold(path,args.process) method_newThreshold(path, args.process)
elif args.method == 3: elif args.method == 3:
predict(path, ['data/output/imgs/'+name[10:] for name in path], args.load, args.scale, args.mask_threshold) predict(path, args.load, args.scale, args.mask_threshold)
for exName in filter(lambda x: x != '.DS_Store', os.listdir('data/output')):
draw_bar(exName, os.listdir('data/output/' + exName + '/csv'))
...@@ -3,10 +3,13 @@ import numpy as np ...@@ -3,10 +3,13 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
import cv2 as cv
from tqdm import tqdm
from torchvision import transforms from torchvision import transforms
from unet import UNet from unet import UNet
from utils.dataset import BasicDataset from utils.dataset import BasicDataset
from cvBasedMethod.util import save_img,calcRes
def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5): def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
...@@ -30,14 +33,7 @@ def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5): ...@@ -30,14 +33,7 @@ def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
return full_mask > out_threshold return full_mask > out_threshold
def mask_to_image(mask): def predict(file_names, model, scale, mask_threshold):
return Image.fromarray((mask * 255).astype(np.uint8))
def predict(img_name, outdir, model, scale, mask_threshold):
in_files = img_name
out_files = outdir
net = UNet(n_channels = 1, n_classes = 1) net = UNet(n_channels = 1, n_classes = 1)
logging.info("Loading model {}".format(model)) logging.info("Loading model {}".format(model))
...@@ -49,15 +45,12 @@ def predict(img_name, outdir, model, scale, mask_threshold): ...@@ -49,15 +45,12 @@ def predict(img_name, outdir, model, scale, mask_threshold):
logging.info("Model loaded !") logging.info("Model loaded !")
for i, fn in enumerate(in_files): for i, fn in enumerate(tqdm(file_names)):
logging.info("\nPredicting image {} ...".format(fn)) logging.info("\nPredicting image {} ...".format(fn[0]+'/'+fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
img = Image.open(fn)
mask = predict_img(net = net, full_img = img, scale_factor = scale, out_threshold = mask_threshold, mask = predict_img(net = net, full_img = img, scale_factor = scale, out_threshold = mask_threshold,
device = device) device = device)
#out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
logging.info("Mask saved to {}".format(out_files[i])) result = (mask * 255).astype(np.uint8) # result.save(out_files[i]) # logging.info("Mask saved to {}".format(out_files[i]))
\ No newline at end of file save_img({'ori':img,'mask':result},fn[0],fn[1])
calcRes(cv.cvtColor(np.asarray(img), cv.COLOR_RGB2BGR),result,fn[0],fn[1][:-4])
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment