-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathonnx_utils.py
370 lines (317 loc) · 16.8 KB
/
onnx_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import tensorflow as tf
import numpy as np
import os
import cv2
import onnxruntime as rt
from onnx import version_converter
from onnx import ModelProto
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import onnx
#====================== onnx evaluation =======================#
# parameters to be used for the normalization step if 'torch' is selected as preproc_mode
# these numbers are coming from 'imagenet' dataset
torch_means = [np.float32(0.485),np.float32(0.456),np.float32(0.406)]
torch_std = [np.float32(0.224), np.float32(0.224), np.float32(0.224)]
caffe_mean = [np.float32(103.939), np.float32(116.779), np.float32(123.68)]
def get_model(onnx_model_path):
'''
get_model(onnx_model_path)
function reads onnx model and returns it along with the onnxruntime inference session
input:
onnx_model_path : (str) full path to the onnx model "xyz/model.onnx"
outputs:
onx: onnx model
sess: onnxruntime inference session object to perform the inference
'''
onx = ModelProto()
with open(onnx_model_path, mode = 'rb') as f:
content = f.read()
onx.ParseFromString(content)
sess = rt.InferenceSession(onnx_model_path)
return onx, sess
def predict_onnx(sess, data):
'''
predict_onnx(sess, data)
function runs the inference using the provided data and the onnxruntime inference session and returns the predictions
inputs:
sess: onnxruntime inference session obtained from the get_model(onnx_model_path) function
data: input data to run the inference on (numpy array)
the data should have the same input shape as the model
outputs:
onx_pred: prediction results of the onnx model on the provided input data'''
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
onx_pred = sess.run([label_name], {input_name: data.astype(np.float32)})[0]
return onx_pred
def plot_confusion_matrix(cm, class_labels, save_path = "float_model", test_accuracy = '_'):
'''
plot_confusion_matrix(cm, class_labels, save_path, test_accuracy)
function prints the confusion_matrix, plots it, and save the plot as a png file
inputs:
cm: (nxn) shape confusion matrix, where n is num_classes
class_labels: (an array of str) list of class labels
save_path: (str) path where to save the resulting confusion matrix plot
test_accuracy: (str) containing the accuracy to print on the confusion matrix title.
'''
print(f'confusion_matrix : \n{cm}')
cm_normalized = [element/sum(row) for element, row in zip([row for row in cm], cm)]
cm_normalized = np.array(cm_normalized)
plt.figure(figsize = (4,4))
disp = ConfusionMatrixDisplay(cm_normalized)
disp.plot(cmap = "Blues")
plt.title(f'Model_accuracy : {test_accuracy}', fontsize = 10)
plt.tight_layout(pad=3)
plt.ylabel("True labels")
plt.xlabel("Predicted labels")
plt.xticks(np.arange(0, len(class_labels)), class_labels)
plt.yticks(np.arange(0, len(class_labels)), class_labels)
plt.rcParams.update({'font.size': 14})
plt.savefig(f'{save_path}_confusion-matrix.png')
plt.show()
def get_preprocessed_image(image_path, width = 128, height = 128, preproc_mode = 'tf', interpolation = 'bilinear'):
'''
get_preprocessed_image(image_path, width, height, preproc_mode, interpolation)
function takes the path of the image as input and returns the preprocessed image in the form of a numpy array of size (nchw)
inputs:
image_path: (str) path to the image file to be loaded and preprocessed
width: (int) width of the output image (if original width is different the image will be resized)
height: (int) height of the output image (if original width is different the image will be resized)
preproc_mode: (str) preprocessing mode, 'caffe', 'tf' or 'torch'
'caffe' will bring an image in 'bgr' model and subtract the mean of the channels [103.939, 116.779, 123.68] to make zero mean
'tf' will bring the image in 'rgb' model and the range of [-1,+1], preproc_image_pixels = -1 + (image_pixels/127.5)
'torch' will bring the image in 'rgb' mode and apply double normalization based on imagenet data,
_image_pixels = (image_pixels/255.)
preproc_image_pixels = (_image_pixels - torch_mean)/torch_std
interpolation: (str) interpolation method, supported values are 'bilinear', 'nearest'
'''
if preproc_mode.lower() == 'caffe':
# Load and preprocess the image
image = cv2.imread(image_path) # Load image in BGR format
image_resized = cv2.resize(image, (height,width)) - caffe_mean # Resize and zero-mean
# Convert image to float32
input_image = image_resized.astype(np.float32)
# Add a batch dimension (b, h, w, c)
img_array = np.expand_dims(input_image, axis=0)
elif preproc_mode.lower() in ['torch', 'tf']:
img = tf.keras.utils.load_img(image_path, grayscale = False, color_mode = 'rgb',
target_size = (width,height), interpolation=interpolation)
img_array = np.array([tf.keras.utils.img_to_array(img)])
if preproc_mode.lower() == 'tf':
img_array = -1 + img_array / 127.5
elif preproc_mode.lower() == 'torch':
img_array = img_array / 255.0
img_array = (img_array - torch_means)
img_array = img_array/ torch_std
else:
raise Exception('Only \'tf\' or \'torch\' preprocessings are supported.')
img_array = img_array.transpose((0,3,1,2))
return img_array
def evaluate_onnx_model(onnx_model_path, test_dir, img_width = 128, img_height = 128, save_path='float_onnx', preproc_mode = 'tf', interpolation = 'bilinear'):
'''
evaluate_onnx_model(onnx_model_path, test_dir, img_width, img_height, save_path, preproc_mode, interpolation)
function evaluates an onnx model at path onnx_model_path using the data available in test_dir, creates a confusion matrix and saves it at the save_path
inputs:
onnx_model_path: (str) path of the onnx model to evaluate
test_dir: (str) path to the test dataset (one subdir per class with all images in them belonging to that class)
img_widht: width of the image to be used for the preprocessing before passing to model
img_height: height of the image to be used for the preprocessing before passing to model
save_path: (str) path to save the confusion matrix (example: ./results_mobilenetv2/exports/model_name [without.png])
preproc_mode: (str) preprocessing mode, supported values are 'tf', 'torch'
interpolation: (str) interpolation type, supported types 'bilinear', 'nearest'
outputs:
test_acc: test accuracy (float number)
test_cm: test confusion matrix (nxn matrix)
'''
_, sess = get_model(onnx_model_path)
gt_labels = []
prd_labels = np.empty((0))
class_labels = sorted(os.listdir(test_dir))
for i in range(len(class_labels)):
class_label = class_labels[i]
for file in os.listdir(os.path.join(test_dir, class_label)):
gt_labels.append(i)
image_path = os.path.join(test_dir,class_label,file)
img = get_preprocessed_image(image_path, width = img_width, height = img_height,
preproc_mode = preproc_mode, interpolation = interpolation)
# predicting the results on the batch
pred = predict_onnx(sess, img).argmax(axis = 1)
prd_labels = np.concatenate((prd_labels, pred))
test_acc = round(accuracy_score(gt_labels, prd_labels), 6)
print(f'Evaluation Top 1 accuracy : {test_acc}')
test_cm = confusion_matrix(gt_labels,prd_labels)
plot_confusion_matrix(test_cm, class_labels = class_labels, save_path = save_path, test_accuracy = test_acc)
return test_acc, test_cm
# ============================ Onnx Quantization ============================== #
from datetime import datetime
import time
import onnxruntime.quantization as quantization
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static, CalibrationDataReader
def preprocess_image_batch(images_folder: str, height: int, width: int, preproc_mode = 'tf', interpolation = 'bilinear', size_limit=0):
"""
preprocess_image_batch(images_folder, height, width, preproc_mode, interpolation, size_limit)
function loads a batch of images and preprocess them
inputs:
images_folder: (str) path to folder storing images
height: (int) image height in pixels
width: (int) image width in pixels
preproc_mode: (str) preprocessing type, supported options are 'tf', 'torch'
interpolation: (str) interpolation method, supported values are 'bilinear', or 'nearest'
size_limit: (int) number of images to load. Default is 0 which means all images are picked.
ouputs:
bathc_data: a numpy array as a matrix characterizing multiple images
"""
image_names = os.listdir(images_folder)
if size_limit > 0 and len(image_names) >= size_limit:
batch_filenames = [image_names[i] for i in range(size_limit)]
else:
batch_filenames = image_names
unconcatenated_batch_data = []
for image_name in batch_filenames:
image_filepath = images_folder + "/" + image_name
img = get_preprocessed_image(image_filepath, width = width, height = height,
preproc_mode = preproc_mode, interpolation = interpolation)
unconcatenated_batch_data.append(img)
batch_data = np.concatenate(
np.expand_dims(unconcatenated_batch_data, axis=0), axis=0
)
return batch_data
class ImageDataReader(CalibrationDataReader):
def __init__(self, calibration_image_folder: str, model_path: str, preproc_mode: str, interpolation: str):
'''
ImageDataReader class creates a ImageDataReader object to pass to quantiz_static function for onnx model quantization
inputs:
calibration_image_folder: (str) a dataset to be used for performing the quantization (subset of the original training dataset)
model_path: (str) path of the model to be evaluated
preproc_mode: (str) preprocessing mode, supported 'tf' or 'torch'
interpolation: (str) interpolation method, supported values are 'bilinear' and 'nearest'
'''
self.enum_data = None
# Use inference session to get input shape.
session = rt.InferenceSession(model_path, None)
(_, _, height, width) = session.get_inputs()[0].shape
# Convert image to input data
self.nhwc_data_list = preprocess_image_batch(
calibration_image_folder, height, width, preproc_mode, interpolation, size_limit=0
)
self.input_name = session.get_inputs()[0].name
self.datasize = len(self.nhwc_data_list)
def get_next(self):
if self.enum_data is None:
self.enum_data = iter(
[{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list]
)
return next(self.enum_data, None)
def rewind(self):
self.enum_data = None
def onnx_benchmark(model_path):
session = rt.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
(_, _, height, width) = session.get_inputs()[0].shape
total = 0.0
runs = 10
input_data = np.zeros((1, 3, height, width), np.float32)
# Warming up
_ = session.run([], {input_name: input_data})
for i in range(runs):
start = time.perf_counter()
_ = session.run([], {input_name: input_data})
end = (time.perf_counter() - start) * 1000
total += end
# print(f"{end:.2f}ms")
total /= runs
print(f"Avg: {total:.2f}ms")
def quantize_onnx_model(input_model, calibration_dataset_path, preproc_mode = 'tf', interpolation='bilinear'):
'''
quantize_onnx_model(input_model, calibration_dataset_path, preproc_mode, interpolation)
function quantizes the input_model using the calibration_dataset_path using quantize-dequantize method
inputs:
input_model: (str) path of the onnx_model which is to be quantized using onnxruntime
calibration_dataset_path: (str) path to the calibration dataset which will be used to quantize the model (a subset of the training dataset)
Contains n sub-directories, where n is number of classes and each sub-directory contains all the iamges for the class
preproc_mode: (str) preprocessing mode, supported options are 'tf' or 'torch'
interpolation: (str) interpolation method: supported options are 'bilinear', and 'nearest'
'''
if not input_model.endswith('.onnx'):
raise Exception("Error! The model must be in onnx format")
# set the data reader pointing to the representative dataset
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(current_time + ' - ' + 'Prepare the data reader for the representative dataset...')
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(current_time + ' - ' + 'Found a model to be quantized: {}'.format(os.path.basename(input_model)))
# set the data reader pointing to the representative dataset
current_time = datetime.now().strftime("%H:%M:%S")
print(current_time + ' - ' + 'Prepare the data reader for the representative dataset...')
image_datareader = ImageDataReader(calibration_dataset_path, input_model, preproc_mode, interpolation)
# prepare quantized onnx model filename
quant_model = os.path.splitext(input_model)
if not calibration_dataset_path is None:
quant_model = quant_model[0] + '_QDQ_quant' + quant_model[1]
else:
quant_model = quant_model[0] + '_QDQ_fakequant' + quant_model[1]
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(current_time + ' - ' + 'Quantize the model {}, please wait...'.format(os.path.basename(input_model)))
# prepare quantized onnx model filename
quant_model = os.path.splitext(input_model)
if not calibration_dataset_path is None:
quant_model = quant_model[0] + '_QDQ_quant' + quant_model[1]
else:
quant_model = quant_model[0] + '_QDQ_fakequant' + quant_model[1]
# Calibrate and quantize model
# Turn off model optimization during quantization
infer_model = os.path.splitext(input_model)
infer_model = infer_model[0] + '_infer' + infer_model[1]
quantization.quant_pre_process(input_model_path=input_model, output_model_path=infer_model)
quantize_static(
infer_model,
quant_model,
image_datareader,
quant_format=QuantFormat.QDQ,
per_channel=True,
weight_type=QuantType.QInt8,
activation_type = QuantType.QInt8,
optimize_model=True,
reduce_range=True
)
print(current_time + ' - ' + '{} model has been created'.format(os.path.basename(quant_model)))
print("benchmarking fp32 model...")
onnx_benchmark(input_model)
print("benchmarking int8 model...")
onnx_benchmark(quant_model)
# delete the temp files
os.remove(infer_model)
def change_opset(input_model, target_opset=14):
if not input_model.endswith('.onnx'):
raise Exception("Error! The model must be in onnx format")
model = onnx.load(input_model)
# Check the current opset version
current_opset = model.opset_import[0].version
if current_opset == target_opset:
print(f"The model is already using opset {target_opset}")
return input_model
# Modify the opset version in the model
converted_model = version_converter.convert_version(model, target_opset)
temp_model_path = input_model+ '.temp'
onnx.save(converted_model, temp_model_path)
# Load the modified model using ONNX Runtime Check if the model is valid
session = rt.InferenceSession(temp_model_path)
try:
session.get_inputs()
except Exception as e:
print(f"An error occurred while loading the modified model: {e}")
return
# Replace the original model file with the modified model
os.replace(temp_model_path, input_model)
print(f"The model has been converted to opset {target_opset} and saved at the same location.")
return input_model