Skip to content

Commit

Permalink
chore: fixes for mongo queries (new django 5 expressions, objectid fi…
Browse files Browse the repository at this point in the history
…eld for PG, date with timezone) (#8)

* chore: support new integer lookups fully + isNull lookup

* fix: objectid field postgresql support

* fix: properly handle datetime fields timezones, even if not supported by mongodb
  • Loading branch information
gersmann authored Jul 8, 2024
1 parent 802b661 commit c85028f
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 195 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ MyModel.objects.select_related("same_table_child", "extends").all()

# simple aggregations
MyModel.objects.filter(name_in=["foo", "bar"]).count()

# raw mongo filter
MyModel.objects.filter(RawMongoDBQuery({"name": "1"})).delete()
```

### Search
Expand Down
14 changes: 14 additions & 0 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def as_operation(self, with_limits=True, with_col_aliases=False): # noqa: C901
or mongo_where.requires_search()
) and not self.query.distinct # search not supported / efficient for distinct queries

self._extend_with_stage(pipeline, "prepend")

has_attname_as_key = False
if mongo_where and build_search_pipeline:
search = mongo_where.get_mongo_search(self, self.connection)
Expand All @@ -88,6 +90,8 @@ def as_operation(self, with_limits=True, with_col_aliases=False): # noqa: C901
{"$match": mongo_where.get_mongo_query(self, self.connection, is_search=False)}
)

self._extend_with_stage(pipeline, "pre-sort")

if self.query.distinct:
has_attname_as_key = True
pipeline.extend(self.get_distinct_clause())
Expand All @@ -102,6 +106,8 @@ def as_operation(self, with_limits=True, with_col_aliases=False): # noqa: C901
if with_limit_offset and self.query.high_mark:
pipeline.append({"$limit": self.query.high_mark - self.query.low_mark})

self._extend_with_stage(pipeline, "append")

if (select_cols := self.select + extra_select) and not has_attname_as_key:
select_pipeline = MongoSelect(select_cols, self.mongo_meta).get_mongo()
pipeline.extend(select_pipeline)
Expand All @@ -114,6 +120,14 @@ def as_operation(self, with_limits=True, with_col_aliases=False): # noqa: C901
],
}

def _extend_with_stage(self, pipeline, position):
if not hasattr(self.query, "aggregation_stages"):
return
if self.query.aggregation_stages and any(
stages := [stage for pos, stage in self.query.aggregation_stages if pos == position]
):
pipeline.extend(stages)

@cached_property
def mongo_meta(self):
if hasattr(self.query.model, "MongoMeta"):
Expand Down
2 changes: 1 addition & 1 deletion django_mongodb/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def rowcount(self):
return -1
return self.result.retrieved
if isinstance(self.result, UpdateResult):
return self.result.modified_count
return self.result.matched_count # update might be a no-op in case there are no changes
raise NotSupportedError

@property
Expand Down
1 change: 1 addition & 0 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_explaining_query_execution = False
supports_json_field = True
has_native_json_field = True
supports_unlimited_charfield = True
29 changes: 21 additions & 8 deletions django_mongodb/managers.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,57 @@
from typing import TypeVar
from typing import Generic, Literal, TypeVar

from django.db import models

T = TypeVar("T")


class MongoQuerySet(models.QuerySet):
class MongoQuerySet(Generic[T], models.QuerySet[T]):
"""QuerySet which uses MongoDB as backend"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._prefer_search = False
self._aggregation_stages = []

def prefer_search(self, prefer_search=True):
obj = self._chain()
obj._prefer_search = prefer_search
obj.query.prefer_search = prefer_search
return obj

def add_aggregation_stage(
self,
stage: dict,
position: Literal["prepend", "pre-sort", "append"] = "prepend",
):
obj = self._chain()
obj._aggregation_stages.append((position, stage))
obj.query.aggregation_stages = obj._aggregation_stages
return obj

def _chain(self):
"""
Add the _prefer_search hint to the chained query
"""
obj = super()._chain()
if obj._prefer_search:
obj.query.prefer_search = obj._prefer_search
if obj._aggregation_stages:
obj.query.aggregation_stages = obj._aggregation_stages
return obj

def _clone(self):
obj = super()._clone()
obj._prefer_search = self._prefer_search
obj._aggregation_stages = self._aggregation_stages
return obj


T = TypeVar("T")


class MongoManager(models.Manager[T]):
class MongoManager(Generic[T], models.Manager[T]):
"""Manager which uses MongoDB as backend"""

def get_queryset(self):
def get_queryset(self) -> MongoQuerySet[T]:
return MongoQuerySet(self.model, using=self._db)

def prefer_search(self, require_search=True):
def prefer_search(self, require_search=True) -> MongoQuerySet[T]:
return self.get_queryset().prefer_search(require_search)
14 changes: 10 additions & 4 deletions django_mongodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,32 @@ def validators(self):
return self._validators

def db_type(self, connection):
return "ObjectId"
return "ObjectId" if "ObjectIdField" in connection.data_types else "CHAR(24)"

def to_python(self, value):
if value is None or isinstance(value, bson.ObjectId):
return value
else:
return bson.ObjectId(value)

def from_db_value(self, value, expression, connection):
return self.to_python(value)

def get_prep_value(self, value):
if value is None:
return None
if isinstance(value, str):
return bson.ObjectId(value)
return bson.ObjectId(value)

def get_db_prep_value(self, value, connection, prepared=False):
return (
self.get_prep_value(value) if "ObjectIdField" in connection.data_types else str(value)
)


class ObjectIdField(ObjectIdFieldMixin, models.CharField):
def __init__(self, *args, **kwargs):
kwargs["max_length"] = 24
super().__init__(*args, **kwargs)
pass


class ObjectIdAutoField(ObjectIdFieldMixin, AutoField, metaclass=AutoFieldMeta):
Expand Down
9 changes: 9 additions & 0 deletions django_mongodb/operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime

from bson import ObjectId
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.utils.timezone import is_aware, make_aware


class DatabaseOperations(BaseDatabaseOperations):
Expand Down Expand Up @@ -52,6 +54,11 @@ def convert_date_value(self, value, expression, connection):
return value.date()
return value

def convert_datetime_value(self, value, expression, connection):
if settings.USE_TZ and isinstance(value, datetime.datetime) and not is_aware(value):
return make_aware(value)
return value

def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
Expand All @@ -60,4 +67,6 @@ def get_db_converters(self, expression):
converters.append(self.convert_time_value)
case "DateField":
converters.append(self.convert_date_value)
case "DateTimeField":
converters.append(self.convert_datetime_value)
return converters
37 changes: 36 additions & 1 deletion django_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
GreaterThanOrEqual,
In,
IntegerFieldExact,
IntegerGreaterThan,
IntegerGreaterThanOrEqual,
IntegerLessThan,
IntegerLessThanOrEqual,
IsNull,
LessThan,
LessThanOrEqual,
Lookup,
Expand Down Expand Up @@ -147,14 +152,24 @@ class MongoEqualityComparison(MongoLookup):

def __init__(
self,
operator: LessThan | LessThanOrEqual | GreaterThan | GreaterThanOrEqual,
operator: LessThan
| LessThanOrEqual
| GreaterThan
| GreaterThanOrEqual
| IntegerLessThan
| IntegerLessThanOrEqual
| IntegerGreaterThan,
mongo_meta,
):
super().__init__(operator, mongo_meta)
self.filter_operator = {
IntegerLessThan: "$lt",
LessThan: "$lt",
IntegerLessThanOrEqual: "$lt",
LessThanOrEqual: "$lte",
IntegerGreaterThan: "$gt",
GreaterThan: "$gt",
IntegerGreaterThanOrEqual: "$gte",
GreaterThanOrEqual: "$gte",
}[type(operator)]

Expand All @@ -167,6 +182,21 @@ def _get_mongo_search(self, compiler, connection) -> dict:
}


class MongoIsNull(MongoLookup):
def _get_mongo_query(self, compiler, connection, is_search=False) -> dict:
return {self.lhs.target.column: None if self.rhs else {"$ne": None}}

def _get_mongo_search(self, compiler, connection) -> dict:
if self.lhs.target.attname not in self.mongo_meta["search_fields"]:
return {}
return {
"exists": {
"path": self.lhs.target.column,
"value": self.rhs,
}
}


class SearchNode(Node):
"""MongoDB Search Query Base Node"""

Expand Down Expand Up @@ -264,11 +294,16 @@ class MongoWhereNode:
RelatedExact: MongoExact,
In: MongoIn,
LessThan: MongoEqualityComparison,
IntegerLessThan: MongoEqualityComparison,
LessThanOrEqual: MongoEqualityComparison,
IntegerLessThanOrEqual: MongoEqualityComparison,
GreaterThan: MongoEqualityComparison,
IntegerGreaterThan: MongoEqualityComparison,
GreaterThanOrEqual: MongoEqualityComparison,
IntegerGreaterThanOrEqual: MongoEqualityComparison,
SearchVectorExact: MongoSearchVectorExact,
RawMongoDBQuery: RawMongoQueryExpression,
IsNull: MongoIsNull,
}

def __init__(self, where: WhereNode, mongo_meta):
Expand Down
Loading

0 comments on commit c85028f

Please sign in to comment.