Skip to content

Commit 1d288f8

Browse files
authored
Merge pull request #134 from pipeless-ai/yolo-world
Support several inference outputs for ONNX Runtime + YOLO World example
2 parents 99676e4 + f15d183 commit 1d288f8

File tree

11 files changed

+267
-25
lines changed

11 files changed

+267
-25
lines changed

examples/onnx-candy/post-process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import cv2
33

44
def hook(frame_data, _):
5-
inference_results = frame_data["inference_output"]
5+
inference_results = frame_data["inference_output"].get("output1", [])
66
candy_image = inference_results[0] # Remove batch axis
77
candy_image = np.clip(candy_image, 0, 255)
88
candy_image = candy_image.transpose(1,2,0).astype("uint8")

examples/onnx-yolo/post-process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def hook(frame_data, _):
66
model_output = frame_data['inference_output']
77
if len(model_output) > 0:
88
yolo_input_shape = (640, 640, 3) # h,w,c
9-
boxes, scores, class_ids = postprocess_yolo(frame.shape, yolo_input_shape, model_output)
9+
boxes, scores, class_ids = postprocess_yolo(frame.shape, yolo_input_shape, model_output.get("output0", []))
1010
class_labels = [yolo_classes[id] for id in class_ids]
1111
for i in range(len(boxes)):
1212
draw_bbox(frame, boxes[i], class_labels[i], scores[i], color_palette[class_ids[i]])

examples/yolo-world/post-process.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import cv2
2+
import numpy as np
3+
4+
def hook(frame_data, _):
5+
frame = frame_data['original']
6+
model_output = frame_data['inference_output']
7+
if len(model_output) > 0:
8+
yolo_input_shape = (640, 640, 3) # h,w,c
9+
boxes, scores, class_ids = postprocess_yolo_world(frame.shape, yolo_input_shape, model_output)
10+
class_labels = [yolo_classes[int(id)] for id in class_ids]
11+
for i in range(len(boxes)):
12+
draw_bbox(frame, boxes[i], class_labels[i], scores[i], color_palette[int(class_ids[i])])
13+
14+
frame_data['modified'] = frame
15+
16+
#################################################
17+
# Util functions to make the hook more readable #
18+
#################################################
19+
yolo_classes = ['hard hat', 'gloves', 'protective boot', 'reflective vest', 'person']
20+
color_palette = np.random.uniform(0, 255, size=(len(yolo_classes), 3))
21+
22+
def draw_bbox(image, box, label='', score=None, color=(255, 0, 255), txt_color=(255, 255, 255)):
23+
lw = max(round(sum(image.shape) / 2 * 0.003), 2)
24+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
25+
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
26+
if label:
27+
tf = max(lw - 1, 1) # font thickness
28+
w, h = cv2.getTextSize(str(label), 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
29+
outside = p1[1] - h >= 3
30+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
31+
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
32+
if score is not None:
33+
cv2.putText(image, f'{label} - {score}', (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
34+
0, lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
35+
else:
36+
cv2.putText(image, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
37+
0, lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
38+
39+
def postprocess_yolo_world(original_frame_shape, resized_img_shape, output):
40+
original_height, original_width, _ = original_frame_shape
41+
resized_height, resized_width, _ = resized_img_shape
42+
43+
boxes = np.array(output['boxes'][0])
44+
classes = np.array(output['labels'][0])
45+
scores = np.array(output['scores'][0])
46+
47+
# Filter negative indexes
48+
neg_indexes_classes = np.where(classes < 0)[0]
49+
neg_indexes_scores = np.where(scores < 0)[0]
50+
neg_indexes = np.concatenate((neg_indexes_classes, neg_indexes_scores))
51+
52+
mask = np.ones(classes.shape, dtype=bool)
53+
mask[neg_indexes] = False
54+
55+
boxes = boxes[mask]
56+
classes = classes[mask]
57+
scores = scores[mask]
58+
59+
# arrays to accumulate the results
60+
result_boxes = []
61+
result_classes = []
62+
result_scores = []
63+
64+
# Calculate the scaling factors for the bounding box coordinates
65+
if original_height > original_width:
66+
scale_factor = original_height / resized_height
67+
else:
68+
scale_factor = original_width / resized_width
69+
70+
# Resize the output boxes
71+
for i, score in enumerate(scores):
72+
if score < 0.05: # apply confidence threshold
73+
continue
74+
if not score < 1:
75+
continue # Remove bad predictions that return a score of 1.0
76+
77+
x1, y1, x2, y2 = boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]
78+
79+
## Calculate the scaled coordinates of the bounding box
80+
## the original image was padded to be square
81+
if original_height > original_width:
82+
# we added pad on the width
83+
pad = (resized_width - original_width / scale_factor) // 2
84+
x1 = int((x1 - pad) * scale_factor)
85+
y1 = int(y1 * scale_factor)
86+
x2 = int((x2 - pad) * scale_factor)
87+
y2 = int(y2 * scale_factor)
88+
else:
89+
# we added pad on the height
90+
pad = (resized_height - original_height / scale_factor) // 2
91+
x1 = int(x1 * scale_factor)
92+
y1 = int((y1 - pad) * scale_factor)
93+
x2 = int(x2 * scale_factor)
94+
y2 = int((y2 - pad) * scale_factor)
95+
96+
result_classes.append(classes[i])
97+
result_scores.append(score)
98+
result_boxes.append([x1, y1, x2, y2])
99+
100+
return result_boxes, result_scores, result_classes

examples/yolo-world/pre-process.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import cv2
2+
import numpy as np
3+
4+
def is_cuda_available():
5+
return cv2.cuda.getCudaEnabledDeviceCount() > 0
6+
7+
"""
8+
Resize and pad image. Uses CUDA when available
9+
"""
10+
def resize_and_pad(frame, target_dim, pad_top, pad_bottom, pad_left, pad_right):
11+
target_height, target_width = target_dim
12+
if is_cuda_available():
13+
# FIXME: due to the memory allocation here could be even slower than running on CPU. We must provide the frame from GPU memory to the hook
14+
frame_gpu = cv2.cuda_GpuMat(frame)
15+
resized_frame_gpu = cv2.cuda.resize(frame_gpu, (target_width, target_height), interpolation=cv2.INTER_CUBIC)
16+
padded_frame_gpu = cv2.cuda.copyMakeBorder(resized_frame_gpu, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
17+
result = padded_frame_gpu.download()
18+
return result
19+
else:
20+
resized_frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC)
21+
padded_frame = cv2.copyMakeBorder(resized_frame, pad_top, pad_bottom, pad_left, pad_right,
22+
borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0))
23+
return padded_frame
24+
25+
def resize_with_padding(frame, target_dim):
26+
target_height, target_width, _ = target_dim
27+
frame_height, frame_width, _ = frame.shape
28+
29+
width_ratio = target_width / frame_width
30+
height_ratio = target_height / frame_height
31+
# Choose the minimum scaling factor to maintain aspect ratio
32+
scale_factor = min(width_ratio, height_ratio)
33+
# Calculate new dimensions after resizing
34+
new_width = int(frame_width * scale_factor)
35+
new_height = int(frame_height * scale_factor)
36+
# Calculate padding dimensions
37+
pad_width = (target_width - new_width) // 2
38+
pad_height = (target_height - new_height) // 2
39+
40+
padded_image = resize_and_pad(frame, (new_height, new_width), pad_height, pad_height, pad_width, pad_width)
41+
return padded_image
42+
43+
def hook(frame_data, _):
44+
frame = frame_data["original"].view()
45+
yolo_input_shape = (640, 640, 3) # h,w,c
46+
frame = resize_with_padding(frame, yolo_input_shape)
47+
frame = np.array(frame) / 255.0 # Normalize pixel values
48+
frame = np.transpose(frame, axes=(2,0,1)) # Convert to c,h,w
49+
inference_inputs = frame.astype("float32")
50+
frame_data['inference_input'] = inference_inputs

examples/yolo-world/process.json

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"runtime": "onnx",
3+
"model_uri": "https://pipeless-public.s3.eu-west-3.amazonaws.com/yolow-l-ppe.onnx",
4+
"inference_params": {
5+
"execution_provider": "cpu"
6+
}
7+
}

pipeless/Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pipeless/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pipeless-ai"
3-
version = "1.10.0"
3+
version = "1.11.0"
44
edition = "2021"
55
authors = ["Miguel A. Cabrera Minagorri"]
66
description = "An open-source computer vision framework to build and deploy applications in minutes"

pipeless/src/data.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ pub enum UserData {
1414
Dictionary(Vec<(String, UserData)>),
1515
}
1616

17+
pub enum InferenceOutput {
18+
Default(ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>),
19+
OnnxInferenceOutput(crate::stages::inference::onnx::OnnxInferenceOutput)
20+
}
21+
1722
pub struct RgbFrame {
1823
uuid: uuid::Uuid,
1924
original: ndarray::Array3<u8>,
@@ -26,7 +31,8 @@ pub struct RgbFrame {
2631
fps: u8,
2732
input_ts: f64, // epoch in seconds
2833
inference_input: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
29-
inference_output: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
34+
// We can convert the output into an arrayview since the user does not need to modify it and the inference runtimes returns a view, so we avoid a copy
35+
inference_output: InferenceOutput,
3036
pipeline_id: uuid::Uuid,
3137
user_data: UserData,
3238
frame_number: u64,
@@ -47,7 +53,7 @@ impl RgbFrame {
4753
pts, dts, duration, fps,
4854
input_ts,
4955
inference_input: ndarray::ArrayBase::zeros(ndarray::IxDyn(&[0])),
50-
inference_output: ndarray::ArrayBase::zeros(ndarray::IxDyn(&[0])),
56+
inference_output: InferenceOutput::Default(ndarray::ArrayBase::zeros(ndarray::IxDyn(&[0]))),
5157
pipeline_id,
5258
user_data: UserData::Empty,
5359
frame_number,
@@ -62,7 +68,7 @@ impl RgbFrame {
6268
pts: u64, dts: u64, duration: u64,
6369
fps: u8, input_ts: f64,
6470
inference_input: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
65-
inference_output: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
71+
inference_output: InferenceOutput,
6672
pipeline_id: &str,
6773
user_data: UserData, frame_number: u64,
6874
) -> Self {
@@ -122,13 +128,13 @@ impl RgbFrame {
122128
pub fn get_inference_input(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
123129
&self.inference_input
124130
}
125-
pub fn get_inference_output(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
131+
pub fn get_inference_output(&self) -> &InferenceOutput{
126132
&self.inference_output
127133
}
128134
pub fn set_inference_input(&mut self, input_data: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>) {
129135
self.inference_input = input_data;
130136
}
131-
pub fn set_inference_output(&mut self, output_data: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>) {
137+
pub fn set_inference_output(&mut self, output_data: InferenceOutput) {
132138
self.inference_output = output_data;
133139
}
134140
pub fn get_pipeline_id(&self) -> &uuid::Uuid {
@@ -180,7 +186,7 @@ impl Frame {
180186
Frame::RgbFrame(frame) => frame.get_inference_input()
181187
}
182188
}
183-
pub fn get_inference_output(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
189+
pub fn get_inference_output(&self) -> &InferenceOutput {
184190
match self {
185191
Frame::RgbFrame(frame) => frame.get_inference_output()
186192
}
@@ -190,7 +196,7 @@ impl Frame {
190196
Frame::RgbFrame(frame) => { frame.set_inference_input(input_data); },
191197
}
192198
}
193-
pub fn set_inference_output(&mut self, output_data: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>) {
199+
pub fn set_inference_output(&mut self, output_data: InferenceOutput) {
194200
match self {
195201
Frame::RgbFrame(frame) => { frame.set_inference_output(output_data); },
196202
}

pipeless/src/events.rs

+2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ pub fn publish_new_frame_change_event_sync(
203203
) {
204204
let new_frame_event = Event::new_frame_change(frame);
205205
// By using try_send frames are discarded when the channel is full
206+
// However, note this does not produce a fluid output video. For that instead of discarding the frame
207+
// we would need to send it to the output without processing it
206208
if let Err(err) = bus_sender.try_send(new_frame_event) {
207209
debug!("Discarding frame: {}", err);
208210
}

pipeless/src/stages/inference/onnx.rs

+41-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
use std::collections::HashMap;
2+
13
use log::{error, warn};
24
use ort;
35

46
use crate as pipeless;
57

8+
pub type OnnxInferenceOutput = HashMap<String, ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>>;
9+
610
pub struct OnnxSessionParams {
711
stage_name: String, // Name o the stage this session belongs to
812
execution_provider: String, //The user has to provide the execution provider
@@ -163,10 +167,43 @@ impl super::session::SessionTrait for OnnxSession {
163167
match self.session.run_with_binding(&io_bindings) {
164168
Ok(()) => {
165169
let outputs = io_bindings.outputs().unwrap();
166-
// TODO: iterate over the outputs hashmap to return all the model outputs not just the first
167-
let output = outputs[&self.session.outputs[0].name].try_extract().unwrap();
168-
let output_ndarray = output.view().to_owned();
169-
frame.set_inference_output(output_ndarray);
170+
let mut frame_inference_output = OnnxInferenceOutput::new();
171+
for (output_name, output_value) in outputs {
172+
// FIXME: the extract code is very unelegant. The extract can return several different numric types depending on the model used
173+
// and there is not a number wrapper that we can apply, so we have to check type by type
174+
match output_value.try_extract() {
175+
Ok(output) => {
176+
//let output = output.view().map(|v: &_| v.into());
177+
// FIXME: we can use an arrayview for the inference output instead of owned array base to avoid copying here.
178+
let output_ndarray = output.view().to_owned();
179+
frame_inference_output.insert(output_name, output_ndarray);
180+
},
181+
Err(_err) => {
182+
// Try to convert from i64 since sometimes the models do not return floats
183+
match output_value.try_extract() {
184+
Ok(output) => {
185+
// FIXME: this copies the array twice, first to_owned and then the mapv
186+
let output_ndarray: ndarray::ArrayBase<ndarray::OwnedRepr<i64>, _> = output.view().to_owned();
187+
let float_output = output_ndarray.mapv(|v| v as f32);
188+
frame_inference_output.insert(output_name, float_output);
189+
}
190+
Err(_err) => {
191+
// Try to convert from i64 since sometimes the models do not return floats
192+
match output_value.try_extract() {
193+
Ok(output) => {
194+
// FIXME: this copies the array twice, first to_owned and then the mapv
195+
let output_ndarray: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, _> = output.view().to_owned();
196+
let float_output = output_ndarray.mapv(|v| v as f32);
197+
frame_inference_output.insert(output_name, float_output);
198+
}
199+
Err(err) => warn!("Error extracting inference results: {}", err.to_string()),
200+
}
201+
},
202+
}
203+
}
204+
}
205+
}
206+
frame.set_inference_output(pipeless::data::InferenceOutput::OnnxInferenceOutput(frame_inference_output));
170207
},
171208
Err(err) => error!("There was an error running inference: {}", err)
172209
}

0 commit comments

Comments
 (0)