Skip to content

Commit 0ff6d54

Browse files
authored
added and function (#910)
1 parent 1e8bec1 commit 0ff6d54

File tree

4 files changed

+61
-1
lines changed

4 files changed

+61
-1
lines changed

Diff for: src/datachain/func/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
sum,
1717
)
1818
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
19-
from .conditional import case, greatest, ifelse, isnone, least, or_
19+
from .conditional import and_, case, greatest, ifelse, isnone, least, or_
2020
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
2121
from .random import rand
2222
from .string import byte_hamming_distance
2323
from .window import window
2424

2525
__all__ = [
26+
"and_",
2627
"any_value",
2728
"array",
2829
"avg",

Diff for: src/datachain/func/conditional.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Union
22

33
from sqlalchemy import ColumnElement
4+
from sqlalchemy import and_ as sql_and
45
from sqlalchemy import case as sql_case
56
from sqlalchemy import or_ as sql_or
67

@@ -238,3 +239,32 @@ def or_(*args: Union[ColumnElement, Func]) -> Func:
238239
func_args.append(arg)
239240

240241
return Func("or", inner=sql_or, cols=cols, args=func_args, result_type=bool)
242+
243+
244+
def and_(*args: Union[ColumnElement, Func]) -> Func:
245+
"""
246+
Returns the function that produces conjunction of expressions joined by AND
247+
logical operator.
248+
249+
Args:
250+
args (ColumnElement | Func): The expressions for AND statement.
251+
252+
Returns:
253+
Func: A Func object that represents the and function.
254+
255+
Example:
256+
```py
257+
dc.mutate(
258+
test=ifelse(and_(isnone("name"), isnone("surname")), "Empty", "Not Empty")
259+
)
260+
```
261+
"""
262+
cols, func_args = [], []
263+
264+
for arg in args:
265+
if isinstance(arg, (str, Func)):
266+
cols.append(arg)
267+
else:
268+
func_args.append(arg)
269+
270+
return Func("and", inner=sql_and, cols=cols, args=func_args, result_type=bool)

Diff for: tests/unit/sql/test_conditional.py

+16
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,19 @@ def test_or(warehouse, val1, val2, expected):
170170
query = select(or_(isnone(val1), isnone(val2)))
171171
result = tuple(warehouse.db.execute(query))
172172
assert result == ((expected,),)
173+
174+
175+
@pytest.mark.parametrize(
176+
"val1,val2,expected",
177+
[
178+
[None, func.literal("a"), False],
179+
[None, None, True],
180+
[func.literal("a"), func.literal("a"), False],
181+
],
182+
)
183+
def test_and(warehouse, val1, val2, expected):
184+
from datachain.func.conditional import and_, isnone
185+
186+
query = select(and_(isnone(val1), isnone(val2)))
187+
result = tuple(warehouse.db.execute(query))
188+
assert result == ((expected,),)

Diff for: tests/unit/test_func.py

+13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from datachain import C, DataChain
55
from datachain.func import (
6+
and_,
67
bit_hamming_distance,
78
byte_hamming_distance,
89
case,
@@ -319,6 +320,18 @@ def test_or_func_mutate(dc):
319320
]
320321

321322

323+
@skip_if_not_sqlite
324+
def test_and_func_mutate(dc):
325+
res = dc.mutate(test=ifelse(and_(C("num") > 1, C("num") < 4), "Match", "Not Match"))
326+
assert list(res.order_by("num").collect("test")) == [
327+
"Not Match",
328+
"Match",
329+
"Match",
330+
"Not Match",
331+
"Not Match",
332+
]
333+
334+
322335
def test_xor():
323336
rnd1, rnd2 = rand(), rand()
324337

0 commit comments

Comments
 (0)