Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 26d776d

Browse filesBrowse files
samsjaJoan Fontanals
and
Joan Fontanals
authored
fix: validate before (#1806)
Signed-off-by: samsja <sami.jaghouar@hotmail.fr> Co-authored-by: Joan Fontanals <joan.martinez@jina.ai>
1 parent 7209b78 commit 26d776d
Copy full SHA for 26d776d

File tree

Expand file treeCollapse file tree

14 files changed

+153
-86
lines changed
Filter options
Expand file treeCollapse file tree

14 files changed

+153
-86
lines changed

‎docarray/documents/audio.py

Copy file name to clipboardExpand all lines: docarray/documents/audio.py
+25-9Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
22

33
import numpy as np
4-
54
from pydantic import Field
65

76
from docarray.base_doc import BaseDoc
@@ -10,6 +9,10 @@
109
from docarray.typing.tensor.abstract_tensor import AbstractTensor
1110
from docarray.typing.tensor.audio.audio_tensor import AudioTensor
1211
from docarray.utils._internal.misc import import_library
12+
from docarray.utils._internal.pydantic import is_pydantic_v2
13+
14+
if is_pydantic_v2:
15+
from pydantic import model_validator
1316

1417
if TYPE_CHECKING:
1518
import tensorflow as tf # type: ignore
@@ -121,17 +124,30 @@ class MultiModalDoc(BaseDoc):
121124
)
122125

123126
@classmethod
124-
def validate(
125-
cls: Type[T],
126-
value: Union[str, AbstractTensor, Any],
127-
) -> T:
127+
def _validate(cls, value) -> Dict[str, Any]:
128128
if isinstance(value, str):
129-
value = cls(url=value)
129+
value = dict(url=value)
130130
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
131131
torch is not None
132132
and isinstance(value, torch.Tensor)
133133
or (tf is not None and isinstance(value, tf.Tensor))
134134
):
135-
value = cls(tensor=value)
135+
value = dict(tensor=value)
136+
137+
return value
138+
139+
if is_pydantic_v2:
140+
141+
@model_validator(mode='before')
142+
@classmethod
143+
def validate_model_before(cls, value):
144+
return cls._validate(value)
145+
146+
else:
136147

137-
return super().validate(value)
148+
@classmethod
149+
def validate(
150+
cls: Type[T],
151+
value: Union[str, AbstractTensor, Any],
152+
) -> T:
153+
return super().validate(cls._validate(value))

‎docarray/documents/image.py

Copy file name to clipboardExpand all lines: docarray/documents/image.py
+25-10Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
22

33
import numpy as np
4-
54
from pydantic import Field
65

76
from docarray.base_doc import BaseDoc
87
from docarray.typing import AnyEmbedding, ImageBytes, ImageUrl
98
from docarray.typing.tensor.abstract_tensor import AbstractTensor
109
from docarray.typing.tensor.image.image_tensor import ImageTensor
1110
from docarray.utils._internal.misc import import_library
11+
from docarray.utils._internal.pydantic import is_pydantic_v2
1212

13+
if is_pydantic_v2:
14+
from pydantic import model_validator
1315

1416
if TYPE_CHECKING:
1517
import tensorflow as tf # type: ignore
@@ -115,19 +117,32 @@ class MultiModalDoc(BaseDoc):
115117
)
116118

117119
@classmethod
118-
def validate(
119-
cls: Type[T],
120-
value: Union[str, AbstractTensor, Any],
121-
) -> T:
120+
def _validate(cls, value) -> Dict[str, Any]:
122121
if isinstance(value, str):
123-
value = cls(url=value)
122+
value = dict(url=value)
124123
elif (
125124
isinstance(value, (AbstractTensor, np.ndarray))
126125
or (torch is not None and isinstance(value, torch.Tensor))
127126
or (tf is not None and isinstance(value, tf.Tensor))
128127
):
129-
value = cls(tensor=value)
128+
value = dict(tensor=value)
130129
elif isinstance(value, bytes):
131-
value = cls(byte=value)
130+
value = dict(byte=value)
131+
132+
return value
133+
134+
if is_pydantic_v2:
135+
136+
@model_validator(mode='before')
137+
@classmethod
138+
def validate_model_before(cls, value):
139+
return cls._validate(value)
140+
141+
else:
132142

133-
return super().validate(value)
143+
@classmethod
144+
def validate(
145+
cls: Type[T],
146+
value: Union[str, AbstractTensor, Any],
147+
) -> T:
148+
return super().validate(cls._validate(value))

‎docarray/documents/mesh/mesh_3d.py

Copy file name to clipboardExpand all lines: docarray/documents/mesh/mesh_3d.py
+22-8Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
77
from docarray.typing.tensor.embedding import AnyEmbedding
88
from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl
9+
from docarray.utils._internal.pydantic import is_pydantic_v2
910

11+
if is_pydantic_v2:
12+
from pydantic import model_validator
1013

1114
T = TypeVar('T', bound='Mesh3D')
1215

@@ -125,11 +128,22 @@ class MultiModalDoc(BaseDoc):
125128
default=None,
126129
)
127130

128-
@classmethod
129-
def validate(
130-
cls: Type[T],
131-
value: Union[str, Any],
132-
) -> T:
133-
if isinstance(value, str):
134-
value = cls(url=value)
135-
return super().validate(value)
131+
if is_pydantic_v2:
132+
133+
@model_validator(mode='before')
134+
@classmethod
135+
def validate_model_before(cls, value):
136+
if isinstance(value, str):
137+
return {'url': value}
138+
return value
139+
140+
else:
141+
142+
@classmethod
143+
def validate(
144+
cls: Type[T],
145+
value: Union[str, Any],
146+
) -> T:
147+
if isinstance(value, str):
148+
value = cls(url=value)
149+
return super().validate(value)

‎docarray/documents/point_cloud/point_cloud_3d.py

Copy file name to clipboardExpand all lines: docarray/documents/point_cloud/point_cloud_3d.py
+24-8Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
22

33
import numpy as np
4-
54
from pydantic import Field
65

76
from docarray.base_doc import BaseDoc
87
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
98
from docarray.typing import AnyEmbedding, PointCloud3DUrl
109
from docarray.typing.tensor.abstract_tensor import AbstractTensor
1110
from docarray.utils._internal.misc import import_library
11+
from docarray.utils._internal.pydantic import is_pydantic_v2
12+
13+
if is_pydantic_v2:
14+
from pydantic import model_validator
1215

1316
if TYPE_CHECKING:
1417
import tensorflow as tf # type: ignore
@@ -130,17 +133,30 @@ class MultiModalDoc(BaseDoc):
130133
)
131134

132135
@classmethod
133-
def validate(
134-
cls: Type[T],
135-
value: Union[str, AbstractTensor, Any],
136-
) -> T:
136+
def _validate(self, value: Union[str, AbstractTensor, Any]) -> Any:
137137
if isinstance(value, str):
138-
value = cls(url=value)
138+
value = {'url': value}
139139
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
140140
torch is not None
141141
and isinstance(value, torch.Tensor)
142142
or (tf is not None and isinstance(value, tf.Tensor))
143143
):
144-
value = cls(tensors=PointsAndColors(points=value))
144+
value = {'tensors': PointsAndColors(points=value)}
145+
146+
return value
147+
148+
if is_pydantic_v2:
149+
150+
@model_validator(mode='before')
151+
@classmethod
152+
def validate_model_before(cls, value):
153+
return cls._validate(value)
154+
155+
else:
145156

146-
return super().validate(value)
157+
@classmethod
158+
def validate(
159+
cls: Type[T],
160+
value: Union[str, AbstractTensor, Any],
161+
) -> T:
162+
return super().validate(cls._validate(value))

‎docarray/documents/text.py

Copy file name to clipboardExpand all lines: docarray/documents/text.py
+24-8Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from docarray.base_doc import BaseDoc
66
from docarray.typing import TextUrl
77
from docarray.typing.tensor.embedding import AnyEmbedding
8+
from docarray.utils._internal.pydantic import is_pydantic_v2
9+
10+
if is_pydantic_v2:
11+
from pydantic import model_validator
812

913
T = TypeVar('T', bound='TextDoc')
1014

@@ -129,14 +133,26 @@ def __init__(self, text: Optional[str] = None, **kwargs):
129133
kwargs['text'] = text
130134
super().__init__(**kwargs)
131135

132-
@classmethod
133-
def validate(
134-
cls: Type[T],
135-
value: Union[str, Any],
136-
) -> T:
137-
if isinstance(value, str):
138-
value = cls(text=value)
139-
return super().validate(value)
136+
if is_pydantic_v2:
137+
138+
@model_validator(mode='before')
139+
@classmethod
140+
def validate_model_before(cls, values):
141+
if isinstance(values, str):
142+
return {'text': values}
143+
else:
144+
return values
145+
146+
else:
147+
148+
@classmethod
149+
def validate(
150+
cls: Type[T],
151+
value: Union[str, Any],
152+
) -> T:
153+
if isinstance(value, str):
154+
value = cls(text=value)
155+
return super().validate(value)
140156

141157
def __eq__(self, other: Any) -> bool:
142158
if isinstance(other, str):

‎docarray/documents/video.py

Copy file name to clipboardExpand all lines: docarray/documents/video.py
+25-9Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
22

33
import numpy as np
4-
54
from pydantic import Field
65

76
from docarray.base_doc import BaseDoc
@@ -11,6 +10,10 @@
1110
from docarray.typing.tensor.video.video_tensor import VideoTensor
1211
from docarray.typing.url.video_url import VideoUrl
1312
from docarray.utils._internal.misc import import_library
13+
from docarray.utils._internal.pydantic import is_pydantic_v2
14+
15+
if is_pydantic_v2:
16+
from pydantic import model_validator
1417

1518
if TYPE_CHECKING:
1619
import tensorflow as tf # type: ignore
@@ -131,17 +134,30 @@ class MultiModalDoc(BaseDoc):
131134
)
132135

133136
@classmethod
134-
def validate(
135-
cls: Type[T],
136-
value: Union[str, AbstractTensor, Any],
137-
) -> T:
137+
def _validate(cls, value) -> Dict[str, Any]:
138138
if isinstance(value, str):
139-
value = cls(url=value)
139+
value = dict(url=value)
140140
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
141141
torch is not None
142142
and isinstance(value, torch.Tensor)
143143
or (tf is not None and isinstance(value, tf.Tensor))
144144
):
145-
value = cls(tensor=value)
145+
value = dict(tensor=value)
146+
147+
return value
148+
149+
if is_pydantic_v2:
150+
151+
@model_validator(mode='before')
152+
@classmethod
153+
def validate_model_before(cls, value):
154+
return cls._validate(value)
155+
156+
else:
146157

147-
return super().validate(value)
158+
@classmethod
159+
def validate(
160+
cls: Type[T],
161+
value: Union[str, AbstractTensor, Any],
162+
) -> T:
163+
return super().validate(cls._validate(value))

‎docarray/typing/tensor/ndarray.py

Copy file name to clipboardExpand all lines: docarray/typing/tensor/ndarray.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def _docarray_validate(
142142
return cls._docarray_from_native(arr)
143143
except Exception:
144144
pass # handled below
145-
breakpoint()
146145
raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}')
147146

148147
@classmethod

‎tests/integrations/predefined_document/test_audio.py

Copy file name to clipboardExpand all lines: tests/integrations/predefined_document/test_audio.py
-6Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from docarray.typing import AudioUrl
1212
from docarray.typing.tensor.audio import AudioNdArray, AudioTorchTensor
1313
from docarray.utils._internal.misc import is_tf_available
14-
from docarray.utils._internal.pydantic import is_pydantic_v2
1514
from tests import TOYDATA_DIR
1615

1716
tf_available = is_tf_available()
@@ -184,32 +183,27 @@ class MyAudio(AudioDoc):
184183

185184

186185
# Validating predefined docs against url or tensor is not yet working with pydantic v28
187-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
188186
def test_audio_np():
189187
audio = parse_obj_as(AudioDoc, np.zeros((10, 10, 3)))
190188
assert (audio.tensor == np.zeros((10, 10, 3))).all()
191189

192190

193-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
194191
def test_audio_torch():
195192
audio = parse_obj_as(AudioDoc, torch.zeros(10, 10, 3))
196193
assert (audio.tensor == torch.zeros(10, 10, 3)).all()
197194

198195

199-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
200196
@pytest.mark.tensorflow
201197
def test_audio_tensorflow():
202198
audio = parse_obj_as(AudioDoc, tf.zeros((10, 10, 3)))
203199
assert tnp.allclose(audio.tensor.tensor, tf.zeros((10, 10, 3)))
204200

205201

206-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
207202
def test_audio_bytes():
208203
audio = parse_obj_as(AudioDoc, torch.zeros(10, 10, 3))
209204
audio.bytes_ = audio.tensor.to_bytes()
210205

211206

212-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
213207
def test_audio_shortcut_doc():
214208
class MyDoc(BaseDoc):
215209
audio: AudioDoc

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.