Skip to content

Commit

Permalink
Merge pull request #903 from circulon/fix/has_many_through_not_working
Browse files Browse the repository at this point in the history
Fix has many through relationship not working
  • Loading branch information
josephmancuso authored Nov 21, 2024
2 parents c9055e9 + a844ce6 commit 5da7c38
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 90 deletions.
2 changes: 1 addition & 1 deletion src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,7 +1970,7 @@ def _register_relationships_to_model(
def _map_related(self, related_result, related):
if related.__class__.__name__ == 'MorphTo':
return related_result
elif related.__class__.__name__ == 'HasOneThrough':
elif related.__class__.__name__ in ['HasOneThrough', 'HasManyThrough']:
return related_result.group_by(related.local_key)

return related_result.group_by(related.foreign_key)
Expand Down
226 changes: 144 additions & 82 deletions src/masoniteorm/relationships/HasManyThrough.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .BaseRelationship import BaseRelationship
from ..collection import Collection
from .BaseRelationship import BaseRelationship


class HasManyThrough(BaseRelationship):
Expand Down Expand Up @@ -57,33 +57,46 @@ def __get__(self, instance, owner):
if attribute in instance._relationships:
return instance._relationships[attribute]

result = self.apply_query(
result = self.apply_related_query(
self.distant_builder, self.intermediary_builder, instance
)
return result
else:
return self

def apply_query(self, distant_builder, intermediary_builder, owner):
"""Apply the query and return a dictionary to be hydrated.
Used during accessing a relationship on a model
def apply_related_query(self, distant_builder, intermediary_builder, owner):
"""
Apply the query to return a Collection of data for the distant models to be hydrated with.
Arguments:
query {oject} -- The relationship object
owner {object} -- The current model oject.
Method is used when accessing a relationship on a model if its not
already eager loaded
Returns:
dict -- A dictionary of data which will be hydrated.
Arguments
distant_builder (QueryBuilder): QueryBuilder attached to the distant table
intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table
owner (Any): the model this relationship is starting from
Returns
Collection: Collection of dicts which will be used for hydrating models.
"""
# select * from `countries` inner join `ports` on `ports`.`country_id` = `countries`.`country_id` where `ports`.`port_id` is null and `countries`.`deleted_at` is null and `ports`.`deleted_at` is null
result = distant_builder.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{distant_builder.get_table_name()}.{self.other_owner_key}",
).where(f"{self.intermediary_builder.get_table_name()}.{self.local_owner_key}", getattr(owner, self.other_owner_key)).get()

return result
distant_table = distant_builder.get_table_name()
intermediate_table = intermediary_builder.get_table_name()

return (
self.distant_builder.select(f"{distant_table}.*, {intermediate_table}.{self.local_key}")
.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
.where(
f"{intermediate_table}.{self.local_key}",
getattr(owner, self.local_owner_key),
)
.get()
)

def relate(self, related_model):
return self.distant_builder.join(
Expand All @@ -104,51 +117,144 @@ def make_builder(self, eagers=None):

return builder

def get_related(self, query, relation, eagers=None, callback=None):
builder = self.distant_builder
def register_related(self, key, model, collection):
"""
Attach the related model to source models attribute
Arguments
key (str): The attribute name
model (Any): The model instance
collection (Collection): The data for the related models
Returns
None
"""
related = collection.get(getattr(model, self.local_owner_key), None)
if related and not isinstance(related, Collection):
related = Collection(related)

model.add_relation({key: related if related else None})

def get_related(self, current_builder, relation, eagers=None, callback=None):
"""
Get a Collection to hydrate the models for the distant table with
Used when eager loading the model attribute
Arguments
current_builder (QueryBuilder): The source models QueryBuilder object
relation (HasManyThrough): this relationship object
eagers (Any):
callback (Any):
Returns
Collection the collection of dicts to hydrate the distant models with
"""

distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()

if callback:
callback(builder)
callback(current_builder)

(
self.distant_builder.select(
f"{distant_table}.*, {intermediate_table}.{self.local_key}"
).join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
)

if isinstance(relation, Collection):
return builder.where_in(
f"{builder.get_table_name()}.{self.foreign_key}",
Collection(relation._get_value(self.local_key)).unique(),
return self.distant_builder.where_in(
f"{intermediate_table}.{self.local_key}",
Collection(relation._get_value(self.local_owner_key)).unique(),
).get()
else:
return builder.where(
f"{builder.get_table_name()}.{self.foreign_key}",
return self.distant_builder.where(
f"{intermediate_table}.{self.local_key}",
getattr(relation, self.local_owner_key),
).get()

def get_with_count_query(self, builder, callback):
query = self.distant_builder
def attach(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach method"
)

def attach_related(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach_related method"
)

if not builder._columns:
builder = builder.select("*")
def query_has(self, current_builder, method="where_exists"):
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()

return_query = builder.add_select(
getattr(current_builder, method)(
self.distant_builder.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
).where_column(
f"{intermediate_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
)

return self.distant_builder

def query_where_exists(self, current_builder, callback, method="where_exists"):
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()

getattr(current_builder, method)(
self.distant_builder.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
.where_column(
f"{intermediate_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
.when(callback, lambda q: (callback(q)))
)

def get_with_count_query(self, current_builder, callback):
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()

if not current_builder._columns:
current_builder.select("*")

return_query = current_builder.add_select(
f"{self.attribute}_count",
lambda q: (
(
q.count("*")
.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{query.get_table_name()}.{self.other_owner_key}",
f"{distant_table}.{self.other_owner_key}",
)
.where_column(
f"{builder.get_table_name()}.{self.local_owner_key}",
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
f"{intermediate_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
.table(query.get_table_name())
.table(distant_table)
.when(
callback,
lambda q: (
q.where_in(
self.foreign_key,
callback(query.select(self.other_owner_key)),
callback(
self.distant_builder.select(self.other_owner_key)
),
)
),
)
Expand All @@ -157,47 +263,3 @@ def get_with_count_query(self, builder, callback):
)

return return_query

def attach(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach method"
)

def attach_related(self, current_model, related_record):
raise NotImplementedError(
"HasOneThrough relationship does not implement the attach_related method"
)

def query_has(self, current_query_builder, method="where_exists"):
related_builder = self.get_builder()

getattr(current_query_builder, method)(
self.distant_builder.where_column(
f"{current_query_builder.get_table_name()}.{self.local_owner_key}",
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
).join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{self.distant_builder.get_table_name()}.{self.other_owner_key}",
)
)

return related_builder

def query_where_exists(
self, current_query_builder, callback, method="where_exists"
):
query = self.distant_builder

getattr(current_query_builder, method)(
query.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{query.get_table_name()}.{self.other_owner_key}",
).where_column(
f"{current_query_builder.get_table_name()}.{self.local_owner_key}",
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
)
).when(callback, lambda q: (callback(q)))
12 changes: 6 additions & 6 deletions tests/mysql/relationships/test_has_many_through.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def test_has_query(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""",
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)

def test_or_has(self):
sql = InboundShipment.where("name", "Joe").or_has("from_country").to_sql()

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""",
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)

def test_where_has_query(self):
Expand All @@ -49,7 +49,7 @@ def test_where_has_query(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""",
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)

def test_or_where_has(self):
Expand All @@ -61,15 +61,15 @@ def test_or_where_has(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""",
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)

def test_doesnt_have(self):
sql = InboundShipment.doesnt_have("from_country").to_sql()

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""",
"""SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)

def test_or_where_doesnt_have(self):
Expand All @@ -83,5 +83,5 @@ def test_or_where_doesnt_have(self):

self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""",
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
Loading

0 comments on commit 5da7c38

Please sign in to comment.