41
41
from operator import attrgetter
42
42
from typing import TYPE_CHECKING , Any , Generic , Literal , cast
43
43
44
- from sqlalchemy import BinaryExpression , Delete , Select , Update , and_ , any_ , exists , or_ , select , text
44
+ from sqlalchemy import (
45
+ BinaryExpression ,
46
+ Delete ,
47
+ Select ,
48
+ Update ,
49
+ and_ ,
50
+ any_ ,
51
+ exists ,
52
+ false ,
53
+ not_ ,
54
+ or_ ,
55
+ select ,
56
+ text ,
57
+ )
45
58
from typing_extensions import TypeVar
46
59
47
60
if TYPE_CHECKING :
@@ -615,6 +628,8 @@ class ExistsFilter(StatementFilter):
615
628
)
616
629
"""
617
630
631
+ field_name : str
632
+ """Name of model attribute to search on."""
618
633
values : list [ColumnElement [bool ]]
619
634
"""List of SQLAlchemy column expressions to use in the EXISTS clause."""
620
635
operator : Literal ["and" , "or" ] = "and"
@@ -644,6 +659,35 @@ def _or(self) -> Callable[..., ColumnElement[bool]]:
644
659
"""
645
660
return or_
646
661
662
+ def get_exists_clause (self , model : type [ModelT ]) -> ColumnElement [bool ]:
663
+ """Generate the EXISTS clause for the statement.
664
+
665
+ Args:
666
+ model: The SQLAlchemy model class
667
+
668
+ Returns:
669
+ ColumnElement[bool]: EXISTS clause
670
+ """
671
+ field = self ._get_instrumented_attr (model , self .field_name )
672
+
673
+ # Get the underlying column name of the field
674
+ field_column = getattr (field , "comparator" , None )
675
+ if not field_column :
676
+ return false () # Handle cases where the field might not be directly comparable, ie. relations
677
+ field_column_name = field_column .key
678
+
679
+ # Construct a subquery using select()
680
+ subquery = select (field ).where (
681
+ * (
682
+ [getattr (model , field_column_name ) == getattr (model , field_column_name ), self ._and (* self .values )]
683
+ if self .operator == "and"
684
+ else [getattr (model , field_column_name ) == getattr (model , field_column_name ), self ._or (* self .values )]
685
+ )
686
+ )
687
+
688
+ # Use the subquery in the exists() clause
689
+ return exists (subquery )
690
+
647
691
def append_to_statement (self , statement : StatementTypeT , model : type [ModelT ]) -> StatementTypeT :
648
692
"""Apply EXISTS condition to the statement.
649
693
@@ -659,10 +703,7 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
659
703
"""
660
704
if not self .values :
661
705
return statement
662
-
663
- if self .operator == "and" :
664
- exists_clause = select (model ).where (self ._and (* self .values )).exists ()
665
- exists_clause = select (model ).where (self ._or (* self .values )).exists ()
706
+ exists_clause = self .get_exists_clause (model )
666
707
return cast ("StatementTypeT" , statement .where (exists_clause ))
667
708
668
709
@@ -685,6 +726,7 @@ class NotExistsFilter(StatementFilter):
685
726
from advanced_alchemy.filters import NotExistsFilter
686
727
687
728
filter = NotExistsFilter(
729
+ field_name="User.is_active",
688
730
values=[User.email.like("%@example.com%")],
689
731
)
690
732
statement = filter.append_to_statement(
@@ -694,13 +736,16 @@ class NotExistsFilter(StatementFilter):
694
736
Using OR conditions::
695
737
696
738
filter = NotExistsFilter(
739
+ field_name="User.role",
697
740
values=[User.role == "admin", User.role == "owner"],
698
741
operator="or",
699
742
)
700
743
"""
701
744
745
+ field_name : str
746
+ """Name of model attribute to search on."""
702
747
values : list [ColumnElement [bool ]]
703
- """List of SQLAlchemy column expressions to use in the EXISTS clause."""
748
+ """List of SQLAlchemy column expressions to use in the NOT EXISTS clause."""
704
749
operator : Literal ["and" , "or" ] = "and"
705
750
"""If "and", combines conditions with AND, otherwise uses OR."""
706
751
@@ -728,6 +773,37 @@ def _or(self) -> Callable[..., ColumnElement[bool]]:
728
773
"""
729
774
return or_
730
775
776
+ def get_exists_clause (self , model : type [ModelT ]) -> ColumnElement [bool ]:
777
+ """Generate the NOT EXISTS clause for the statement.
778
+
779
+ Args:
780
+ model: The SQLAlchemy model class
781
+
782
+ Returns:
783
+ ColumnElement[bool]: NOT EXISTS clause
784
+ """
785
+ field = self ._get_instrumented_attr (model , self .field_name )
786
+
787
+ # Get the underlying column name of the field
788
+ field_column = getattr (field , "comparator" , None )
789
+ if not field_column :
790
+ return false () # Handle cases where the field might not be directly comparable, ie. relations
791
+ field_column_name = field_column .key
792
+
793
+ # Construct a subquery using select()
794
+ subquery = select (field ).where (
795
+ * (
796
+ [getattr (model , field_column_name ) == getattr (model , field_column_name ), self ._and (* self .values )]
797
+ if self .operator == "and"
798
+ else [
799
+ getattr (model , field_column_name ) == getattr (model , field_column_name ),
800
+ self ._or (* self .values ),
801
+ ]
802
+ )
803
+ )
804
+ # Use the subquery in the exists() clause and negate it with not_()
805
+ return not_ (exists (subquery ))
806
+
731
807
def append_to_statement (self , statement : StatementTypeT , model : type [ModelT ]) -> StatementTypeT :
732
808
"""Apply NOT EXISTS condition to the statement.
733
809
@@ -743,8 +819,5 @@ def append_to_statement(self, statement: StatementTypeT, model: type[ModelT]) ->
743
819
"""
744
820
if not self .values :
745
821
return statement
746
-
747
- if self .operator == "and" :
748
- exists_clause = select (model ).where (self ._and (* self .values )).exists ()
749
- exists_clause = select (model ).where (self ._or (* self .values )).exists ()
750
- return cast ("StatementTypeT" , statement .where (~ exists_clause ))
822
+ exists_clause = self .get_exists_clause (model )
823
+ return cast ("StatementTypeT" , statement .where (exists_clause ))
0 commit comments