@@ -106,6 +106,12 @@ def __init__(
106
106
self .batch_size = batch_size
107
107
self .model_version = model_dict ["model" ]["version" ]
108
108
109
+ self .nodata_mask = np .ones (shape = (img .shape [0 ], img .shape [1 ], 1 ), dtype = np .uint8 )
110
+ if self .nodata_value is not None :
111
+ # create image nodata mask
112
+ # needs to happen before normalization
113
+ self .nodata_mask [np .all (img == self .nodata_value , axis = 2 )] = 0
114
+
109
115
# adjust band order and normalize image
110
116
self .img = self .normalize (
111
117
img = self .adjust_band_order (
@@ -188,13 +194,13 @@ def _valid(self, invalid_buffer):
188
194
class_dict = {"reclass_value_from" : [0 , 1 , 2 ], "reclass_value_to" : [1 , 0 , 0 ]}
189
195
valid = reclassify (self .csm , class_dict )
190
196
197
+ if self .nodata_value is not None :
198
+ # add image nodata pixels to valid pixel mask
199
+ valid [self .nodata_mask == self .nodata_value ] = 0
200
+
191
201
# dilate the inverse of the binary valid pixel mask (invalid=0)
192
202
# this effectively buffers the invalid pixels
193
203
valid_i = ~ valid .astype (bool )
194
204
valid = (~ scipy .ndimage .binary_dilation (valid_i , iterations = invalid_buffer ).astype (bool )).astype (np .uint8 )
195
205
196
- if self .nodata_value is not None :
197
- # add image nodata pixels to valid pixel mask
198
- valid [np .all (self .img == self .nodata_value , axis = 2 )] = 0
199
-
200
206
return valid
0 commit comments