-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbrats_data_loader.py
158 lines (126 loc) · 7.27 KB
/
brats_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from time import time
from batchgenerators.augmentations.crop_and_pad_augmentations import crop
from batchgenerators.dataloading import MultiThreadedAugmenter
from config import brats_preprocessed_folder, num_threads_for_brats_example
# from batchgenerators.examples.brats2017.config import brats_preprocessed_folder, num_threads_for_brats_example
from batchgenerators.transforms import Compose
from batchgenerators.utilities.data_splitting import get_split_deterministic
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
from batchgenerators.dataloading.data_loader import DataLoader
from batchgenerators.augmentations.utils import pad_nd_image
from batchgenerators.transforms.spatial_transforms import SpatialTransform_2, MirrorTransform
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, GammaTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
channel_indices = {
't1': 0,
't1c': 1,
't2': 2,
'flair': 3,
'seg': 4
}
def get_train_transform(patch_size):
# we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
# to showcase some things
tr_transforms = []
# the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
# also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
# shape and do not transform spatially, so no border artifacts will be introduced
# Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
# We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
# of samples will be augmented, the rest will just be cropped
tr_transforms.append(
SpatialTransform_2(
patch_size, [i // 2 for i in patch_size],
do_elastic_deform=True, deformation_scale=(0, 0.25),
do_rotation=True,
angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
do_scale=True, scale=(0.75, 1.25),
border_mode_data='constant', border_cval_data=0,
border_mode_seg='constant', border_cval_seg=0,
order_seg=1, order_data=3,
random_crop=True,
p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1
)
)
# now we mirror along all axes
tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
# brightness transform for 15% of samples
tr_transforms.append(BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15))
# gamma transform. This is a nonlinear transformation of intensity values
# (https://en.wikipedia.org/wiki/Gamma_correction)
tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15))
# we can also invert the image, apply the transform and then invert back
tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15))
# Gaussian Noise
tr_transforms.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))
# blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and
# thus make the model more robust to it
tr_transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True,
p_per_channel=0.5, p_per_sample=0.15))
# now we compose these transforms together
tr_transforms = Compose(tr_transforms)
return tr_transforms
def get_list_of_patients(preprocessed_data_folder):
npy_files = subfiles(preprocessed_data_folder, suffix=".npy", join=True)
# remove npy file extension
patients = [i[:-4] for i in npy_files]
return patients
#NEW
def iterate_through_patients(patients, in_channels):
in_channels = [channel_indices[i] for i in in_channels]
for p in patients:
patient_data, meta_data = BRATSDataLoader.load_patient(p)
# patient_data = BRATSDataLoader.load_patient(p)[0][in_channels][None]
# meta_data = BRATSDataLoader.load_patient(p)[1]
yield (patient_data[in_channels][None], meta_data)
class BRATSDataLoader(DataLoader):
"""
Based on Fabian Isensee's BRATS dataloader example.
"""
def __init__(self, data, batch_size, patch_size, in_channels, num_threads_in_multithreaded=1,
seed_for_shuffle=1234, return_incomplete=False, shuffle=True, infinite=True):
super(BRATSDataLoader, self).__init__(data, batch_size, num_threads_in_multithreaded,
seed_for_shuffle, return_incomplete, shuffle, infinite)
self.patch_size = patch_size
# ADDED: in_channels
self.num_modalities = len(in_channels) # 4
self.in_channels = [channel_indices[i] for i in in_channels]
self.indices = list(range(len(data)))
@staticmethod
def load_patient(patient):
data = np.load(patient + ".npy", mmap_mode="r")
metadata = load_pickle(patient + ".pkl")
return data, metadata
def generate_train_batch(self):
# DataLoader has its own methods for selecting what patients to use next, see its Documentation
idx = self.get_indices()
patients_for_batch = [self._data[i] for i in idx]
# initialize empty array for data and seg
data = np.zeros((self.batch_size, self.num_modalities, *self.patch_size), dtype=np.float32)
seg = np.zeros((self.batch_size, 1, *self.patch_size), dtype=np.float32)
metadata = []
patient_names = []
# iterate over patients_for_batch and include them in the batch
for i, j in enumerate(patients_for_batch):
patient_data, patient_metadata = self.load_patient(j)
# patient data is a memmap. If we extract just one slice then just this one slice will be read from the
# disk, so no worries!
# we may need to add this for 2d models!
# slice_idx = np.random.choice(patient_data.shape[1])
# patient_data = patient_data[:, slice_idx]
# this will only pad patient_data if its shape is smaller than self.patch_size
patient_data = pad_nd_image(patient_data, self.patch_size)
# now random crop to self.patch_size
# crop expects the data to be (b, c, x, y, z) but patient_data is (c, x, y, z) so we need to add one
# dummy dimension in order for it to work (@Todo, could be improved)
# ADDED: channel selector
patient_data, patient_seg = crop(patient_data[self.in_channels][None], patient_data[-1:][None],
self.patch_size, crop_type="random")
data[i] = patient_data[0]
seg[i] = patient_seg[0]
metadata.append(patient_metadata)
patient_names.append(j)
return {'data': data, 'seg': seg, 'metadata': metadata, 'names': patient_names}