Skip to content

Commit

Permalink
added and function
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin committed Feb 10, 2025
1 parent 1e8bec1 commit 306158f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
sum,
)
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import case, greatest, ifelse, isnone, least, or_
from .conditional import and_, case, greatest, ifelse, isnone, least, or_
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
from .string import byte_hamming_distance
from .window import window

__all__ = [
"and_",
"any_value",
"array",
"avg",
Expand Down
30 changes: 30 additions & 0 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Union

from sqlalchemy import ColumnElement
from sqlalchemy import and_ as sql_and
from sqlalchemy import case as sql_case
from sqlalchemy import or_ as sql_or

Expand Down Expand Up @@ -238,3 +239,32 @@ def or_(*args: Union[ColumnElement, Func]) -> Func:
func_args.append(arg)

return Func("or", inner=sql_or, cols=cols, args=func_args, result_type=bool)


def and_(*args: Union[ColumnElement, Func]) -> Func:
"""
Returns the function that produces conjunction of expressions joined by AND
logical operator.
Args:
args (ColumnElement | Func): The expressions for AND statement.
Returns:
Func: A Func object that represents the and function.
Example:
```py
dc.mutate(
test=ifelse(and_(isnone("name"), isnone("surname")), "Empty", "Not Empty")
)
```
"""
cols, func_args = [], []

for arg in args:
if isinstance(arg, (str, Func)):
cols.append(arg)
else:
func_args.append(arg)

return Func("and", inner=sql_and, cols=cols, args=func_args, result_type=bool)
16 changes: 16 additions & 0 deletions tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,19 @@ def test_or(warehouse, val1, val2, expected):
query = select(or_(isnone(val1), isnone(val2)))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)


@pytest.mark.parametrize(
"val1,val2,expected",
[
[None, func.literal("a"), False],
[None, None, True],
[func.literal("a"), func.literal("a"), False],
],
)
def test_and(warehouse, val1, val2, expected):
from datachain.func.conditional import and_, isnone

query = select(and_(isnone(val1), isnone(val2)))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)
13 changes: 13 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from datachain import C, DataChain
from datachain.func import (
and_,
bit_hamming_distance,
byte_hamming_distance,
case,
Expand Down Expand Up @@ -319,6 +320,18 @@ def test_or_func_mutate(dc):
]


@skip_if_not_sqlite
def test_and_func_mutate(dc):
res = dc.mutate(test=ifelse(and_(C("num") > 1, C("num") < 4), "Match", "Not Match"))
assert list(res.order_by("num").collect("test")) == [
"Not Match",
"Match",
"Match",
"Not Match",
"Not Match",
]


def test_xor():
rnd1, rnd2 = rand(), rand()

Expand Down

0 comments on commit 306158f

Please sign in to comment.