Skip to content

Commit bd9e866

Browse files
authored
Improve models and file methods (#966)
* Improve models and file methods - BBox model now have methods to convert to/from different bbox formats: Albumentations, COCO, PASCAL VOC and YOLO - Add more validations and tests for BBox, Pose and Segment models - Add tests for Ultralytics models - Add File conversion methos: to TextFile, ImageFile and VideoFile (see Ultralytcs examples for usage) - Add 'image_info' method to get meta from ImageFile (same as for video file) * Update 'openimage-detect' example to use new BBox.from_albumentations method * Add more tests, refactor File class 'read' and 'read_bytes' methods
1 parent 6846558 commit bd9e866

32 files changed

+3597
-265
lines changed

examples/computer_vision/openimage-detect.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,9 @@ def openimage_detect(args):
2222
detections = json.load(stream_json).get("detections", [])
2323

2424
for i, detect in enumerate(detections):
25-
bbox = model.BBox.from_list(
26-
[
27-
detect["XMin"] * img.width,
28-
detect["XMax"] * img.width,
29-
detect["YMin"] * img.height,
30-
detect["YMax"] * img.height,
31-
]
25+
bbox = model.BBox.from_albumentations(
26+
[detect[k] for k in ("XMin", "YMin", "XMax", "YMax")],
27+
img_size=(img.width, img.height),
3228
)
3329

3430
fstream = File(

examples/computer_vision/ultralytics-bbox.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
1-
import os
2-
3-
os.environ["YOLO_VERBOSE"] = "false"
4-
5-
6-
from io import BytesIO
7-
8-
from PIL import Image
91
from ultralytics import YOLO
102

113
from datachain import C, DataChain, File
124
from datachain.model.ultralytics import YoloBBoxes
135

146

157
def process_bboxes(yolo: YOLO, file: File) -> YoloBBoxes:
16-
results = yolo(Image.open(BytesIO(file.read())))
8+
results = yolo(file.as_image_file().read(), verbose=False)
179
return YoloBBoxes.from_results(results)
1810

1911

examples/computer_vision/ultralytics-pose.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
1-
import os
2-
3-
os.environ["YOLO_VERBOSE"] = "false"
4-
5-
6-
from io import BytesIO
7-
8-
from PIL import Image
91
from ultralytics import YOLO
102

113
from datachain import C, DataChain, File
124
from datachain.model.ultralytics import YoloPoses
135

146

157
def process_poses(yolo: YOLO, file: File) -> YoloPoses:
16-
results = yolo(Image.open(BytesIO(file.read())))
8+
results = yolo(file.as_image_file().read(), verbose=False)
179
return YoloPoses.from_results(results)
1810

1911

examples/computer_vision/ultralytics-segment.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
1-
import os
2-
3-
os.environ["YOLO_VERBOSE"] = "false"
4-
5-
6-
from io import BytesIO
7-
8-
from PIL import Image
91
from ultralytics import YOLO
102

113
from datachain import C, DataChain, File
124
from datachain.model.ultralytics import YoloSegments
135

146

157
def process_segments(yolo: YOLO, file: File) -> YoloSegments:
16-
results = yolo(Image.open(BytesIO(file.read())))
8+
results = yolo(file.as_image_file().read(), verbose=False)
179
return YoloSegments.from_results(results)
1810

1911

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ tests = [
102102
"hypothesis",
103103
"aiotools>=1.7.0",
104104
"requests-mock",
105-
"scipy"
105+
"scipy",
106+
"ultralytics"
106107
]
107108
dev = [
108109
"datachain[docs,tests]",
@@ -118,7 +119,7 @@ examples = [
118119
"defusedxml",
119120
"accelerate",
120121
"huggingface_hub[hf_transfer]",
121-
"ultralytics==8.3.87",
122+
"ultralytics",
122123
"open_clip_torch"
123124
]
124125

src/datachain/lib/file.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,30 @@ def __init__(self, **kwargs):
242242
self._catalog = None
243243
self._caching_enabled: bool = False
244244

245+
def as_text_file(self) -> "TextFile":
246+
"""Convert the file to a `TextFile` object."""
247+
if isinstance(self, TextFile):
248+
return self
249+
file = TextFile(**self.model_dump())
250+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
251+
return file
252+
253+
def as_image_file(self) -> "ImageFile":
254+
"""Convert the file to a `ImageFile` object."""
255+
if isinstance(self, ImageFile):
256+
return self
257+
file = ImageFile(**self.model_dump())
258+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
259+
return file
260+
261+
def as_video_file(self) -> "VideoFile":
262+
"""Convert the file to a `VideoFile` object."""
263+
if isinstance(self, VideoFile):
264+
return self
265+
file = VideoFile(**self.model_dump())
266+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
267+
return file
268+
245269
@classmethod
246270
def upload(
247271
cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
@@ -291,20 +315,20 @@ def open(self, mode: Literal["rb", "r"] = "rb") -> Iterator[Any]:
291315
) as f:
292316
yield io.TextIOWrapper(f) if mode == "r" else f
293317

294-
def read(self, length: int = -1):
295-
"""Returns file contents."""
318+
def read_bytes(self, length: int = -1):
319+
"""Returns file contents as bytes."""
296320
with self.open() as stream:
297321
return stream.read(length)
298322

299-
def read_bytes(self):
300-
"""Returns file contents as bytes."""
301-
return self.read()
302-
303323
def read_text(self):
304324
"""Returns file contents as text."""
305325
with self.open(mode="r") as stream:
306326
return stream.read()
307327

328+
def read(self, length: int = -1):
329+
"""Returns file contents."""
330+
return self.read_bytes(length)
331+
308332
def save(self, destination: str, client_config: Optional[dict] = None):
309333
"""Writes it's content to destination"""
310334
destination = stringify_path(destination)
@@ -547,20 +571,36 @@ def save(self, destination: str, client_config: Optional[dict] = None):
547571
class ImageFile(File):
548572
"""`DataModel` for reading image files."""
549573

574+
def get_info(self) -> "Image":
575+
"""
576+
Retrieves metadata and information about the image file.
577+
578+
Returns:
579+
Image: A Model containing image metadata such as width, height and format.
580+
"""
581+
from .image import image_info
582+
583+
return image_info(self)
584+
550585
def read(self):
551586
"""Returns `PIL.Image.Image` object."""
552587
from PIL import Image as PilImage
553588

554589
fobj = super().read()
555590
return PilImage.open(BytesIO(fobj))
556591

557-
def save(self, destination: str, client_config: Optional[dict] = None):
592+
def save( # type: ignore[override]
593+
self,
594+
destination: str,
595+
format: Optional[str] = None,
596+
client_config: Optional[dict] = None,
597+
):
558598
"""Writes it's content to destination"""
559599
destination = stringify_path(destination)
560600

561601
client: Client = self._catalog.get_client(destination, **(client_config or {}))
562602
with client.fs.open(destination, mode="wb") as f:
563-
self.read().save(f)
603+
self.read().save(f, format=format)
564604

565605

566606
class Image(DataModel):

src/datachain/lib/image.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,41 @@
11
from typing import Callable, Optional, Union
22

33
import torch
4-
from PIL import Image
4+
from PIL import Image as PILImage
5+
6+
from datachain.lib.file import File, FileError, Image, ImageFile
7+
8+
9+
def image_info(file: Union[File, ImageFile]) -> Image:
10+
"""
11+
Returns image file information.
12+
13+
Args:
14+
file (ImageFile): Image file object.
15+
16+
Returns:
17+
Image: Image file information.
18+
"""
19+
try:
20+
img = file.as_image_file().read()
21+
except Exception as exc:
22+
raise FileError(file, "unable to open image file") from exc
23+
24+
return Image(
25+
width=img.width,
26+
height=img.height,
27+
format=img.format or "",
28+
)
529

630

731
def convert_image(
8-
img: Image.Image,
32+
img: PILImage.Image,
933
mode: str = "RGB",
1034
size: Optional[tuple[int, int]] = None,
1135
transform: Optional[Callable] = None,
1236
encoder: Optional[Callable] = None,
1337
device: Optional[Union[str, torch.device]] = None,
14-
) -> Union[Image.Image, torch.Tensor]:
38+
) -> Union[PILImage.Image, torch.Tensor]:
1539
"""
1640
Resize, transform, and otherwise convert an image.
1741
@@ -47,13 +71,13 @@ def convert_image(
4771

4872

4973
def convert_images(
50-
images: Union[Image.Image, list[Image.Image]],
74+
images: Union[PILImage.Image, list[PILImage.Image]],
5175
mode: str = "RGB",
5276
size: Optional[tuple[int, int]] = None,
5377
transform: Optional[Callable] = None,
5478
encoder: Optional[Callable] = None,
5579
device: Optional[Union[str, torch.device]] = None,
56-
) -> Union[list[Image.Image], torch.Tensor]:
80+
) -> Union[list[PILImage.Image], torch.Tensor]:
5781
"""
5882
Resize, transform, and otherwise convert one or more images.
5983
@@ -65,7 +89,7 @@ def convert_images(
6589
encoder (Callable): Encode image using model.
6690
device (str or torch.device): Device to use.
6791
"""
68-
if isinstance(images, Image.Image):
92+
if isinstance(images, PILImage.Image):
6993
images = [images]
7094

7195
converted = [

src/datachain/lib/video.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import posixpath
22
import shutil
33
import tempfile
4-
from typing import Optional
4+
from typing import Optional, Union
55

66
from numpy import ndarray
77

8-
from datachain.lib.file import FileError, ImageFile, Video, VideoFile
8+
from datachain.lib.file import File, FileError, ImageFile, Video, VideoFile
99

1010
try:
1111
import ffmpeg
@@ -18,7 +18,7 @@
1818
) from exc
1919

2020

21-
def video_info(file: VideoFile) -> Video:
21+
def video_info(file: Union[File, VideoFile]) -> Video:
2222
"""
2323
Returns video file information.
2424
@@ -28,6 +28,8 @@ def video_info(file: VideoFile) -> Video:
2828
Returns:
2929
Video: Video file information.
3030
"""
31+
file = file.as_video_file()
32+
3133
if not (file_path := file.get_local_path()):
3234
file.ensure_cached()
3335
file_path = file.get_local_path()

0 commit comments

Comments
 (0)