-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_data.py
150 lines (112 loc) · 4.99 KB
/
build_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import datetime
import os
import time
import cv2
from tqdm import tqdm
import numpy as np
import pathlib
import logging
from utils import augs
from utils.aux_funcs import calc_jaccard, get_runtime, scan_files, plot_hist, show_images
from configs.general_configs import (
N_SAMPLES,
MIN_J,
MAX_J,
SEG_DIR_POSTFIX,
SEG_PREFIX,
IMAGE_PREFIX,
TRAIN_INPUT_DATA_DIR,
TEST_INPUT_DATA_DIR,
TRAIN_OUTPUT_DATA_DIR,
TEST_OUTPUT_DATA_DIR, DATA_ROOT_DIR
)
logging.getLogger('PIL').setLevel(logging.WARNING)
__author__ = '[email protected]'
def build_data_file(files: list, output_dir: pathlib.Path, n_samples, min_j: int, max_j: int, plot_samples: bool = False):
t_start = time.time()
mask_augs = augs.mask_augs()
imgs = []
masks = []
aug_masks = []
jaccards = []
n_passes = n_samples // len(files) if len(files) > 0 else 0
for pass_idx in range(n_passes):
print(f'\n== Pass: {pass_idx+1}/{n_passes} ==')
files_pbar = tqdm(files)
for img_fl, seg_fl in files_pbar:
img = cv2.imread(str(img_fl), -1)
mask = cv2.imread(str(seg_fl), -1)
aug_res = mask_augs(image=img, mask=mask)
img_aug, mask_aug = aug_res.get('image'), aug_res.get('mask')
jaccard = calc_jaccard(mask, mask_aug)
if min_j < jaccard < max_j:
imgs.append(img)
masks.append(mask)
aug_masks.append(mask_aug)
jaccards.append(jaccard)
files_pbar.set_postfix(jaccard=f'{jaccard:.4f}')
data = np.array(list(zip(imgs, masks, aug_masks, jaccards)), dtype=object)
if len(data):
# - Save the data
data_dir = output_dir / f'{len(data)}_samples'
# To avoid data overwrite
if data_dir.is_dir():
ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f'\'{data_dir}\' already exists! Data will be placed in \'{data_dir}/{ts}\'')
data_dir = data_dir / ts
os.makedirs(data_dir, exist_ok=True)
print(f'> Saving data to \'{data_dir}/data.npy\'...')
np.save(str(data_dir / f'data.npy'), data, allow_pickle=True)
print(f'> Data was saved to \'{data_dir}/data.npy\'')
if plot_samples:
print(f'Plotting samples..')
# - Plot samples
samples_dir = data_dir / 'samples'
os.makedirs(samples_dir, exist_ok=True)
plot_pbar = tqdm(data)
idx = 0
for img, msk, msk_aug, j in plot_pbar:
show_images(images=[img, msk, msk_aug], labels=['Image', 'Mask', 'Augmented Mask'], suptitle=f'Jaccard: {j:.4f}', figsize=(25, 10), save_file=samples_dir / f'{idx}.png')
idx += 1
# - Plot J histogram
print(f'Plotting the histogram of the Js...')
plot_hist(data=jaccards, hist_range=(0., 1., 0.1), bins=10, save_name=f'data dist ({len(data)} samples)', output_dir=data_dir, density=True)
print(f'== Data generation took: {get_runtime(seconds=time.time() - t_start)} ==')
else:
print(f'No data was generated - no files were provided!')
return data
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_type', type=str, choices=['train', 'test'], help='The type of the data (i.e, train or test)')
parser.add_argument('--seg_dir_postfix', type=str, default=SEG_DIR_POSTFIX, help='The postfix of the directory which holds the segmentations')
parser.add_argument('--image_prefix', type=str, default=IMAGE_PREFIX, help='The prefix of the images')
parser.add_argument('--seg_prefix', type=str, default=SEG_PREFIX, help='The prefix of the segmentations')
parser.add_argument('--n_samples', type=int, default=N_SAMPLES, help='The total number of samples to generate')
parser.add_argument('--min_j', type=float, default=MIN_J, help='The minimal allowed jaccard ')
parser.add_argument('--max_j', type=float, default=MAX_J, help='The maximal allowed jaccard ')
parser.add_argument('--plot_samples', default=False, action='store_true', help=f'If to plot the samples images of the data')
return parser
if __name__ == '__main__':
# - Get the argument parser
parser = get_arg_parser()
args = parser.parse_args()
# - Scan the files in the data dir
fls = scan_files(
root_dir=DATA_ROOT_DIR / f'train/{TRAIN_INPUT_DATA_DIR}' if args.data_type == 'train' else DATA_ROOT_DIR / f'test/{TEST_INPUT_DATA_DIR}',
seg_dir_postfix=args.seg_dir_postfix,
image_prefix=args.image_prefix,
seg_prefix=args.seg_prefix
)
# - Build the data file
if fls:
build_data_file(
files=fls,
output_dir=TRAIN_OUTPUT_DATA_DIR if args.data_type == 'train' else TEST_OUTPUT_DATA_DIR,
n_samples=args.n_samples,
min_j=args.min_j,
max_j=args.max_j,
plot_samples=args.plot_samples,
)
else:
print(f'No files to generate data from !')