Skip to content

Commit 6942404

Browse files
committed
feat: wip
1 parent 758c327 commit 6942404

File tree

2 files changed

+92
-16
lines changed

2 files changed

+92
-16
lines changed

advanced_alchemy/filters.py

+84-11
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,20 @@
4141
from operator import attrgetter
4242
from typing import TYPE_CHECKING, Any, Generic, Literal, cast
4343

44-
from sqlalchemy import BinaryExpression, Delete, Select, Update, and_, any_, exists, or_, select, text
44+
from sqlalchemy import (
45+
BinaryExpression,
46+
Delete,
47+
Select,
48+
Update,
49+
and_,
50+
any_,
51+
exists,
52+
false,
53+
not_,
54+
or_,
55+
select,
56+
text,
57+
)
4558
from typing_extensions import TypeVar
4659

4760
if TYPE_CHECKING:
@@ -615,6 +628,8 @@ class ExistsFilter(StatementFilter):
615628
)
616629
"""
617630

631+
field_name: str
632+
"""Name of model attribute to search on."""
618633
values: list[ColumnElement[bool]]
619634
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
620635
operator: Literal["and", "or"] = "and"
@@ -644,6 +659,35 @@ def _or(self) -> Callable[..., ColumnElement[bool]]:
644659
"""
645660
return or_
646661

662+
def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]:
663+
"""Generate the EXISTS clause for the statement.
664+
665+
Args:
666+
model: The SQLAlchemy model class
667+
668+
Returns:
669+
ColumnElement[bool]: EXISTS clause
670+
"""
671+
field = self._get_instrumented_attr(model, self.field_name)
672+
673+
# Get the underlying column name of the field
674+
field_column = getattr(field, "comparator", None)
675+
if not field_column:
676+
return false() # Handle cases where the field might not be directly comparable, ie. relations
677+
field_column_name = field_column.key
678+
679+
# Construct a subquery using select()
680+
subquery = select(field).where(
681+
*(
682+
[getattr(model, field_column_name) == getattr(model, field_column_name), self._and(*self.values)]
683+
if self.operator == "and"
684+
else [getattr(model, field_column_name) == getattr(model, field_column_name), self._or(*self.values)]
685+
)
686+
)
687+
688+
# Use the subquery in the exists() clause
689+
return exists(subquery)
690+
647691
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
648692
"""Apply EXISTS condition to the statement.
649693
@@ -659,10 +703,7 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
659703
"""
660704
if not self.values:
661705
return statement
662-
663-
if self.operator == "and":
664-
exists_clause = select(model).where(self._and(*self.values)).exists()
665-
exists_clause = select(model).where(self._or(*self.values)).exists()
706+
exists_clause = self.get_exists_clause(model)
666707
return cast("StatementTypeT", statement.where(exists_clause))
667708

668709

@@ -685,6 +726,7 @@ class NotExistsFilter(StatementFilter):
685726
from advanced_alchemy.filters import NotExistsFilter
686727
687728
filter = NotExistsFilter(
729+
field_name="User.is_active",
688730
values=[User.email.like("%@example.com%")],
689731
)
690732
statement = filter.append_to_statement(
@@ -694,13 +736,16 @@ class NotExistsFilter(StatementFilter):
694736
Using OR conditions::
695737
696738
filter = NotExistsFilter(
739+
field_name="User.role",
697740
values=[User.role == "admin", User.role == "owner"],
698741
operator="or",
699742
)
700743
"""
701744

745+
field_name: str
746+
"""Name of model attribute to search on."""
702747
values: list[ColumnElement[bool]]
703-
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
748+
"""List of SQLAlchemy column expressions to use in the NOT EXISTS clause."""
704749
operator: Literal["and", "or"] = "and"
705750
"""If "and", combines conditions with AND, otherwise uses OR."""
706751

@@ -728,6 +773,37 @@ def _or(self) -> Callable[..., ColumnElement[bool]]:
728773
"""
729774
return or_
730775

776+
def get_exists_clause(self, model: type[ModelT]) -> ColumnElement[bool]:
777+
"""Generate the NOT EXISTS clause for the statement.
778+
779+
Args:
780+
model: The SQLAlchemy model class
781+
782+
Returns:
783+
ColumnElement[bool]: NOT EXISTS clause
784+
"""
785+
field = self._get_instrumented_attr(model, self.field_name)
786+
787+
# Get the underlying column name of the field
788+
field_column = getattr(field, "comparator", None)
789+
if not field_column:
790+
return false() # Handle cases where the field might not be directly comparable, ie. relations
791+
field_column_name = field_column.key
792+
793+
# Construct a subquery using select()
794+
subquery = select(field).where(
795+
*(
796+
[getattr(model, field_column_name) == getattr(model, field_column_name), self._and(*self.values)]
797+
if self.operator == "and"
798+
else [
799+
getattr(model, field_column_name) == getattr(model, field_column_name),
800+
self._or(*self.values),
801+
]
802+
)
803+
)
804+
# Use the subquery in the exists() clause and negate it with not_()
805+
return not_(exists(subquery))
806+
731807
def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) -> StatementTypeT:
732808
"""Apply NOT EXISTS condition to the statement.
733809
@@ -743,8 +819,5 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
743819
"""
744820
if not self.values:
745821
return statement
746-
747-
if self.operator == "and":
748-
exists_clause = select(model).where(self._and(*self.values)).exists()
749-
exists_clause = select(model).where(self._or(*self.values)).exists()
750-
return cast("StatementTypeT", statement.where(~exists_clause))
822+
exists_clause = self.get_exists_clause(model)
823+
return cast("StatementTypeT", statement.where(exists_clause))

tests/integration/test_filters.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,35 @@ def test_not_in_collection_filter(db_session: Session) -> None:
8383

8484

8585
def test_exists_filter_basic(db_session: Session) -> None:
86-
exists_filter_1 = ExistsFilter(values=[Movie.genre == "Action"])
86+
exists_filter_1 = ExistsFilter(field_name="genre", values=[Movie.genre == "Action"])
8787
statement = exists_filter_1.append_to_statement(select(Movie), Movie)
8888
results = db_session.execute(statement).scalars().all()
8989
assert len(results) == 1
9090

91-
exists_filter_2 = ExistsFilter(values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")])
91+
exists_filter_2 = ExistsFilter(
92+
field_name="genre", values=[Movie.genre.startswith("Action"), Movie.genre.startswith("Drama")]
93+
)
9294
statement = exists_filter_2.append_to_statement(select(Movie), Movie)
9395
results = db_session.execute(statement).scalars().all()
9496
assert len(results) == 2
9597

9698

9799
def test_exists_filter(db_session: Session) -> None:
98-
exists_filter_1 = ExistsFilter(values=[Movie.title.startswith("The")])
100+
exists_filter_1 = ExistsFilter(field_name="title", values=[Movie.title.startswith("The")])
99101
statement = exists_filter_1.append_to_statement(select(Movie), Movie)
100102
results = db_session.execute(statement).scalars().all()
101103
assert len(results) == 3
102104

103105
exists_filter_2 = ExistsFilter(
106+
field_name="title",
104107
values=[Movie.title.startswith("Shawshank Redemption"), Movie.title.startswith("The")],
105-
operator="and",
106108
)
107109
statement = exists_filter_2.append_to_statement(select(Movie), Movie)
108110
results = db_session.execute(statement).scalars().all()
109111
assert len(results) == 0
110112

111113
exists_filter_3 = ExistsFilter(
114+
field_name="title",
112115
values=[Movie.title.startswith("The"), Movie.title.startswith("Shawshank")],
113116
operator="or",
114117
)
@@ -118,7 +121,7 @@ def test_exists_filter(db_session: Session) -> None:
118121

119122

120123
def test_not_exists_filter(db_session: Session) -> None:
121-
not_exists_filter = NotExistsFilter(values=[Movie.title.like("%Hangover%")])
124+
not_exists_filter = NotExistsFilter(field_name="title", values=[Movie.title.like("%Hangover%")])
122125
statement = not_exists_filter.append_to_statement(select(Movie), Movie)
123126
results = db_session.execute(statement).scalars().all()
124127
assert len(results) == 2

0 commit comments

Comments
 (0)