Skip to content

Commit 884215f

Browse files
committed
Add function to visualize Yolo results
1 parent 86fc806 commit 884215f

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

src/datachain/toolkit/ultralytics.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
import torch
5+
from PIL import Image
6+
from ultralytics.engine.results import Results
7+
8+
from datachain.model.ultralytics.bbox import YoloBBox, YoloBBoxes
9+
from datachain.model.ultralytics.pose import YoloPose, YoloPoses
10+
from datachain.model.ultralytics.segment import YoloSegment, YoloSegments
11+
12+
YoloSignal = Union[YoloBBox, YoloBBoxes, YoloPose, YoloPoses, YoloSegment, YoloSegments]
13+
14+
15+
def _signal_to_results(img: np.ndarray, signal: YoloSignal) -> Results:
16+
# Convert RGB to BGR
17+
if img.ndim == 3 and img.shape[2] == 3:
18+
bgr_array = img[:, :, ::-1]
19+
else:
20+
# If the image is not RGB (e.g., grayscale or RGBA), use as is
21+
bgr_array = img
22+
23+
names = {}
24+
boxes_list = []
25+
keypoints_list = []
26+
masks_list = []
27+
28+
# Get the boxes, keypoints, and masks from the signal
29+
if isinstance(signal, YoloBBox):
30+
names[signal.cls] = signal.name
31+
boxes_list.append(
32+
torch.tensor([[*signal.box.coords, signal.confidence, signal.cls]])
33+
)
34+
elif isinstance(signal, YoloBBoxes):
35+
for i, _ in enumerate(signal.cls):
36+
names[signal.cls[i]] = signal.name[i]
37+
boxes_list.append(
38+
torch.tensor(
39+
[[*signal.box[i].coords, signal.confidence[i], signal.cls[i]]]
40+
)
41+
)
42+
elif isinstance(signal, YoloPose):
43+
names[signal.cls] = signal.name
44+
boxes_list.append(
45+
torch.tensor([[*signal.box.coords, signal.confidence, signal.cls]])
46+
)
47+
keypoints_list.append(
48+
torch.tensor([list(zip(signal.pose.x, signal.pose.y, signal.pose.visible))])
49+
)
50+
elif isinstance(signal, YoloPoses):
51+
for i, _ in enumerate(signal.cls):
52+
names[signal.cls[i]] = signal.name[i]
53+
boxes_list.append(
54+
torch.tensor(
55+
[[*signal.box[i].coords, signal.confidence[i], signal.cls[i]]]
56+
)
57+
)
58+
keypoints_list.append(
59+
torch.tensor(
60+
[
61+
list(
62+
zip(
63+
signal.pose[i].x,
64+
signal.pose[i].y,
65+
signal.pose[i].visible,
66+
)
67+
)
68+
]
69+
)
70+
)
71+
elif isinstance(signal, YoloSegment):
72+
names[signal.cls] = signal.name
73+
boxes_list.append(
74+
torch.tensor([[*signal.box.coords, signal.confidence, signal.cls]])
75+
)
76+
masks_list.append(torch.tensor([list(zip(signal.segment.x, signal.segment.y))]))
77+
elif isinstance(signal, YoloSegments):
78+
for i, _ in enumerate(signal.cls):
79+
names[signal.cls[i]] = signal.name[i]
80+
boxes_list.append(
81+
torch.tensor(
82+
[[*signal.box[i].coords, signal.confidence[i], signal.cls[i]]]
83+
)
84+
)
85+
masks_list.append(
86+
torch.tensor([list(zip(signal.segment[i].x, signal.segment[i].y))])
87+
)
88+
89+
boxes = torch.cat(boxes_list, dim=0) if len(boxes_list) > 0 else None
90+
keypoints = torch.cat(keypoints_list, dim=0) if len(keypoints_list) > 0 else None
91+
masks = torch.cat(masks_list, dim=0) if len(masks_list) > 0 else None
92+
93+
return Results(
94+
bgr_array,
95+
path="",
96+
names=names,
97+
boxes=boxes,
98+
keypoints=keypoints,
99+
masks=masks,
100+
)
101+
102+
103+
def visualize_yolo(
104+
img: np.ndarray,
105+
signal: YoloSignal,
106+
scale: float = 1.0,
107+
line_width: int = 1,
108+
font_size: int = 20,
109+
kpt_radius: int = 3,
110+
) -> Image.Image:
111+
"""
112+
Visualize signals detected by YOLO.
113+
114+
Args:
115+
image (ndarray): The image to visualize as a NumPy array.
116+
signal: The signal detected by YOLO. Possible signals are YoloBBox, YoloBBoxes,
117+
YoloPose, YoloPoses, YoloSegment, and YoloSegments.
118+
scale (float): The scale factor for the image. Default is 1.0.
119+
line_width (int): The line width for drawing boxes and lines. Default is 1.
120+
font_size (int): The font size for text. Default is 20.
121+
kpt_radius (int): The radius for drawing keypoints. Default is 3.
122+
123+
Returns:
124+
PIL.Image.Image: The image with the detected signals visualized.
125+
"""
126+
results = _signal_to_results(img, signal)
127+
128+
im_bgr = results.plot(
129+
line_width=line_width,
130+
font_size=font_size,
131+
kpt_radius=kpt_radius,
132+
)
133+
134+
im_rgb = Image.fromarray(im_bgr[..., ::-1])
135+
136+
if scale != 1.0:
137+
orig_height, orig_width = results.orig_shape
138+
new_size = (int(orig_width * scale), int(orig_height * scale))
139+
im_rgb = im_rgb.resize(new_size, Image.Resampling.LANCZOS)
140+
141+
return im_rgb

0 commit comments

Comments
 (0)