diff --git a/src/masoniteorm/query/QueryBuilder.py b/src/masoniteorm/query/QueryBuilder.py index c3b99226..22562f77 100644 --- a/src/masoniteorm/query/QueryBuilder.py +++ b/src/masoniteorm/query/QueryBuilder.py @@ -1970,7 +1970,7 @@ def _register_relationships_to_model( def _map_related(self, related_result, related): if related.__class__.__name__ == 'MorphTo': return related_result - elif related.__class__.__name__ == 'HasOneThrough': + elif related.__class__.__name__ in ['HasOneThrough', 'HasManyThrough']: return related_result.group_by(related.local_key) return related_result.group_by(related.foreign_key) diff --git a/src/masoniteorm/relationships/HasManyThrough.py b/src/masoniteorm/relationships/HasManyThrough.py index 86174cfb..e044f9a7 100644 --- a/src/masoniteorm/relationships/HasManyThrough.py +++ b/src/masoniteorm/relationships/HasManyThrough.py @@ -1,5 +1,5 @@ -from .BaseRelationship import BaseRelationship from ..collection import Collection +from .BaseRelationship import BaseRelationship class HasManyThrough(BaseRelationship): @@ -57,33 +57,46 @@ def __get__(self, instance, owner): if attribute in instance._relationships: return instance._relationships[attribute] - result = self.apply_query( + result = self.apply_related_query( self.distant_builder, self.intermediary_builder, instance ) return result else: return self - def apply_query(self, distant_builder, intermediary_builder, owner): - """Apply the query and return a dictionary to be hydrated. - Used during accessing a relationship on a model + def apply_related_query(self, distant_builder, intermediary_builder, owner): + """ + Apply the query to return a Collection of data for the distant models to be hydrated with. - Arguments: - query {oject} -- The relationship object - owner {object} -- The current model oject. + Method is used when accessing a relationship on a model if its not + already eager loaded - Returns: - dict -- A dictionary of data which will be hydrated. + Arguments + distant_builder (QueryBuilder): QueryBuilder attached to the distant table + intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table + owner (Any): the model this relationship is starting from + + Returns + Collection: Collection of dicts which will be used for hydrating models. """ - # select * from `countries` inner join `ports` on `ports`.`country_id` = `countries`.`country_id` where `ports`.`port_id` is null and `countries`.`deleted_at` is null and `ports`.`deleted_at` is null - result = distant_builder.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{distant_builder.get_table_name()}.{self.other_owner_key}", - ).where(f"{self.intermediary_builder.get_table_name()}.{self.local_owner_key}", getattr(owner, self.other_owner_key)).get() - return result + distant_table = distant_builder.get_table_name() + intermediate_table = intermediary_builder.get_table_name() + + return ( + self.distant_builder.select(f"{distant_table}.*, {intermediate_table}.{self.local_key}") + .join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + .where( + f"{intermediate_table}.{self.local_key}", + getattr(owner, self.local_owner_key), + ) + .get() + ) def relate(self, related_model): return self.distant_builder.join( @@ -104,51 +117,144 @@ def make_builder(self, eagers=None): return builder - def get_related(self, query, relation, eagers=None, callback=None): - builder = self.distant_builder + def register_related(self, key, model, collection): + """ + Attach the related model to source models attribute + + Arguments + key (str): The attribute name + model (Any): The model instance + collection (Collection): The data for the related models + + Returns + None + """ + related = collection.get(getattr(model, self.local_owner_key), None) + if related and not isinstance(related, Collection): + related = Collection(related) + + model.add_relation({key: related if related else None}) + + def get_related(self, current_builder, relation, eagers=None, callback=None): + """ + Get a Collection to hydrate the models for the distant table with + Used when eager loading the model attribute + + Arguments + current_builder (QueryBuilder): The source models QueryBuilder object + relation (HasManyThrough): this relationship object + eagers (Any): + callback (Any): + + Returns + Collection the collection of dicts to hydrate the distant models with + """ + + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() if callback: - callback(builder) + callback(current_builder) + + ( + self.distant_builder.select( + f"{distant_table}.*, {intermediate_table}.{self.local_key}" + ).join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + ) if isinstance(relation, Collection): - return builder.where_in( - f"{builder.get_table_name()}.{self.foreign_key}", - Collection(relation._get_value(self.local_key)).unique(), + return self.distant_builder.where_in( + f"{intermediate_table}.{self.local_key}", + Collection(relation._get_value(self.local_owner_key)).unique(), ).get() else: - return builder.where( - f"{builder.get_table_name()}.{self.foreign_key}", + return self.distant_builder.where( + f"{intermediate_table}.{self.local_key}", getattr(relation, self.local_owner_key), ).get() - def get_with_count_query(self, builder, callback): - query = self.distant_builder + def attach(self, current_model, related_record): + raise NotImplementedError( + "HasOneThrough relationship does not implement the attach method" + ) + + def attach_related(self, current_model, related_record): + raise NotImplementedError( + "HasOneThrough relationship does not implement the attach_related method" + ) - if not builder._columns: - builder = builder.select("*") + def query_has(self, current_builder, method="where_exists"): + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() - return_query = builder.add_select( + getattr(current_builder, method)( + self.distant_builder.join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ).where_column( + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", + ) + ) + + return self.distant_builder + + def query_where_exists(self, current_builder, callback, method="where_exists"): + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() + + getattr(current_builder, method)( + self.distant_builder.join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + .where_column( + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", + ) + .when(callback, lambda q: (callback(q))) + ) + + def get_with_count_query(self, current_builder, callback): + distant_table = self.distant_builder.get_table_name() + intermediate_table = self.intermediary_builder.get_table_name() + + if not current_builder._columns: + current_builder.select("*") + + return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( ( q.count("*") .join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", "=", - f"{query.get_table_name()}.{self.other_owner_key}", + f"{distant_table}.{self.other_owner_key}", ) .where_column( - f"{builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", ) - .table(query.get_table_name()) + .table(distant_table) .when( callback, lambda q: ( q.where_in( self.foreign_key, - callback(query.select(self.other_owner_key)), + callback( + self.distant_builder.select(self.other_owner_key) + ), ) ), ) @@ -157,47 +263,3 @@ def get_with_count_query(self, builder, callback): ) return return_query - - def attach(self, current_model, related_record): - raise NotImplementedError( - "HasOneThrough relationship does not implement the attach method" - ) - - def attach_related(self, current_model, related_record): - raise NotImplementedError( - "HasOneThrough relationship does not implement the attach_related method" - ) - - def query_has(self, current_query_builder, method="where_exists"): - related_builder = self.get_builder() - - getattr(current_query_builder, method)( - self.distant_builder.where_column( - f"{current_query_builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", - ).join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{self.distant_builder.get_table_name()}.{self.other_owner_key}", - ) - ) - - return related_builder - - def query_where_exists( - self, current_query_builder, callback, method="where_exists" - ): - query = self.distant_builder - - getattr(current_query_builder, method)( - query.join( - f"{self.intermediary_builder.get_table_name()}", - f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", - "=", - f"{query.get_table_name()}.{self.other_owner_key}", - ).where_column( - f"{current_query_builder.get_table_name()}.{self.local_owner_key}", - f"{self.intermediary_builder.get_table_name()}.{self.local_key}", - ) - ).when(callback, lambda q: (callback(q))) diff --git a/tests/mysql/relationships/test_has_many_through.py b/tests/mysql/relationships/test_has_many_through.py index 61830c2b..3c6d5b7e 100644 --- a/tests/mysql/relationships/test_has_many_through.py +++ b/tests/mysql/relationships/test_has_many_through.py @@ -31,7 +31,7 @@ def test_has_query(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_has(self): @@ -39,7 +39,7 @@ def test_or_has(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_where_has_query(self): @@ -49,7 +49,7 @@ def test_where_has_query(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_or_where_has(self): @@ -61,7 +61,7 @@ def test_or_where_has(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_doesnt_have(self): @@ -69,7 +69,7 @@ def test_doesnt_have(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`)""", + """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_where_doesnt_have(self): @@ -83,5 +83,5 @@ def test_or_where_doesnt_have(self): self.assertEqual( sql, - """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `inbound_shipments`.`from_port_id` = `ports`.`port_id`) AND `inbound_shipments`.`name` = 'USA'""", + """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) diff --git a/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py new file mode 100644 index 00000000..baf68eae --- /dev/null +++ b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py @@ -0,0 +1,144 @@ +import unittest + +from src.masoniteorm.collection import Collection +from src.masoniteorm.models import Model +from src.masoniteorm.relationships import has_many_through +from tests.integrations.config.database import DATABASES +from src.masoniteorm.schema import Schema +from src.masoniteorm.schema.platforms import SQLitePlatform + + +class Enrolment(Model): + __table__ = "enrolment" + __connection__ = "dev" + __fillable__ = ["active_student_id", "in_course_id"] + + +class Student(Model): + __table__ = "student" + __connection__ = "dev" + __fillable__ = ["student_id", "name"] + + +class Course(Model): + __table__ = "course" + __connection__ = "dev" + __fillable__ = ["course_id", "name"] + + @has_many_through( + None, + "in_course_id", + "active_student_id", + "course_id", + "student_id" + ) + def students(self): + return [Student, Enrolment] + + +class TestHasManyThroughRelationship(unittest.TestCase): + def setUp(self): + self.schema = Schema( + connection="dev", + connection_details=DATABASES, + platform=SQLitePlatform, + ).on("dev") + + with self.schema.create_table_if_not_exists("student") as table: + table.integer("student_id").primary() + table.string("name") + + with self.schema.create_table_if_not_exists("course") as table: + table.integer("course_id").primary() + table.string("name") + + with self.schema.create_table_if_not_exists("enrolment") as table: + table.integer("enrolment_id").primary() + table.integer("active_student_id") + table.integer("in_course_id") + + if not Course.count(): + Course.builder.new().bulk_create( + [ + {"course_id": 10, "name": "Math 101"}, + {"course_id": 20, "name": "History 101"}, + {"course_id": 30, "name": "Math 302"}, + {"course_id": 40, "name": "Biology 302"}, + ] + ) + + if not Student.count(): + Student.builder.new().bulk_create( + [ + {"student_id": 100, "name": "Bob"}, + {"student_id": 200, "name": "Alice"}, + {"student_id": 300, "name": "Steve"}, + {"student_id": 400, "name": "Megan"}, + ] + ) + + if not Enrolment.count(): + Enrolment.builder.new().bulk_create( + [ + {"active_student_id": 100, "in_course_id": 30}, + {"active_student_id": 200, "in_course_id": 10}, + {"active_student_id": 100, "in_course_id": 10}, + {"active_student_id": 400, "in_course_id": 20}, + ] + ) + + def test_has_many_through_can_eager_load(self): + courses = Course.where("name", "Math 101").with_("students").get() + students = courses.first().students + + self.assertIsInstance(students, Collection) + self.assertEqual(students.count(), 2) + + student1 = students.shift() + self.assertIsInstance(student1, Student) + self.assertEqual(student1.name, "Alice") + + student2 = students.shift() + self.assertIsInstance(student2, Student) + self.assertEqual(student2.name, "Bob") + + # check .first() and .get() produce the same result + single = ( + Course.where("name", "History 101") + .with_("students") + .first() + ) + self.assertIsInstance(single.students, Collection) + + single_get = ( + Course.where("name", "History 101").with_("students").get() + ) + + print(single.students) + print(single_get.first().students) + self.assertEqual(single.students.count(), 1) + self.assertEqual(single_get.first().students.count(), 1) + + single_name = single.students.first().name + single_get_name = single_get.first().students.first().name + self.assertEqual(single_name, single_get_name) + + def test_has_many_through_eager_load_can_be_empty(self): + courses = ( + Course.where("name", "Biology 302") + .with_("students") + .get() + ) + self.assertIsNone(courses.first().students) + + def test_has_many_through_can_get_related(self): + course = Course.where("name", "Math 101").first() + self.assertIsInstance(course.students, Collection) + self.assertIsInstance(course.students.first(), Student) + self.assertEqual(course.students.count(), 2) + + def test_has_many_through_has_query(self): + courses = Course.where_has( + "students", lambda query: query.where("name", "Bob") + ) + self.assertEqual(courses.count(), 2) diff --git a/tests/sqlite/relationships/test_sqlite_has_through_relationships.py b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py similarity index 98% rename from tests/sqlite/relationships/test_sqlite_has_through_relationships.py rename to tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py index d23f4847..dee1bff9 100644 --- a/tests/sqlite/relationships/test_sqlite_has_through_relationships.py +++ b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py @@ -29,7 +29,8 @@ def from_country(self): return [Country, Port] -class TestRelationships(unittest.TestCase): + +class TestHasOneThroughRelationship(unittest.TestCase): def setUp(self): self.schema = Schema( connection="dev",