-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtransforms.py
91 lines (68 loc) · 2.61 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Collection of custom PyTorch friendly transform classes.
These transform classes can be called during training loops to perform
data augmentation.
Author: Titouan Lorieul <[email protected]>
Theo Larcher <[email protected]>
"""
import numpy as np
import torch
from torchvision import transforms
class RGBDataTransform:
def __call__(self, data):
return transforms.functional.to_tensor(data)
class NIRDataTransform:
def __call__(self, data):
data = np.tile(data[:, :, None], 3)
data = transforms.functional.to_tensor(data)
return data
class RasterDataTransform:
def __init__(self, mu, sigma, resize=None):
self.mu = np.asarray(mu, dtype=np.float32)[:, None, None]
self.sigma = np.asarray(sigma, dtype=np.float32)[:, None, None]
self.resize = resize
def __call__(self, data):
data = torch.as_tensor(data, dtype=torch.float32)
data = (data - self.mu) / self.sigma
if self.resize:
data = transforms.functional.resize(data, self.resize)
return data
class TemperatureDataTransform(RasterDataTransform):
def __init__(self):
mu = [-12.0, 1.0, 1.0]
sigma = [40.0, 22.0, 51.0]
super().__init__(mu, sigma, resize=256)
class PrecipitationDataTransform(RasterDataTransform):
def __init__(self):
mu = [43.0, -1.0, 4.0]
sigma = [3410.0, 177.0, 139.0]
super().__init__(mu, sigma, resize=256)
class PedologicalDataTransform(RasterDataTransform):
def __init__(self):
mu = [-1.0, 31.0, -1.0]
sigma = [526.0, 68.0, 88.0]
super().__init__(mu, sigma, resize=256)
class DataAugmentation(transforms.Compose):
def __init__(self, train):
if train:
super().__init__([
transforms.RandomRotation(degrees=45, fill=1),
transforms.RandomCrop(size=224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
])
else:
super().__init__([
transforms.CenterCrop(size=224),
])
class Normalization(transforms.Normalize):
def __init__(self, num_modalities):
super().__init__(
mean=[0.485, 0.456, 0.406] * num_modalities,
std=[0.229, 0.224, 0.225] * num_modalities,
)
class PreprocessRGBTemperatureData:
def __call__(self, data):
rgb_data, temp_data = data["rgb"], data["environmental_patches"]
rgb_data = RGBDataTransform()(rgb_data)
temp_data = TemperatureDataTransform()(temp_data)
return torch.concat((rgb_data, temp_data))