Skip to content

Commit bc93719

Browse files
authored
Merge pull request #1247 from Junranus/dev-postgresql
2 parents 3f4c53c + 0b9a984 commit bc93719

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

examples/healthcare/data/diaret.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
try:
20+
import pickle
21+
except ImportError:
22+
import cPickle as pickle
23+
24+
import os
25+
import sys
26+
import random
27+
import numpy as np
28+
from PIL import Image
29+
30+
31+
# need to save to specific local directories
32+
def load_data(dir_path="/tmp/diaret", resize_size=(128, 128)):
33+
dir_path = check_dataset_exist(dirpath=dir_path)
34+
image_sets = {label: load_image_path(os.listdir(os.path.join(dir_path, label)))
35+
for label in os.listdir(dir_path)}
36+
images, labels = [], []
37+
for label in os.listdir(dir_path):
38+
image_names = load_image_path(os.listdir(os.path.join(dir_path, label)))
39+
label_images = [np.array(Image.open(os.path.join(dir_path, label, img_name))\
40+
.resize(resize_size).convert("RGB")).transpose(2, 0, 1)
41+
for img_name in image_names]
42+
images.extend(label_images)
43+
labels.extend([int(label)] * len(label_images))
44+
45+
images = np.array(images, dtype=np.float32)
46+
labels = np.array(labels, dtype=np.int32)
47+
return images, labels
48+
49+
50+
def load_image_path(image_pths):
51+
allowed_image_format = ['png', 'jpg', 'jpeg']
52+
return list(filter(lambda pth: pth.rsplit('.')[-1].lower() in allowed_image_format,
53+
image_pths))
54+
55+
56+
def check_dataset_exist(dirpath):
57+
if not os.path.exists(dirpath):
58+
print(
59+
'Please download the Diabetic Retinopathy dataset first'
60+
)
61+
sys.exit(0)
62+
return dirpath
63+
64+
65+
def normalize(train_x, val_x):
66+
mean = [0.5339, 0.4180, 0.4460] # mean for dataset
67+
std = [0.3329, 0.2637, 0.2761] # std for dataset
68+
train_x /= 255
69+
val_x /= 255
70+
for ch in range(0, 2):
71+
train_x[:, ch, :, :] -= mean[ch]
72+
train_x[:, ch, :, :] /= std[ch]
73+
val_x[:, ch, :, :] -= mean[ch]
74+
val_x[:, ch, :, :] /= std[ch]
75+
return train_x, val_x
76+
77+
78+
def train_test_split(x, y, val_ratio=0.2):
79+
indices = list(range(len(x)))
80+
val_indices = list(random.sample(indices, int(val_ratio*len(x))))
81+
train_indices = list(set(indices) - set(val_indices))
82+
return x[train_indices], y[train_indices], x[val_indices], y[val_indices]
83+
84+
85+
def load(dir_path):
86+
x, y = load_data(dir_path=dir_path)
87+
train_x, train_y, val_x, val_y = train_test_split(x, y)
88+
train_x, val_x = normalize(train_x, val_x)
89+
return train_x, train_y, val_x, val_y

0 commit comments

Comments
 (0)