Skip to content

Commit 57bc2ef

Browse files
committed
Update datasets.py
1 parent 5b31be3 commit 57bc2ef

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

datasets.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,23 @@ def build_dataset(
3434
dataset["channels"] = ds_train.output_shapes["image"][-1].value
3535

3636
ds_train = ds_train.shuffle(1024).repeat()
37-
ds_train = ds_train.map(lambda data: _parse_function(data,shape,dataset["num_classes"]))
37+
ds_train = ds_train.map(lambda data: _parse_function(data, shape, dataset["num_classes"], dataset["channels"]))
3838
dataset["train"] = ds_train.batch(train_batch_size)
3939

4040
ds_test = ds_test.shuffle(1024).repeat()
41-
ds_test = ds_test.map(lambda data: _parse_function(data,shape,dataset["num_classes"]))
41+
ds_test = ds_test.map(lambda data: _parse_function(data, shape, dataset["num_classes"], dataset["channels"]))
4242
dataset["test"] = ds_test.batch(valid_batch_size)
4343

4444
return dataset
4545

46-
def _parse_function(data,shape,num_classes):
46+
def _parse_function(data, shape, num_classes, channels):
4747
height, width = shape
4848
image = data["image"]
4949
label = data["label"]
5050

5151
image = tf.cast(image, dtype=tf.float32)
5252
image = tf.image.resize_images(image, (height,width))
53-
image = tf.reshape(image, (height,width, 1))
53+
image = tf.reshape(image, (height,width, channels))
5454
image = image / 255.0
5555
image = image - 0.5
5656
image = image * 2.0

0 commit comments

Comments
 (0)