1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ #
19
+
20
+ try :
21
+ import pickle
22
+ except ImportError :
23
+ import cPickle as pickle
24
+
25
+ import numpy as np
26
+ import os
27
+ import sys
28
+
29
+
30
+ def load_dataset (filepath ):
31
+ with open (filepath , 'rb' ) as fd :
32
+ try :
33
+ cifar10 = pickle .load (fd , encoding = 'latin1' )
34
+ except TypeError :
35
+ cifar10 = pickle .load (fd )
36
+ image = cifar10 ['data' ].astype (dtype = np .uint8 )
37
+ image = image .reshape ((- 1 , 3 , 32 , 32 ))
38
+ label = np .asarray (cifar10 ['labels' ], dtype = np .uint8 )
39
+ label = label .reshape (label .size , 1 )
40
+ return image , label
41
+
42
+
43
+ def load_train_data (dir_path = '/tmp/cifar-10-batches-py' , num_batches = 5 ): # need to save to specific local directories
44
+ labels = []
45
+ batchsize = 10000
46
+ images = np .empty ((num_batches * batchsize , 3 , 32 , 32 ), dtype = np .uint8 )
47
+ for did in range (1 , num_batches + 1 ):
48
+ fname_train_data = dir_path + "/data_batch_{}" .format (did )
49
+ image , label = load_dataset (check_dataset_exist (fname_train_data ))
50
+ images [(did - 1 ) * batchsize :did * batchsize ] = image
51
+ labels .extend (label )
52
+ images = np .array (images , dtype = np .float32 )
53
+ labels = np .array (labels , dtype = np .int32 )
54
+ return images , labels
55
+
56
+
57
+ def load_test_data (dir_path = '/tmp/cifar-10-batches-py' ): # need to save to specific local directories
58
+ images , labels = load_dataset (check_dataset_exist (dir_path + "/test_batch" ))
59
+ return np .array (images , dtype = np .float32 ), np .array (labels , dtype = np .int32 )
60
+
61
+
62
+ def check_dataset_exist (dirpath ):
63
+ if not os .path .exists (dirpath ):
64
+ print (
65
+ 'Please download the cifar10 dataset.'
66
+ )
67
+ sys .exit (0 )
68
+ return dirpath
69
+
70
+
71
+ def normalize (train_x , val_x ):
72
+ mean = [0.4914 , 0.4822 , 0.4465 ]
73
+ std = [0.2023 , 0.1994 , 0.2010 ]
74
+ train_x /= 255
75
+ val_x /= 255
76
+ for ch in range (0 , 2 ):
77
+ train_x [:, ch , :, :] -= mean [ch ]
78
+ train_x [:, ch , :, :] /= std [ch ]
79
+ val_x [:, ch , :, :] -= mean [ch ]
80
+ val_x [:, ch , :, :] /= std [ch ]
81
+ return train_x , val_x
82
+
83
+ def load (dir_path ):
84
+ train_x , train_y = load_train_data (dir_path )
85
+ val_x , val_y = load_test_data (dir_path )
86
+ train_x , val_x = normalize (train_x , val_x )
87
+ train_y = train_y .flatten ()
88
+ val_y = val_y .flatten ()
89
+ return train_x , train_y , val_x , val_y
0 commit comments