Skip to content

Commit 4575595

Browse files
authored
Merge pull request #1238 from zmeihui/24-12-2-dev
Add a sample test dataset for the TED CT Detection
2 parents 87b4cfd + f053d79 commit 4575595

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

examples/healthcare/data/cifar10.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
try:
21+
import pickle
22+
except ImportError:
23+
import cPickle as pickle
24+
25+
import numpy as np
26+
import os
27+
import sys
28+
29+
30+
def load_dataset(filepath):
31+
with open(filepath, 'rb') as fd:
32+
try:
33+
cifar10 = pickle.load(fd, encoding='latin1')
34+
except TypeError:
35+
cifar10 = pickle.load(fd)
36+
image = cifar10['data'].astype(dtype=np.uint8)
37+
image = image.reshape((-1, 3, 32, 32))
38+
label = np.asarray(cifar10['labels'], dtype=np.uint8)
39+
label = label.reshape(label.size, 1)
40+
return image, label
41+
42+
43+
def load_train_data(dir_path='/tmp/cifar-10-batches-py', num_batches=5): # need to save to specific local directories
44+
labels = []
45+
batchsize = 10000
46+
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
47+
for did in range(1, num_batches + 1):
48+
fname_train_data = dir_path + "/data_batch_{}".format(did)
49+
image, label = load_dataset(check_dataset_exist(fname_train_data))
50+
images[(did - 1) * batchsize:did * batchsize] = image
51+
labels.extend(label)
52+
images = np.array(images, dtype=np.float32)
53+
labels = np.array(labels, dtype=np.int32)
54+
return images, labels
55+
56+
57+
def load_test_data(dir_path='/tmp/cifar-10-batches-py'): # need to save to specific local directories
58+
images, labels = load_dataset(check_dataset_exist(dir_path + "/test_batch"))
59+
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
60+
61+
62+
def check_dataset_exist(dirpath):
63+
if not os.path.exists(dirpath):
64+
print(
65+
'Please download the cifar10 dataset.'
66+
)
67+
sys.exit(0)
68+
return dirpath
69+
70+
71+
def normalize(train_x, val_x):
72+
mean = [0.4914, 0.4822, 0.4465]
73+
std = [0.2023, 0.1994, 0.2010]
74+
train_x /= 255
75+
val_x /= 255
76+
for ch in range(0, 2):
77+
train_x[:, ch, :, :] -= mean[ch]
78+
train_x[:, ch, :, :] /= std[ch]
79+
val_x[:, ch, :, :] -= mean[ch]
80+
val_x[:, ch, :, :] /= std[ch]
81+
return train_x, val_x
82+
83+
def load(dir_path):
84+
train_x, train_y = load_train_data(dir_path)
85+
val_x, val_y = load_test_data(dir_path)
86+
train_x, val_x = normalize(train_x, val_x)
87+
train_y = train_y.flatten()
88+
val_y = val_y.flatten()
89+
return train_x, train_y, val_x, val_y

0 commit comments

Comments
 (0)