Skip to content

Commit 715252a

Browse files
feat: hybrid pydantic support for both v1 and v2 (#1652)
Signed-off-by: samsja <[email protected]> Signed-off-by: samsja <[email protected]> Co-authored-by: Johannes Messner <[email protected]>
1 parent 805a982 commit 715252a

File tree

101 files changed

+1596
-1174
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+1596
-1174
lines changed

.github/workflows/ci.yml

+32-18
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,21 @@ jobs:
7171
- name: Test basic import
7272
run: poetry run python -c 'from docarray import DocList, BaseDoc'
7373

74-
75-
check-mypy:
76-
runs-on: ubuntu-latest
77-
steps:
78-
- uses: actions/[email protected]
79-
- name: Set up Python 3.8
80-
uses: actions/setup-python@v4
81-
with:
82-
python-version: 3.8
83-
- name: check mypy
84-
run: |
85-
python -m pip install --upgrade pip
86-
python -m pip install poetry
87-
poetry install --all-extras
88-
poetry run mypy docarray
74+
# it is time to say bye bye to mypy because of the way we handle support of pydantic v1 and v2
75+
# check-mypy:
76+
# runs-on: ubuntu-latest
77+
# steps:
78+
# - uses: actions/[email protected]
79+
# - name: Set up Python 3.8
80+
# uses: actions/setup-python@v4
81+
# with:
82+
# python-version: 3.8
83+
# - name: check mypy
84+
# run: |
85+
# python -m pip install --upgrade pip
86+
# python -m pip install poetry
87+
# poetry install --all-extras
88+
# poetry run mypy docarray
8989

9090

9191
docarray-test:
@@ -95,6 +95,7 @@ jobs:
9595
fail-fast: false
9696
matrix:
9797
python-version: [3.8]
98+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
9899
test-path: [tests/integrations, tests/units, tests/documentation]
99100
steps:
100101
- uses: actions/[email protected]
@@ -108,11 +109,12 @@ jobs:
108109
python -m pip install poetry
109110
poetry install --all-extras
110111
poetry run pip install elasticsearch==8.6.2
112+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
111113
poetry run pip uninstall -y torch
112114
poetry run pip install torch
113115
sudo apt-get update
114116
sudo apt-get install --no-install-recommends ffmpeg
115-
117+
116118
- name: Test
117119
id: test
118120
run: |
@@ -145,6 +147,7 @@ jobs:
145147
fail-fast: false
146148
matrix:
147149
python-version: [3.8]
150+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
148151
steps:
149152
- uses: actions/[email protected]
150153
- name: Set up Python ${{ matrix.python-version }}
@@ -156,6 +159,7 @@ jobs:
156159
python -m pip install --upgrade pip
157160
python -m pip install poetry
158161
poetry install --all-extras
162+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
159163
poetry run pip install elasticsearch==8.6.2
160164
poetry run pip uninstall -y torch
161165
poetry run pip install torch
@@ -193,6 +197,7 @@ jobs:
193197
fail-fast: false
194198
matrix:
195199
python-version: [3.8]
200+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
196201
steps:
197202
- uses: actions/[email protected]
198203
- name: Set up Python ${{ matrix.python-version }}
@@ -203,7 +208,8 @@ jobs:
203208
run: |
204209
python -m pip install --upgrade pip
205210
python -m pip install poetry
206-
poetry install --all-extras
211+
poetry install --all-extras
212+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
207213
poetry run pip install protobuf==3.20.0 # we check that we support 3.19
208214
poetry run pip uninstall -y torch
209215
poetry run pip install torch
@@ -239,6 +245,7 @@ jobs:
239245
matrix:
240246
python-version: [3.8]
241247
db_test_folder: [base_classes, elastic, hnswlib, qdrant, weaviate, redis, milvus]
248+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
242249
steps:
243250
- uses: actions/[email protected]
244251
- name: Set up Python ${{ matrix.python-version }}
@@ -250,6 +257,7 @@ jobs:
250257
python -m pip install --upgrade pip
251258
python -m pip install poetry
252259
poetry install --all-extras
260+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
253261
poetry run pip install protobuf==3.20.0
254262
poetry run pip install tensorflow==2.12.0
255263
poetry run pip uninstall -y torch
@@ -286,6 +294,7 @@ jobs:
286294
fail-fast: false
287295
matrix:
288296
python-version: [3.8]
297+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
289298
steps:
290299
- uses: actions/[email protected]
291300
- name: Set up Python ${{ matrix.python-version }}
@@ -297,6 +306,7 @@ jobs:
297306
python -m pip install --upgrade pip
298307
python -m pip install poetry
299308
poetry install --all-extras
309+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
300310
poetry run pip install protobuf==3.20.0
301311
poetry run pip install tensorflow==2.12.0
302312
poetry run pip install elasticsearch==8.6.2
@@ -333,6 +343,7 @@ jobs:
333343
fail-fast: false
334344
matrix:
335345
python-version: [3.8]
346+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
336347
steps:
337348
- uses: actions/[email protected]
338349
- name: Set up Python ${{ matrix.python-version }}
@@ -344,6 +355,7 @@ jobs:
344355
python -m pip install --upgrade pip
345356
python -m pip install poetry
346357
poetry install --all-extras
358+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
347359
poetry run pip install protobuf==3.20.0
348360
poetry run pip install tensorflow==2.12.0
349361
poetry run pip uninstall -y torch
@@ -379,6 +391,7 @@ jobs:
379391
fail-fast: false
380392
matrix:
381393
python-version: [3.8]
394+
pydantic-version: ["pydantic-v2", "pydantic-v1"]
382395
steps:
383396
- uses: actions/[email protected]
384397
- name: Set up Python ${{ matrix.python-version }}
@@ -390,6 +403,7 @@ jobs:
390403
python -m pip install --upgrade pip
391404
python -m pip install poetry
392405
poetry install --all-extras
406+
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
393407
poetry run pip uninstall -y torch
394408
poetry run pip install torch
395409
poetry run pip install jaxlib
@@ -462,7 +476,7 @@ jobs:
462476

463477
# just for blocking the merge until all parallel tests are successful
464478
success-all-test:
465-
needs: [docarray-test, docarray-test-proto3, docarray-doc-index, docarray-elastic-v8, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, check-mypy, lint-ruff]
479+
needs: [docarray-test, docarray-test-proto3, docarray-doc-index, docarray-elastic-v8, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, lint-ruff]
466480
if: always()
467481
runs-on: ubuntu-latest
468482
steps:

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,6 @@ output/
151151
.pytest-kind
152152
.kube
153153

154-
*.ipynb
154+
*.ipynb
155+
156+
.python-version

docarray/array/any_array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
6868
class _DocArrayTyped(cls): # type: ignore
6969
doc_type: Type[BaseDoc] = cast(Type[BaseDoc], item)
7070

71-
for field in _DocArrayTyped.doc_type.__fields__.keys():
71+
for field in _DocArrayTyped.doc_type._docarray_fields().keys():
7272

7373
def _property_generator(val: str):
7474
def _getter(self):

docarray/array/doc_list/doc_list.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@
2424
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
2525
from docarray.base_doc import AnyDoc, BaseDoc
2626
from docarray.typing import NdArray
27+
from docarray.utils._internal.pydantic import is_pydantic_v2
28+
29+
if is_pydantic_v2:
30+
from pydantic import GetCoreSchemaHandler
31+
from pydantic_core import core_schema
32+
2733
from docarray.utils._internal._typing import safe_issubclass
2834

2935
if TYPE_CHECKING:
30-
from pydantic import BaseConfig
31-
from pydantic.fields import ModelField
3236

3337
from docarray.array.doc_vec.doc_vec import DocVec
3438
from docarray.proto import DocListProto
@@ -215,11 +219,15 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
215219
:return: Returns a list of the field value for each document
216220
in the doc_list like container
217221
"""
218-
field_type = self.__class__.doc_type._get_field_type(field)
222+
field_type = self.__class__.doc_type._get_field_annotation(field)
223+
field_info = self.__class__.doc_type._docarray_fields()[field]
224+
is_field_required = (
225+
field_info.is_required() if is_pydantic_v2 else field_info.required
226+
)
219227

220228
if (
221229
not is_union_type(field_type)
222-
and self.__class__.doc_type.__fields__[field].required
230+
and is_field_required
223231
and isinstance(field_type, type)
224232
and safe_issubclass(field_type, BaseDoc)
225233
):
@@ -263,11 +271,9 @@ def to_doc_vec(
263271
return DocVec.__class_getitem__(self.doc_type)(self, tensor_type=tensor_type)
264272

265273
@classmethod
266-
def validate(
274+
def _docarray_validate(
267275
cls: Type[T],
268276
value: Union[T, Iterable[BaseDoc]],
269-
field: 'ModelField',
270-
config: 'BaseConfig',
271277
):
272278
from docarray.array.doc_vec.doc_vec import DocVec
273279

@@ -336,3 +342,13 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
336342

337343
def __repr__(self):
338344
return AnyDocArray.__repr__(self) # type: ignore
345+
346+
if is_pydantic_v2:
347+
348+
@classmethod
349+
def __get_pydantic_core_schema__(
350+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
351+
) -> core_schema.CoreSchema:
352+
return core_schema.general_plain_validator_function(
353+
cls.validate,
354+
)

docarray/array/doc_vec/column_storage.py

+8
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,11 @@ def values(self) -> ValuesView: # type: ignore
215215
# context: https://github.com/python/typing/discussions/1033
216216
def items(self) -> ItemsView: # type: ignore
217217
return ItemsView(self._local_dict())
218+
219+
def to_dict(self) -> Dict[str, Any]:
220+
"""
221+
Return a dictionary with the same keys as the storage.columns
222+
and the values at position self.index.
223+
Warning: modification on the dict will not be reflected on the storage.
224+
"""
225+
return {key: self[key] for key in self.storage.columns.keys()}

docarray/array/doc_vec/doc_vec.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections import ChainMap
22
from typing import (
3-
TYPE_CHECKING,
43
Any,
54
Dict,
65
Iterable,
@@ -17,7 +16,7 @@
1716
overload,
1817
)
1918

20-
from pydantic import BaseConfig, parse_obj_as
19+
from pydantic import parse_obj_as
2120
from typing_inspect import typingGenericAlias
2221

2322
from docarray.array.any_array import AnyDocArray
@@ -28,17 +27,19 @@
2827
from docarray.base_doc import AnyDoc, BaseDoc
2928
from docarray.typing import NdArray
3029
from docarray.typing.tensor.abstract_tensor import AbstractTensor
30+
from docarray.utils._internal.pydantic import is_pydantic_v2
31+
32+
if is_pydantic_v2:
33+
from pydantic import GetCoreSchemaHandler
34+
from pydantic_core import core_schema
35+
3136
from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
3237
from docarray.utils._internal.misc import (
3338
is_jax_available,
3439
is_tf_available,
3540
is_torch_available,
3641
)
3742

38-
if TYPE_CHECKING:
39-
from pydantic.fields import ModelField
40-
41-
4243
torch_available = is_torch_available()
4344
if torch_available:
4445
from docarray.typing import TorchTensor
@@ -147,12 +148,15 @@ def __init__(
147148
else DocList.__class_getitem__(self.doc_type)(docs)
148149
)
149150

150-
for field_name, field in self.doc_type.__fields__.items():
151+
for field_name, field in self.doc_type._docarray_fields().items():
151152
# here we iterate over the field of the docs schema, and we collect the data
152153
# from each document and put them in the corresponding column
153-
field_type: Type = self.doc_type._get_field_type(field_name)
154+
field_type: Type = self.doc_type._get_field_annotation(field_name)
154155

155-
is_field_required = self.doc_type.__fields__[field_name].required
156+
field_info = self.doc_type._docarray_fields()[field_name]
157+
is_field_required = (
158+
field_info.is_required() if is_pydantic_v2 else field_info.required
159+
)
156160

157161
first_doc_is_none = getattr(docs[0], field_name) is None
158162

@@ -317,11 +321,9 @@ def from_columns_storage(cls: Type[T], storage: ColumnStorage) -> T:
317321
return docs
318322

319323
@classmethod
320-
def validate(
324+
def _docarray_validate(
321325
cls: Type[T],
322326
value: Union[T, Iterable[T_doc]],
323-
field: 'ModelField',
324-
config: 'BaseConfig',
325327
) -> T:
326328
if isinstance(value, cls):
327329
return value
@@ -512,7 +514,7 @@ def _set_data_column(
512514
if col is not None:
513515
validation_class = col.__unparametrizedcls__ or col.__class__
514516
else:
515-
validation_class = self.doc_type.__fields__[field].type_
517+
validation_class = self.doc_type._get_field_annotation(field)
516518

517519
# TODO shape check should be handle by the tensor validation
518520

@@ -521,7 +523,9 @@ def _set_data_column(
521523

522524
elif field in self._storage.doc_columns.keys():
523525
values_ = parse_obj_as(
524-
DocVec.__class_getitem__(self.doc_type._get_field_type(field)),
526+
DocVec.__class_getitem__(
527+
self.doc_type._get_field_annotation(field)
528+
),
525529
values,
526530
)
527531
self._storage.doc_columns[field] = values_
@@ -657,3 +661,13 @@ def traverse_flat(
657661
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
658662
# call implementation in AnyDocArray
659663
return super(IOMixinDocVec, cls).__class_getitem__(item)
664+
665+
if is_pydantic_v2:
666+
667+
@classmethod
668+
def __get_pydantic_core_schema__(
669+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
670+
) -> core_schema.CoreSchema:
671+
return core_schema.general_plain_validator_function(
672+
cls.validate,
673+
)

0 commit comments

Comments
 (0)