@@ -34,23 +34,23 @@ def build_dataset(
34
34
dataset ["channels" ] = ds_train .output_shapes ["image" ][- 1 ].value
35
35
36
36
ds_train = ds_train .shuffle (1024 ).repeat ()
37
- ds_train = ds_train .map (lambda data : _parse_function (data ,shape ,dataset ["num_classes" ]))
37
+ ds_train = ds_train .map (lambda data : _parse_function (data , shape , dataset ["num_classes" ], dataset [ "channels " ]))
38
38
dataset ["train" ] = ds_train .batch (train_batch_size )
39
39
40
40
ds_test = ds_test .shuffle (1024 ).repeat ()
41
- ds_test = ds_test .map (lambda data : _parse_function (data ,shape ,dataset ["num_classes" ]))
41
+ ds_test = ds_test .map (lambda data : _parse_function (data , shape , dataset ["num_classes" ], dataset [ "channels " ]))
42
42
dataset ["test" ] = ds_test .batch (valid_batch_size )
43
43
44
44
return dataset
45
45
46
- def _parse_function (data ,shape ,num_classes ):
46
+ def _parse_function (data , shape , num_classes , channels ):
47
47
height , width = shape
48
48
image = data ["image" ]
49
49
label = data ["label" ]
50
50
51
51
image = tf .cast (image , dtype = tf .float32 )
52
52
image = tf .image .resize_images (image , (height ,width ))
53
- image = tf .reshape (image , (height ,width , 1 ))
53
+ image = tf .reshape (image , (height ,width , channels ))
54
54
image = image / 255.0
55
55
image = image - 0.5
56
56
image = image * 2.0
0 commit comments