Skip to content

Commit

Permalink
feat: implements raw mongo query
Browse files Browse the repository at this point in the history
  • Loading branch information
gersmann committed Feb 5, 2024
1 parent 93f9ac2 commit 31f2b60
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
21 changes: 16 additions & 5 deletions django_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class RequiresSearchException(Exception):
pass


class RequiresSearchIndex(Exception):
pass


class Node(ABC):
def __init__(self, node: Expression, mongo_meta):
self.node = node
Expand Down Expand Up @@ -190,26 +194,33 @@ class MongoSearchVectorExact(MongoSearchLookup):
def _get_mongo_search(self, compiler, connection) -> dict:
rhs_expressions = self.lhs.get_source_expressions()
lhs_expressions = self.rhs.get_source_expressions()
columns = [expression.field.column for expression in rhs_expressions]
columns = set([expression.field.column for expression in rhs_expressions])
query = [expression.value for expression in lhs_expressions]
# weight = lhs.weight
# config = lhs.config
auto_complete_columns = [
key
for key, value in self.mongo_meta["search_fields"].items()
if "autocomplete" in value
if "autocomplete" in value and key in columns
]
columns.difference_update(set(auto_complete_columns))
query_columns = [
key
for key, value in self.mongo_meta["search_fields"].items()
if "string" in value and key not in auto_complete_columns
if "string" in value and key not in auto_complete_columns and key in columns
]
columns.difference_update(set(query_columns))
if len(columns) > 0:
raise RequiresSearchIndex(
f"SearchVectorExact requires search fields to be defined for all columns: {columns}"
)

search_query = dict()
if auto_complete_columns and query_columns:
search_query = {
"compound": {
"should": [
{"wildcard": {"path": columns, "query": query}},
{"wildcard": {"path": query_columns, "query": query}},
*[
{"autocomplete": {"path": column, "query": query}}
for column in auto_complete_columns
Expand All @@ -220,7 +231,7 @@ def _get_mongo_search(self, compiler, connection) -> dict:
elif auto_complete_columns:
search_query = {"autocomplete": {"path": auto_complete_columns, "query": query}}
elif query_columns:
search_query = {"wildcard": {"path": columns, "query": query}}
search_query = {"wildcard": {"path": query_columns, "query": query}}
return search_query


Expand Down
8 changes: 8 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.contrib.postgres.search import SearchQuery, SearchVector
from django.utils.timezone import now

from django_mongodb.query import RequiresSearchIndex
from refapp.models import RefModel
from testapp.models import (
DifferentTableOneToOne,
Expand Down Expand Up @@ -225,6 +226,7 @@ def test_prefer_search_qs():
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="CI does not have mongodb search")
@pytest.mark.django_db(databases=["mongodb"])
def test_mongo_search_index(search_index):
FooModel.objects.all().delete()
FooModel.objects.create(name="test", json_field={"foo": "bar"})
FooModel.objects.create(name="test1", json_field={"foo": "bar"})
# search index needs to sync
Expand All @@ -243,6 +245,12 @@ def test_mongo_search_index(search_index):
prefer_search_qs = FooModel.objects.all().prefer_search().filter(name="test")
assert len(list(prefer_search_qs)) == 1

with pytest.raises(RequiresSearchIndex):
search_qs = FooModel.objects.annotate(search=SearchVector("datetime_field")).filter(
search=SearchQuery("test")
)
assert len(list(search_qs)) == 0


@pytest.mark.django_db(databases=["default"])
def test_reference_model():
Expand Down
3 changes: 3 additions & 0 deletions testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class FooModel(models.Model):
time_field = models.TimeField(auto_now_add=True)
date_field = models.DateField(auto_now_add=True)

class MongoMeta:
search_fields = {"name": ["string"]}


class SameTableChild(FooModel):
dummy_model_ptr = models.OneToOneField(
Expand Down

0 comments on commit 31f2b60

Please sign in to comment.