From 31f2b60f8c7cc3c171f50887a991b75d7050e14f Mon Sep 17 00:00:00 2001 From: gersmann Date: Mon, 5 Feb 2024 22:18:51 +0100 Subject: [PATCH] feat: implements raw mongo query --- django_mongodb/query.py | 21 ++++++++++++++++----- test/test_models.py | 8 ++++++++ testapp/models.py | 3 +++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 70ffaee..1185c37 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -23,6 +23,10 @@ class RequiresSearchException(Exception): pass +class RequiresSearchIndex(Exception): + pass + + class Node(ABC): def __init__(self, node: Expression, mongo_meta): self.node = node @@ -190,26 +194,33 @@ class MongoSearchVectorExact(MongoSearchLookup): def _get_mongo_search(self, compiler, connection) -> dict: rhs_expressions = self.lhs.get_source_expressions() lhs_expressions = self.rhs.get_source_expressions() - columns = [expression.field.column for expression in rhs_expressions] + columns = set([expression.field.column for expression in rhs_expressions]) query = [expression.value for expression in lhs_expressions] # weight = lhs.weight # config = lhs.config auto_complete_columns = [ key for key, value in self.mongo_meta["search_fields"].items() - if "autocomplete" in value + if "autocomplete" in value and key in columns ] + columns.difference_update(set(auto_complete_columns)) query_columns = [ key for key, value in self.mongo_meta["search_fields"].items() - if "string" in value and key not in auto_complete_columns + if "string" in value and key not in auto_complete_columns and key in columns ] + columns.difference_update(set(query_columns)) + if len(columns) > 0: + raise RequiresSearchIndex( + f"SearchVectorExact requires search fields to be defined for all columns: {columns}" + ) + search_query = dict() if auto_complete_columns and query_columns: search_query = { "compound": { "should": [ - {"wildcard": {"path": columns, "query": query}}, + {"wildcard": {"path": query_columns, "query": query}}, *[ {"autocomplete": {"path": column, "query": query}} for column in auto_complete_columns @@ -220,7 +231,7 @@ def _get_mongo_search(self, compiler, connection) -> dict: elif auto_complete_columns: search_query = {"autocomplete": {"path": auto_complete_columns, "query": query}} elif query_columns: - search_query = {"wildcard": {"path": columns, "query": query}} + search_query = {"wildcard": {"path": query_columns, "query": query}} return search_query diff --git a/test/test_models.py b/test/test_models.py index b397d8a..a085986 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -8,6 +8,7 @@ from django.contrib.postgres.search import SearchQuery, SearchVector from django.utils.timezone import now +from django_mongodb.query import RequiresSearchIndex from refapp.models import RefModel from testapp.models import ( DifferentTableOneToOne, @@ -225,6 +226,7 @@ def test_prefer_search_qs(): @pytest.mark.skipif(os.environ.get("CI") == "true", reason="CI does not have mongodb search") @pytest.mark.django_db(databases=["mongodb"]) def test_mongo_search_index(search_index): + FooModel.objects.all().delete() FooModel.objects.create(name="test", json_field={"foo": "bar"}) FooModel.objects.create(name="test1", json_field={"foo": "bar"}) # search index needs to sync @@ -243,6 +245,12 @@ def test_mongo_search_index(search_index): prefer_search_qs = FooModel.objects.all().prefer_search().filter(name="test") assert len(list(prefer_search_qs)) == 1 + with pytest.raises(RequiresSearchIndex): + search_qs = FooModel.objects.annotate(search=SearchVector("datetime_field")).filter( + search=SearchQuery("test") + ) + assert len(list(search_qs)) == 0 + @pytest.mark.django_db(databases=["default"]) def test_reference_model(): diff --git a/testapp/models.py b/testapp/models.py index 8c87f3b..bfb8ffc 100644 --- a/testapp/models.py +++ b/testapp/models.py @@ -13,6 +13,9 @@ class FooModel(models.Model): time_field = models.TimeField(auto_now_add=True) date_field = models.DateField(auto_now_add=True) + class MongoMeta: + search_fields = {"name": ["string"]} + class SameTableChild(FooModel): dummy_model_ptr = models.OneToOneField(