diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d0a6d730..8532c799 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -16,14 +16,15 @@ jobs: fail-fast: false matrix: python-version: ["3.11", "3.10", "3.9", "3.8", "3.7"] - neo4j-version: ["enterprise", "5.5-enterprise", "4.4-enterprise"] + neo4j-version: ["community", "enterprise", "5.5-enterprise", "4.4-enterprise", "4.4-community"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + cache: 'pip' - name: Creating Neo4j Container run: | chmod +x ./docker-scripts/docker-neo4j.sh diff --git a/.sonarcloud.properties b/.sonarcloud.properties index c74ae42a..d1ab9656 100644 --- a/.sonarcloud.properties +++ b/.sonarcloud.properties @@ -1,2 +1,3 @@ sonar.sources = neomodel/ -sonar.tests = test/ \ No newline at end of file +sonar.tests = test/ +sonar.python.version = 3.7, 3.8, 3.9, 3.10, 3.11 \ No newline at end of file diff --git a/Changelog b/Changelog index 3a870e01..9e0a1486 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,9 @@ +Version 5.1.1 2023-08 +* Add impersonation +* Bumped neo4j-driver to 5.11.0 +* Add automatic path inflation #715 +* Improve code quality and tooling + Version 5.1.0 2023-07 * Bumped neo4j-driver version to 5.10.0 * Breaking change : When using neomodel along with Neo4j version 5, use StructuredNode and StructuredRel's element_id property instead of id. If you have Cypher queries which currently use the id() function, migrate them to elementId() instead. diff --git a/doc/source/conf.py b/doc/source/conf.py index 672ef163..41e50bbf 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -7,7 +7,8 @@ import alabaster -from neomodel import __author__, __package__, __version__ +from neomodel import __author__, __package__ +from neomodel._version import __version__ # # neomodel documentation build configuration file, created by diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index 6decef62..854af8a8 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -22,7 +22,7 @@ Adjust driver configuration:: config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default config.RESOLVER = None # default config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = None # default + config.USER_AGENT = neomodel/vNeo4j.Major.minor # default Setting the database name, for neo4j >= 4:: @@ -38,7 +38,7 @@ constraints and indexes at compile time. However this method is only recommended # before loading your node definitions config.AUTO_INSTALL_LABELS = True -Neomodel also provides the `neomodel_install_labels` script for this task, +Neomodel also provides the :ref:`neomodel_install_labels` script for this task, however if you want to handle this manually see below. Install indexes and constraints for a single class:: diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 3d97464a..3ec09ceb 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -99,8 +99,8 @@ and constraints. Applying constraints and indexes ================================ -After creating a model in Python, any constraints or indexes need must be applied to Neo4j and ``neomodel`` provides a -script to automate this: :: +After creating a model in Python, any constraints or indexes must be applied to Neo4j and ``neomodel`` provides a +script (:ref:`neomodel_install_labels`) to automate this: :: $ neomodel_install_labels yourapp.py someapp.models --db bolt://neo4j:neo4j@localhost:7687 @@ -108,7 +108,7 @@ It is important to execute this after altering the schema and observe the number Remove existing constraints and indexes ======================================= -Similarly, ``neomodel`` provides a script to automate the removal of all existing constraints and indexes from +Similarly, ``neomodel`` provides a script (:ref:`neomodel_remove_labels`) to automate the removal of all existing constraints and indexes from the database, when this is required: :: $ neomodel_remove_labels --db bolt://neo4j:neo4j@localhost:7687 diff --git a/doc/source/module_documentation.rst b/doc/source/module_documentation.rst index f32a76eb..16a4acf2 100644 --- a/doc/source/module_documentation.rst +++ b/doc/source/module_documentation.rst @@ -3,37 +3,37 @@ Modules documentation ===================== Database --------- +======== .. module:: neomodel.util .. autoclass:: neomodel.util.Database :members: :undoc-members: Core ----- +==== .. automodule:: neomodel.core :members: .. _semistructurednode_doc: ``SemiStructuredNode`` -^^^^^^^^^^^^^^^^^^^^^^ +---------------------- .. autoclass:: neomodel.contrib.SemiStructuredNode Properties ----------- +========== .. automodule:: neomodel.properties :members: :show-inheritance: Spatial Properties & Datatypes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +------------------------------ .. automodule:: neomodel.contrib.spatial_properties :members: :show-inheritance: Relationships -------------- +============= .. automodule:: neomodel.relationship :members: :show-inheritance: @@ -46,8 +46,18 @@ Relationships :members: :show-inheritance: +Paths +===== + +.. automodule:: neomodel.path + :members: + :show-inheritance: + + + + Match ------ +===== .. module:: neomodel.match .. autoclass:: neomodel.match.BaseSet :members: @@ -61,9 +71,23 @@ Match Exceptions ----------- +========== .. automodule:: neomodel.exceptions :members: :undoc-members: :show-inheritance: + +Scripts +======= + +.. automodule:: neomodel.scripts.neomodel_install_labels + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: neomodel.scripts.neomodel_remove_labels + :members: + :undoc-members: + :show-inheritance: + diff --git a/doc/source/queries.rst b/doc/source/queries.rst index 7362c856..c3e37629 100644 --- a/doc/source/queries.rst +++ b/doc/source/queries.rst @@ -186,3 +186,57 @@ For random ordering simply pass '?' to the order_by method:: Coffee.nodes.order_by('?') +Retrieving paths +================ + +You can retrieve a whole path of already instantiated objects corresponding to +the nodes and relationship classes with a single query. + +Suppose the following schema: + +:: + + class PersonLivesInCity(StructuredRel): + some_num = IntegerProperty(index=True, + default=12) + + class CountryOfOrigin(StructuredNode): + code = StringProperty(unique_index=True, + required=True) + + class CityOfResidence(StructuredNode): + name = StringProperty(required=True) + country = RelationshipTo(CountryOfOrigin, + 'FROM_COUNTRY') + + class PersonOfInterest(StructuredNode): + uid = UniqueIdProperty() + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True, + default=0) + + country = RelationshipTo(CountryOfOrigin, + 'IS_FROM') + city = RelationshipTo(CityOfResidence, + 'LIVES_IN', + model=PersonLivesInCity) + +Then, paths can be retrieved with: + +:: + + q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + resolve_objects = True) + +Notice here that ``resolve_objects`` is set to ``True``. This results in ``q`` being a +list of ``result, result_name`` and ``q[0][0][0]`` being a ``NeomodelPath`` object. + +``NeomodelPath`` ``nodes, relationships`` attributes contain already instantiated objects of the +nodes and relationships in the query, *in order of appearance*. + +It would be particularly useful to note here that each object is read exactly once from +the database. Therefore, nodes will be instantiated to their neomodel node objects and +relationships to their relationship models *if such a model exists*. In other words, +relationships with data (such as ``PersonLivesInCity`` above) will be instantiated to their +respective objects or ``StrucuredRel`` otherwise. Relationships do not "reload" their +end-points (unless this is required). diff --git a/doc/source/transactions.rst b/doc/source/transactions.rst index f8d025d2..dfa97ee6 100644 --- a/doc/source/transactions.rst +++ b/doc/source/transactions.rst @@ -162,3 +162,39 @@ or manually:: bookmark = db.commit() except Exception as e: db.rollback() + +Impersonation +------------- + +*Neo4j Enterprise feature* + +Impersonation (`see Neo4j driver documentation ``) +can be enabled via a context manager:: + + from neomodel import db + + with db.impersonate(user="writeuser"): + Person(name='Bob').save() + +or as a function decorator:: + + @db.impersonate(user="writeuser") + def update_user_name(uid, name): + user = Person.nodes.filter(uid=uid)[0] + user.name = name + user.save() + +This can be mixed with other context manager like transactions:: + + from neomodel import db + + @db.impersonate(user="tempuser") + # Both transactions will be run as the same impersonated user + def func0(): + @db.transaction() + def func1(): + ... + + @db.transaction() + def func2(): + ... \ No newline at end of file diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 3899ecee..23e0142a 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -33,9 +33,9 @@ ) from .relationship import StructuredRel from .util import change_neo4j_password, clear_neo4j_database +from .path import NeomodelPath __author__ = "Robin Edwards" __email__ = "robin.ge@gmail.com" __license__ = "MIT" __package__ = "neomodel" -__version__ = "5.1.0" diff --git a/neomodel/_version.py b/neomodel/_version.py new file mode 100644 index 00000000..a9c316e2 --- /dev/null +++ b/neomodel/_version.py @@ -0,0 +1 @@ +__version__ = "5.1.1" diff --git a/neomodel/config.py b/neomodel/config.py index 44858141..1f6df10b 100644 --- a/neomodel/config.py +++ b/neomodel/config.py @@ -1,5 +1,7 @@ import neo4j +from ._version import __version__ + AUTO_INSTALL_LABELS = False DATABASE_URL = "bolt://neo4j:foobarbaz@localhost:7687" FORCE_TIMEZONE = False @@ -13,4 +15,4 @@ MAX_TRANSACTION_RETRY_TIME = 30.0 RESOLVER = None TRUSTED_CERTIFICATES = neo4j.TrustSystemCAs() -USER_AGENT = None +USER_AGENT = f"neomodel/v{__version__}" diff --git a/neomodel/core.py b/neomodel/core.py index d6bba57f..00198fa4 100644 --- a/neomodel/core.py +++ b/neomodel/core.py @@ -12,6 +12,11 @@ db = Database() +RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" +INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" +CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" +STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" + def drop_constraints(quiet=True, stdout=None): """ @@ -100,76 +105,102 @@ def install_labels(cls, quiet=True, stdout=None): ) return - # Create indexes and constraints for node properties for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name - if property.index: - if not quiet: - stdout.write( - f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" - ) - try: - db.cypher_query( - f"CREATE INDEX index_{cls.__label__}_{db_property} FOR (n:{cls.__label__}) ON (n.{db_property}); " - ) - except ClientError as e: - if e.code in ( - "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists", - "Neo.ClientError.Schema.IndexAlreadyExists", - ): - stdout.write(f"{str(e)}\n") - else: - raise + _install_node(cls, name, property, quiet, stdout) - elif property.unique_index: - if not quiet: - stdout.write( - f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" - ) - try: - db.cypher_query( - f"""CREATE CONSTRAINT constraint_unique_{cls.__label__}_{db_property} - FOR (n:{cls.__label__}) REQUIRE n.{db_property} IS UNIQUE""" - ) - except ClientError as e: - if e.code in ( - "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists", - "Neo.ClientError.Schema.ConstraintAlreadyExists", - ): - stdout.write(f"{str(e)}\n") - else: - raise - - # TODO : Add support for existence constraints - - # Create indexes and constraints for relationship properties for _, relationship in cls.defined_properties( aliases=False, rels=True, properties=False ).items(): - relationship_cls = relationship.definition["model"] - if relationship_cls is not None: - relationship_type = relationship.definition["relation_type"] - for prop_name, property in relationship_cls.defined_properties( - aliases=False, rels=False - ).items(): - db_property = property.db_property or prop_name - if property.index: - if not quiet: - stdout.write( - f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" - ) - try: - db.cypher_query( - f"CREATE INDEX index_{relationship_type}_{db_property} FOR ()-[r:{relationship_type}]-() ON (r.{db_property}); " - ) - except ClientError as e: - if e.code in ( - "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists", - "Neo.ClientError.Schema.IndexAlreadyExists", - ): - stdout.write(f"{str(e)}\n") - else: - raise + _install_relationship(cls, relationship, quiet, stdout) + + +def _create_node_index(label: str, property_name: str, stdout): + try: + db.cypher_query( + f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + +def _create_node_constraint(label: str, property_name: str, stdout): + try: + db.cypher_query( + f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} + FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + +def _create_relationship_index(relationship_type: str, property_name: str, stdout): + try: + db.cypher_query( + f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + +def _install_node(cls, name, property, quiet, stdout): + # Create indexes and constraints for node property + db_property = property.db_property or name + if property.index: + if not quiet: + stdout.write( + f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + _create_node_index( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + _create_node_constraint( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + +def _install_relationship(cls, relationship, quiet, stdout): + # Create indexes and constraints for relationship property + relationship_cls = relationship.definition["model"] + if relationship_cls is not None: + relationship_type = relationship.definition["relation_type"] + for prop_name, property in relationship_cls.defined_properties( + aliases=False, rels=False + ).items(): + db_property = property.db_property or prop_name + if property.index: + if not quiet: + stdout.write( + f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + _create_relationship_index( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) def install_all_labels(stdout=None): @@ -185,9 +216,10 @@ def install_all_labels(stdout=None): stdout = sys.stdout def subsub(cls): # recursively return all subclasses - return cls.__subclasses__() + [ - g for s in cls.__subclasses__() for g in subsub(s) - ] + subclasses = cls.__subclasses__() + if not subclasses: # base case: no more subclasses + return [] + return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] stdout.write("Setting up indexes and constraints...\n\n") @@ -248,24 +280,28 @@ def __new__(mcs, name, bases, namespace): if config.AUTO_INSTALL_LABELS: install_labels(cls, quiet=False) - base_label_set = frozenset(cls.inherited_labels()) - optional_label_set = set(cls.inherited_optional_labels()) + build_class_registry(cls) + + return cls - # Construct all possible combinations of labels + optional labels - possible_label_combinations = [ - frozenset(set(x).union(base_label_set)) - for i in range(1, len(optional_label_set) + 1) - for x in combinations(optional_label_set, i) - ] - possible_label_combinations.append(base_label_set) - for label_set in possible_label_combinations: - if label_set not in db._NODE_CLASS_REGISTRY: - db._NODE_CLASS_REGISTRY[label_set] = cls - else: - raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY) +def build_class_registry(cls): + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) - return cls + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if label_set not in db._NODE_CLASS_REGISTRY: + db._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY) NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) @@ -421,7 +457,7 @@ def create(cls, *props, **kwargs): if "streaming" in kwargs: warnings.warn( - "streaming is not supported by bolt, please remove the kwarg", + STREAMING_WARNING, category=DeprecationWarning, stacklevel=1, ) @@ -490,7 +526,7 @@ def create_or_update(cls, *props, **kwargs): if "streaming" in kwargs: warnings.warn( - "streaming is not supported by bolt, please remove the kwarg", + STREAMING_WARNING, category=DeprecationWarning, stacklevel=1, ) @@ -559,7 +595,7 @@ def get_or_create(cls, *props, **kwargs): if "streaming" in kwargs: warnings.warn( - "streaming is not supported by bolt, please remove the kwarg", + STREAMING_WARNING, category=DeprecationWarning, stacklevel=1, ) diff --git a/neomodel/exceptions.py b/neomodel/exceptions.py index edcfe901..36b3ba5b 100644 --- a/neomodel/exceptions.py +++ b/neomodel/exceptions.py @@ -233,6 +233,11 @@ def __init__(self, msg): self.message = msg +class FeatureNotSupported(NeomodelException): + def __init__(self, msg): + self.message = msg + + __all__ = ( AttemptedCardinalityViolation.__name__, CardinalityViolation.__name__, @@ -251,4 +256,5 @@ def __init__(self, msg): NodeClassAlreadyDefined.__name__, RelationshipClassNotDefined.__name__, RelationshipClassRedefined.__name__, + FeatureNotSupported.__name__, ) diff --git a/neomodel/match.py b/neomodel/match.py index b458039a..b18842df 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -2,6 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass +from typing import Optional from .core import StructuredNode, db from .exceptions import MultipleNodesReturned @@ -11,13 +12,6 @@ OUTGOING, INCOMING, EITHER = 1, -1, 0 -# basestring python 3.x fallback -try: - basestring -except NameError: - basestring = str - - def _rel_helper( lhs, rhs, @@ -25,7 +19,7 @@ def _rel_helper( relation_type=None, direction=None, relation_properties=None, - **kwargs, + **kwargs, # NOSONAR ): """ Generate a relationship matching string, with specified parameters. @@ -55,7 +49,7 @@ def _rel_helper( rel_props = f" {{{rel_props_str}}}" rel_def = "" - # direct, relation_type=None is unspecified, relation_type + # relation_type is unspecified if relation_type is None: rel_def = "" # all("*" wildcard) relation_type @@ -89,7 +83,7 @@ def _rel_merge_helper( relation_type=None, direction=None, relation_properties=None, - **kwargs, + **kwargs, # NOSONAR ): """ Generate a relationship merging string, with specified parameters. @@ -141,7 +135,7 @@ def _rel_merge_helper( rel_none_props = ( f" ON CREATE SET {rel_prop_val_str} ON MATCH SET {rel_prop_val_str}" ) - # direct, relation_type=None is unspecified, relation_type + # relation_type is unspecified if relation_type is None: stmt = stmt.format("") # all("*" wildcard) relation_type @@ -213,7 +207,7 @@ def install_traversals(cls, node_set): raise ValueError(f"Cannot install traversal '{key}' exists on NodeSet") rel = getattr(cls, key) - rel._lookup_node_class() + rel.lookup_node_class() traversal = Traversal(source=node_set, name=key, definition=rel.definition) setattr(node_set, key, traversal) @@ -243,30 +237,12 @@ def process_filter_args(cls, kwargs): prop = property_obj.aliased_to() deflated_value = getattr(cls, prop).deflate(value) else: - # handle special operators - if operator == _SPECIAL_OPERATOR_IN: - if not isinstance(value, tuple) and not isinstance(value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {key}={value}" - ) - deflated_value = [property_obj.deflate(v) for v in value] - elif operator == _SPECIAL_OPERATOR_ISNULL: - if not isinstance(value, bool): - raise ValueError( - f"Value must be a bool for isnull operation on {key}" - ) - operator = "IS NULL" if value else "IS NOT NULL" - deflated_value = None - elif operator in _REGEX_OPERATOR_TABLE.values(): - deflated_value = property_obj.deflate(value) - if not isinstance(deflated_value, basestring): - raise ValueError(f"Must be a string value for {key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX - else: - deflated_value = property_obj.deflate(value) + operator, deflated_value = transform_operator_to_filter( + operator=operator, + filter_key=key, + filter_value=value, + property_obj=property_obj, + ) # map property to correct property name in the database db_property = cls.defined_properties(rels=False)[prop].db_property or prop @@ -276,6 +252,35 @@ def process_filter_args(cls, kwargs): return output +def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): + # handle special operators + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): + raise ValueError( + f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" + ) + deflated_value = [property_obj.deflate(v) for v in filter_value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(filter_value, bool): + raise ValueError( + f"Value must be a bool for isnull operation on {filter_key}" + ) + operator = "IS NULL" if filter_value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(filter_value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {filter_key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(filter_value) + + return operator, deflated_value + + def process_has_args(cls, kwargs): """ loop through has parameters check they correspond to class rels defined @@ -290,7 +295,7 @@ def process_has_args(cls, kwargs): rhs_ident = key - rel_definitions[key]._lookup_node_class() + rel_definitions[key].lookup_node_class() if value is True: match[rhs_ident] = rel_definitions[key].definition @@ -304,10 +309,50 @@ def process_has_args(cls, kwargs): return match, dont_match +class QueryAST: + match: Optional[list] + optional_match: Optional[list] + where: Optional[list] + with_clause: Optional[str] + return_clause: Optional[str] + order_by: Optional[str] + skip: Optional[int] + limit: Optional[int] + result_class: Optional[type] + lookup: Optional[str] + additional_return: Optional[list] + + def __init__( + self, + match: Optional[list] = None, + optional_match: Optional[list] = None, + where: Optional[list] = None, + with_clause: Optional[str] = None, + return_clause: Optional[str] = None, + order_by: Optional[str] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + result_class: Optional[type] = None, + lookup: Optional[str] = None, + additional_return: Optional[list] = None, + ): + self.match = match if match else [] + self.optional_match = optional_match if optional_match else [] + self.where = where if where else [] + self.with_clause = with_clause + self.return_clause = return_clause + self.order_by = order_by + self.skip = skip + self.limit = limit + self.result_class = result_class + self.lookup = lookup + self.additional_return = additional_return if additional_return else [] + + class QueryBuilder: def __init__(self, node_set): self.node_set = node_set - self._ast = {"match": [], "where": [], "optional match": []} + self._ast = QueryAST() self._query_params = {} self._place_holder_registry = {} self._ident_count = 0 @@ -321,9 +366,9 @@ def build_ast(self): self.build_source(self.node_set) if hasattr(self.node_set, "skip"): - self._ast["skip"] = self.node_set.skip + self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): - self._ast["limit"] = self.node_set.limit + self._ast.limit = self.node_set.limit return self @@ -340,7 +385,7 @@ def build_source(self, source): self.build_additional_match(ident, source) - if hasattr(source, "_order_by"): + if hasattr(source, "order_by_elements"): self.build_order_by(ident, source) if source.filters or source.q_filters: @@ -361,11 +406,11 @@ def create_ident(self): return "r" + str(self._ident_count) def build_order_by(self, ident, source): - if "?" in source._order_by: - self._ast["with"] = f"{ident}, rand() as r" - self._ast["order_by"] = "r" + if "?" in source.order_by_elements: + self._ast.with_clause = f"{ident}, rand() as r" + self._ast.order_by = "r" else: - self._ast["order_by"] = [f"{ident}.{p}" for p in source._order_by] + self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] def build_traversal(self, traversal): """ @@ -377,8 +422,8 @@ def build_traversal(self, traversal): # build source lhs_ident = self.build_source(traversal.source) rhs_ident = traversal.name + rhs_label - self._ast["return"] = traversal.name - self._ast["result_class"] = traversal.target_class + self._ast.return_clause = traversal.name + self._ast.result_class = traversal.target_class rel_ident = self.create_ident() stmt = _rel_helper( @@ -387,7 +432,7 @@ def build_traversal(self, traversal): ident=rel_ident, **traversal.definition, ) - self._ast["match"].append(stmt) + self._ast.match.append(stmt) if traversal.filters: self.build_where_stmt(rel_ident, traversal.filters) @@ -395,11 +440,8 @@ def build_traversal(self, traversal): return traversal.name def _additional_return(self, name): - key = "additional_return" - if key not in self._ast: - self._ast[key] = [] - if name not in self._ast[key] and name != self._ast.get("return"): - self._ast[key].append(name) + if name not in self._ast.additional_return and name != self._ast.return_clause: + self._ast.additional_return.append(name) def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] @@ -409,7 +451,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relationship = getattr(source_class_iterator, part) # build source if "node_class" not in relationship.definition: - relationship._lookup_node_class() + relationship.lookup_node_class() rhs_label = relationship.definition["node_class"].__label__ rel_reference = f'{relationship.definition["node_class"]}_{part}' self._node_counters[rel_reference] += 1 @@ -426,7 +468,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # This is the first one, we make sure that 'return' # contains the primary node so _contains() works # as usual - self._ast["return"] = lhs_name + self._ast.return_clause = lhs_name else: self._additional_return(lhs_name) else: @@ -444,9 +486,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: source_class_iterator = relationship.definition["node_class"] if relation.get("optional"): - self._ast["optional match"].append(stmt) + self._ast.optional_match.append(stmt) else: - self._ast["match"].append(stmt) + self._ast.match.append(stmt) return rhs_name def build_node(self, node): @@ -455,12 +497,12 @@ def build_node(self, node): # Hack to emulate START to lookup a node by id _node_lookup = f"MATCH ({ident}) WHERE {db.get_id_method()}({ident})=${place_holder} WITH {ident}" - self._ast["lookup"] = _node_lookup + self._ast.lookup = _node_lookup self._query_params[place_holder] = node.element_id - self._ast["return"] = ident - self._ast["result_class"] = node.__class__ + self._ast.return_clause = ident + self._ast.result_class = node.__class__ return ident def build_label(self, ident, cls): @@ -469,13 +511,12 @@ def build_label(self, ident, cls): """ ident_w_label = ident + ":" + cls.__label__ - if not self._ast.get("return") and ( - "additional_return" not in self._ast - or ident not in self._ast["additional_return"] + if not self._ast.return_clause and ( + not self._ast.additional_return or ident not in self._ast.additional_return ): - self._ast["match"].append(f"({ident_w_label})") - self._ast["return"] = ident - self._ast["result_class"] = cls + self._ast.match.append(f"({ident_w_label})") + self._ast.return_clause = ident + self._ast.result_class = cls return ident def build_additional_match(self, ident, node_set): @@ -488,7 +529,7 @@ def build_additional_match(self, ident, node_set): if isinstance(value, dict): label = ":" + value["node_class"].__label__ stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **value) - self._ast["where"].append(stmt) + self._ast.where.append(stmt) else: raise ValueError("Expecting dict got: " + repr(value)) @@ -496,7 +537,7 @@ def build_additional_match(self, ident, node_set): if isinstance(val, dict): label = ":" + val["node_class"].__label__ stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **val) - self._ast["where"].append("NOT " + stmt) + self._ast.where.append("NOT " + stmt) else: raise ValueError("Expecting dict got: " + repr(val)) @@ -540,7 +581,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): if q_filters is not None: stmts = self._parse_q_filters(ident, q_filters, source_class) if stmts: - self._ast["where"].append(stmts) + self._ast.where.append(stmts) else: stmts = [] for row in filters: @@ -564,85 +605,87 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._query_params[place_holder] = val stmts.append(statement) - self._ast["where"].append(" AND ".join(stmts)) + self._ast.where.append(" AND ".join(stmts)) def build_query(self): query = "" - if "lookup" in self._ast: - query += self._ast["lookup"] + if self._ast.lookup: + query += self._ast.lookup # Instead of using only one MATCH statement for every relation # to follow, we use one MATCH per relation (to avoid cartesian # product issues...). # There might be optimizations to be done, using projections, # or pusing patterns instead of a chain of OPTIONAL MATCH. - if len(self._ast["match"]) > 0: + if self._ast.match: query += " MATCH " - query += " MATCH ".join(i for i in self._ast["match"]) + query += " MATCH ".join(i for i in self._ast.match) - if len(self._ast["optional match"]): + if self._ast.optional_match: query += " OPTIONAL MATCH " - query += " OPTIONAL MATCH ".join(i for i in self._ast["optional match"]) + query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match) - if "where" in self._ast and self._ast["where"]: + if self._ast.where: query += " WHERE " - query += " AND ".join(self._ast["where"]) + query += " AND ".join(self._ast.where) - if "with" in self._ast and self._ast["with"]: + if self._ast.with_clause: query += " WITH " - query += self._ast["with"] + query += self._ast.with_clause query += " RETURN " - if "return" in self._ast: - query += self._ast["return"] - if "additional_return" in self._ast: - if "return" in self._ast: + if self._ast.return_clause: + query += self._ast.return_clause + if self._ast.additional_return: + if self._ast.return_clause: query += ", " - query += ", ".join(self._ast["additional_return"]) + query += ", ".join(self._ast.additional_return) - if "order_by" in self._ast and self._ast["order_by"]: + if self._ast.order_by: query += " ORDER BY " - query += ", ".join(self._ast["order_by"]) + query += ", ".join(self._ast.order_by) - if "skip" in self._ast: - query += f" SKIP {self._ast['skip']:d}" + if self._ast.skip: + query += f" SKIP {self._ast.skip}" - if "limit" in self._ast: - query += f" LIMIT {self._ast['limit']:d}" + if self._ast.limit: + query += f" LIMIT {self._ast.limit}" return query def _count(self): - self._ast["return"] = f"count({self._ast['return']})" + self._ast.return_clause = f"count({self._ast.return_clause})" # drop order_by, results in an invalid query - self._ast.pop("order_by", None) + self._ast.order_by = None # drop additional_return to avoid unexpected result - self._ast.pop("additional_return", None) + self._ast.additional_return = None query = self.build_query() results, _ = db.cypher_query(query, self._query_params) return int(results[0][0]) def _contains(self, node_element_id): # inject id = into ast - if "return" not in self._ast: - print(self._ast["additional_return"]) - self._ast["return"] = self._ast["additional_return"][0] - ident = self._ast["return"] + if not self._ast.return_clause: + print(self._ast.additional_return) + self._ast.return_clause = self._ast.additional_return[0] + ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") - self._ast["where"].append(f"{db.get_id_method()}({ident}) = ${place_holder}") + self._ast.where.append(f"{db.get_id_method()}({ident}) = ${place_holder}") self._query_params[place_holder] = node_element_id return self._count() >= 1 def _execute(self, lazy=False): if lazy: # inject id() into return or return_set - if "return" in self._ast: - self._ast["return"] = f"{db.get_id_method()}({self._ast['return']})" + if self._ast.return_clause: + self._ast.return_clause = ( + f"{db.get_id_method()}({self._ast.return_clause})" + ) else: - self._ast["additional_return"] = [ + self._ast.additional_return = [ f"{db.get_id_method()}({item})" - for item in self._ast["additional_return"] + for item in self._ast.additional_return ] query = self.build_query() results, _ = db.cypher_query(query, self._query_params, resolve_objects=True) @@ -869,12 +912,12 @@ def order_by(self, *props): remove ordering. """ should_remove = len(props) == 1 and props[0] is None - if not hasattr(self, "_order_by") or should_remove: - self._order_by = [] + if not hasattr(self, "order_by_elements") or should_remove: + self.order_by_elements = [] if should_remove: return self if "?" in props: - self._order_by.append("?") + self.order_by_elements.append("?") else: for prop in props: prop = prop.strip() @@ -893,7 +936,7 @@ def order_by(self, *props): if isinstance(property_obj, AliasProperty): prop = property_obj.aliased_to() - self._order_by.append(prop + (" DESC" if desc else "")) + self.order_by_elements.append(prop + (" DESC" if desc else "")) return self diff --git a/neomodel/path.py b/neomodel/path.py new file mode 100644 index 00000000..5f063d11 --- /dev/null +++ b/neomodel/path.py @@ -0,0 +1,53 @@ +from neo4j.graph import Path +from .core import db +from .relationship import StructuredRel +from .exceptions import RelationshipClassNotDefined + + +class NeomodelPath(Path): + """ + Represents paths within neomodel. + + This object is instantiated when you include whole paths in your ``cypher_query()`` + result sets and turn ``resolve_objects`` to True. + + That is, any query of the form: + :: + + MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p + + ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already + resolved to their neomodel objects if such mapping is possible. + + + :param nodes: Neomodel nodes appearing in the path in order of appearance. + :param relationships: Neomodel relationships appearing in the path in order of appearance. + :type nodes: List[StructuredNode] + :type relationships: List[StructuredRel] + """ + def __init__(self, a_neopath): + self._nodes=[] + self._relationships = [] + + for a_node in a_neopath.nodes: + self._nodes.append(db._object_resolution(a_node)) + + for a_relationship in a_neopath.relationships: + # This check is required here because if the relationship does not bear data + # then it does not have an entry in the registry. In that case, we instantiate + # an "unspecified" StructuredRel. + rel_type = frozenset([a_relationship.type]) + if rel_type in db._NODE_CLASS_REGISTRY: + new_rel = db._object_resolution(a_relationship) + else: + new_rel = StructuredRel.inflate(a_relationship) + self._relationships.append(new_rel) + @property + def nodes(self): + return self._nodes + + @property + def relationships(self): + return self._relationships + + diff --git a/neomodel/properties.py b/neomodel/properties.py index a760a2c4..0d325480 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -265,9 +265,7 @@ def __init__(self, expression=None, **kwargs): def normalize(self, value): normal = Unicode(value) if not re.match(self.expression, normal): - raise ValueError( - f"{value!r} does not match {self.expression!r}" - ) + raise ValueError(f"{value!r} does not match {self.expression!r}") return normal @@ -324,7 +322,9 @@ def normalize(self, value): if self.choices is not None and value not in self.choices: raise ValueError(f"Invalid choice: {value}") if self.max_length is not None and len(value) > self.max_length: - raise ValueError(f"Property max length exceeded. Expected {self.max_length}, got {len(value)} == len('{value}')") + raise ValueError( + f"Property max length exceeded. Expected {self.max_length}, got {len(value)} == len('{value}')" + ) return Unicode(value) def default_value(self): @@ -453,9 +453,8 @@ class DateProperty(Property): def inflate(self, value): if isinstance(value, neo4j.time.DateTime): value = date(value.year, value.month, value.day) - elif isinstance(value, str): - if "T" in value: - value = value[: value.find("T")] + elif isinstance(value, str) and "T" in value: + value = value[: value.find("T")] return datetime.strptime(Unicode(value), "%Y-%m-%d").date() @validator @@ -523,9 +522,13 @@ def inflate(self, value): try: epoch = float(value) except ValueError as exc: - raise ValueError(f"Float or integer expected, got {type(value)} cannot inflate to datetime.") from exc + raise ValueError( + f"Float or integer expected, got {type(value)} cannot inflate to datetime." + ) from exc except TypeError as exc: - raise TypeError(f"Float or integer expected. Can't inflate {type(value)} to datetime.") from exc + raise TypeError( + f"Float or integer expected. Can't inflate {type(value)} to datetime." + ) from exc return datetime.utcfromtimestamp(epoch).replace(tzinfo=pytz.utc) @validator diff --git a/neomodel/relationship.py b/neomodel/relationship.py index 9320bfe8..31eba41c 100644 --- a/neomodel/relationship.py +++ b/neomodel/relationship.py @@ -153,3 +153,4 @@ def inflate(cls, rel): srel._end_node_element_id_property = rel.end_node.element_id srel.element_id_property = rel.element_id return srel + diff --git a/neomodel/relationship_manager.py b/neomodel/relationship_manager.py index 84608168..1e9cf79e 100644 --- a/neomodel/relationship_manager.py +++ b/neomodel/relationship_manager.py @@ -246,8 +246,7 @@ def reconnect(self, old_node, new_node): q += " MERGE" + new_rel # copy over properties if we have - for p in existing_properties: - q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) + q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) q += " WITH r DELETE r" self.source.cypher(q, {"old": old_node.element_id, "new": new_node.element_id}) @@ -449,7 +448,7 @@ def _validate_class(self, cls_name, model): if model and not issubclass(model, (StructuredRel,)): raise ValueError("model must be a StructuredRel") - def _lookup_node_class(self): + def lookup_node_class(self): if not isinstance(self._raw_class, basestring): self.definition["node_class"] = self._raw_class else: @@ -487,7 +486,7 @@ def _lookup_node_class(self): self.definition["node_class"] = getattr(sys.modules[module], name) def build_manager(self, source, name): - self._lookup_node_class() + self.lookup_node_class() return self.manager(source, name, self.definition) diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index fdbf5184..444838b2 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -1,8 +1,34 @@ #!/usr/bin/env python +""" +.. _neomodel_install_labels: + +``neomodel_install_labels`` +--------------------------- + +:: + + usage: neomodel_install_labels [-h] [--db bolt://neo4j:neo4j@localhost:7687] [ ...] + + Setup indexes and constraints on labels in Neo4j for your neomodel schema. + + If a connection URL is not specified, the tool will look up the environment + variable NEO4J_BOLT_URL. If that environment variable is not set, the tool + will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + + positional arguments: + + python modules or files with neomodel schema declarations. + + options: + -h, --help show this help message and exit + --db bolt://neo4j:neo4j@localhost:7687 + Neo4j Server URL +""" from __future__ import print_function import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, RawDescriptionHelpFormatter +import textwrap from importlib import import_module from os import environ, path @@ -10,6 +36,15 @@ def load_python_module_or_file(name): + """ + Imports an existing python module or file into the current workspace. + + In both cases, *the resource must exist*. + + :param name: A string that refers either to a Python module or a source coe + file to load in the current workspace. + :type name: str + """ # Is a file if name.lower().endswith(".py"): basedir = path.dirname(path.abspath(name)) @@ -29,24 +64,27 @@ def load_python_module_or_file(name): pkg = None import_module(module_name, package=pkg) - print("Loaded {}.".format(name)) + print(f"Loaded {name}") def main(): parser = ArgumentParser( - description=""" - Setup indexes and constraints on labels in Neo4j for your neomodel schema. + formatter_class=RawDescriptionHelpFormatter, + description=textwrap.dedent(""" + Setup indexes and constraints on labels in Neo4j for your neomodel schema. - Database credentials can be set by the environment variable NEO4J_BOLT_URL. - """ - ) + If a connection URL is not specified, the tool will look up the environment + variable NEO4J_BOLT_URL. If that environment variable is not set, the tool + will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + """ + )) parser.add_argument( "apps", metavar="", type=str, nargs="+", - help="python modules or files to load schema from.", + help="python modules or files with neomodel schema declarations.", ) parser.add_argument( @@ -55,7 +93,7 @@ def main(): dest="neo4j_bolt_url", type=str, default="", - help="address of your neo4j database", + help="Neo4j Server URL", ) args = parser.parse_args() @@ -68,7 +106,7 @@ def main(): load_python_module_or_file(app) # Connect after to override any code in the module that may set the connection - print("Connecting to {}\n".format(bolt_url)) + print(f"Connecting to {bolt_url}") db.set_connection(bolt_url) install_all_labels() diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 81fa5b3f..58a57cdd 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -1,7 +1,30 @@ #!/usr/bin/env python +""" +.. _neomodel_remove_labels: + +``neomodel_remove_labels`` +-------------------------- + +:: + + usage: neomodel_remove_labels [-h] [--db bolt://neo4j:neo4j@localhost:7687] + + Drop all indexes and constraints on labels from schema in Neo4j database. + + If a connection URL is not specified, the tool will look up the environment + variable NEO4J_BOLT_URL. If that environment variable is not set, the tool + will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + + options: + -h, --help show this help message and exit + --db bolt://neo4j:neo4j@localhost:7687 + Neo4j Server URL + +""" from __future__ import print_function -from argparse import ArgumentParser +from argparse import ArgumentParser, RawDescriptionHelpFormatter +import textwrap from os import environ from .. import db, remove_all_labels @@ -9,12 +32,15 @@ def main(): parser = ArgumentParser( - description=""" - Drop all indexes and constraints on labels from schema in Neo4j database. + formatter_class=RawDescriptionHelpFormatter, + description=textwrap.dedent(""" + Drop all indexes and constraints on labels from schema in Neo4j database. - Database credentials can be set by the environment variable NEO4J_BOLT_URL. - """ - ) + If a connection URL is not specified, the tool will look up the environment + variable NEO4J_BOLT_URL. If that environment variable is not set, the tool + will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + """ + )) parser.add_argument( "--db", @@ -22,7 +48,7 @@ def main(): dest="neo4j_bolt_url", type=str, default="", - help="address of your neo4j database", + help="Neo4j Server URL", ) args = parser.parse_args() @@ -32,7 +58,7 @@ def main(): bolt_url = environ.get("NEO4J_BOLT_URL", "bolt://neo4j:neo4j@localhost:7687") # Connect after to override any code in the module that may set the connection - print("Connecting to {}\n".format(bolt_url)) + print(f"Connecting to {bolt_url}") db.set_connection(bolt_url) remove_all_labels() diff --git a/neomodel/util.py b/neomodel/util.py index ca008a16..c03a1ce2 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -9,12 +9,13 @@ from neo4j import DEFAULT_DATABASE, GraphDatabase, basic_auth from neo4j.api import Bookmarks -from neo4j.exceptions import ClientError, SessionExpired -from neo4j.graph import Node, Relationship +from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired +from neo4j.graph import Node, Path, Relationship from neomodel import config, core from neomodel.exceptions import ( ConstraintValidationFailed, + FeatureNotSupported, NodeClassNotDefined, RelationshipClassNotDefined, UniqueProperty, @@ -34,6 +35,7 @@ def wrapper(self, *args, **kwargs): if not _db.url: _db.set_connection(config.DATABASE_URL) + return func(self, *args, **kwargs) return wrapper @@ -72,7 +74,9 @@ def __init__(self): self._pid = None self._database_name = DEFAULT_DATABASE self.protocol_version = None - self.database_version = None + self._database_version = None + self._database_edition = None + self.impersonated_user = None def set_connection(self, url): """ @@ -128,10 +132,24 @@ def set_connection(self, url): self._active_transaction = None self._database_name = DEFAULT_DATABASE if database_name == "" else database_name - results = self.cypher_query( - "CALL dbms.components() yield versions return versions[0]" - ) - self.database_version = results[0][0][0] + # Getting the information about the database version requires a connection to the database + self._database_version = None + self._database_edition = None + self._update_database_version() + + @property + def database_version(self): + if self._database_version is None: + self._update_database_version() + + return self._database_version + + @property + def database_edition(self): + if self._database_edition is None: + self._update_database_version() + + return self._database_edition @property def transaction(self): @@ -148,6 +166,21 @@ def write_transaction(self): def read_transaction(self): return TransactionProxy(self, access_mode="READ") + def impersonate(self, user: str) -> "ImpersonationHandler": + """All queries executed within this context manager will be executed as impersonated user + + Args: + user (str): User to impersonate + + Returns: + ImpersonationHandler: Context manager to set/unset the user to impersonate + """ + if self.database_edition != "enterprise": + raise FeatureNotSupported( + "Impersonation is only available in Neo4j Enterprise edition" + ) + return ImpersonationHandler(self, impersonated_user=user) + @ensure_connection def begin(self, access_mode=None, **parameters): """ @@ -161,6 +194,7 @@ def begin(self, access_mode=None, **parameters): self._session = self.driver.session( default_access_mode=access_mode, database=self._database_name, + impersonated_user=self.impersonated_user, **parameters, ) self._active_transaction = self._session.begin_transaction() @@ -201,7 +235,67 @@ def rollback(self): self._active_transaction = None self._session = None - def _object_resolution(self, result_list): + def _update_database_version(self): + """ + Updates the database server information when it is required + """ + try: + results = self.cypher_query( + "CALL dbms.components() yield versions, edition return versions[0], edition" + ) + self._database_version = results[0][0][0] + self._database_edition = results[0][0][1] + except ServiceUnavailable: + # The database server is not running yet + pass + + def _object_resolution(self, object_to_resolve): + """ + Performs in place automatic object resolution on a result + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures and Path objects. Not meant to be called + directly, used primarily by _result_resolution. + + :param object_to_resolve: A result as returned by cypher_query. + :type Any: + + :return: An instantiated object. + """ + # Below is the original comment that came with the code extracted in + # this method. It is not very clear but I decided to keep it just in + # case + # + # + # For some reason, while the type of `a_result_attribute[1]` + # as reported by the neo4j driver is `Node` for Node-type data + # retrieved from the database. + # When the retrieved data are Relationship-Type, + # the returned type is `abc.[REL_LABEL]` which is however + # a descendant of Relationship. + # Consequently, the type checking was changed for both + # Node, Relationship objects + if isinstance(object_to_resolve, Node): + return self._NODE_CLASS_REGISTRY[ + frozenset(object_to_resolve.labels) + ].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Relationship): + rel_type = frozenset([object_to_resolve.type]) + return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Path): + from .path import NeomodelPath + + return NeomodelPath(object_to_resolve) + + if isinstance(object_to_resolve, list): + return self._result_resolution([object_to_resolve]) + + return object_to_resolve + + def _result_resolution(self, result_list): """ Performs in place automatic object resolution on a set of results returned by cypher_query. @@ -224,28 +318,7 @@ def _object_resolution(self, result_list): # Nodes to be resolved to native objects resolved_object = a_result_attribute[1] - # For some reason, while the type of `a_result_attribute[1]` - # as reported by the neo4j driver is `Node` for Node-type data - # retrieved from the database. - # When the retrieved data are Relationship-Type, - # the returned type is `abc.[REL_LABEL]` which is however - # a descendant of Relationship. - # Consequently, the type checking was changed for both - # Node, Relationship objects - if isinstance(a_result_attribute[1], Node): - resolved_object = self._NODE_CLASS_REGISTRY[ - frozenset(a_result_attribute[1].labels) - ].inflate(a_result_attribute[1]) - - if isinstance(a_result_attribute[1], Relationship): - resolved_object = self._NODE_CLASS_REGISTRY[ - frozenset([a_result_attribute[1].type]) - ].inflate(a_result_attribute[1]) - - if isinstance(a_result_attribute[1], list): - resolved_object = self._object_resolution( - [a_result_attribute[1]] - ) + resolved_object = self._object_resolution(resolved_object) result_list[a_result_item[0]][ a_result_attribute[0] @@ -303,7 +376,9 @@ def cypher_query( ) else: # Otherwise create a new session in a with to dispose of it after it has been run - with self.driver.session(database=self._database_name) as session: + with self.driver.session( + database=self._database_name, impersonated_user=self.impersonated_user + ) as session: results, meta = self._run_cypher_query( session, query, @@ -333,7 +408,7 @@ def _run_cypher_query( if resolve_objects: # Do any automatic resolution required - results = self._object_resolution(results) + results = self._result_resolution(results) except ClientError as e: if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": @@ -422,9 +497,11 @@ def __exit__(self, exc_type, exc_value, traceback): if exc_value: self.db.rollback() - if exc_type is ClientError: - if exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - raise UniqueProperty(exc_value.message) + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) if not exc_value: self.last_bookmark = self.db.commit() @@ -441,6 +518,30 @@ def with_bookmark(self): return BookmarkingTransactionProxy(self.db, self.access_mode) +class ImpersonationHandler: + def __init__(self, db, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self): + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + class BookmarkingTransactionProxy(TransactionProxy): def __call__(self, func): def wrapper(*args, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index 4a2b666b..8fbca288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,12 +33,12 @@ classifiers = [ "Topic :: Database", ] dependencies = [ - "neo4j==5.10.0", + "neo4j==5.11.0", "pytz>=2021.1", "neobolt==1.7.17", "six==1.16.0", ] -version='5.1.0' +version='5.1.1' [project.urls] documentation = "https://neomodel.readthedocs.io/en/latest/" diff --git a/test/conftest.py b/test/conftest.py index ef48be09..1cf682df 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -74,6 +74,13 @@ def pytest_sessionstart(session): else: clear_neo4j_database(db, clear_constraints=True, clear_indexes=True) + db.cypher_query( + "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" + ) + if db.database_edition == "enterprise": + db.cypher_query("GRANT ROLE publisher TO troygreene") + db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + def version_to_dec(a_version_string): """ diff --git a/test/test_connection.py b/test/test_connection.py index cc53f944..702a4122 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,27 +1,40 @@ import os import pytest +from neo4j.debug import watch from neomodel import config, db +INITIAL_URL = db.url + + +@pytest.fixture(autouse=True) +def setup_teardown(): + yield + # Teardown actions after tests have run + # Reconnect to initial URL for potential subsequent tests + db.driver.close() + db.set_connection(INITIAL_URL) + + +@pytest.fixture(autouse=True, scope="session") +def neo4j_logging(): + with watch("neo4j"): + yield + @pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) def test_connect_to_aura(protocol): - prev_url = db.url cypher_return = "hello world" default_cypher_query = f"RETURN '{cypher_return}'" db.driver.close() _set_connection(protocol=protocol) result, _ = db.cypher_query(default_cypher_query) - db.driver.close() assert len(result) > 0 assert result[0][0] == cypher_return - # Finally, reconnect to base URL for subsequent tests - db.set_connection(prev_url) - def _set_connection(protocol): AURA_TEST_DB_USER = os.environ["AURA_TEST_DB_USER"] @@ -42,6 +55,3 @@ def test_wrong_url_format(url): match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}", ): db.set_connection(url) - - # Finally, reconnect to base URL for subsequent tests - db.set_connection(prev_url) diff --git a/test/test_contrib/test_spatial_datatypes.py b/test/test_contrib/test_spatial_datatypes.py index bd5ace45..c38c79e4 100644 --- a/test/test_contrib/test_spatial_datatypes.py +++ b/test/test_contrib/test_spatial_datatypes.py @@ -80,7 +80,7 @@ def basic_type_assertions( ) assert len(tested_object) == len( ground_truth - ), "{} dimensionality mismatch. Expected {}, had {}".format( + ), "Dimensionality mismatch. Expected {}, had {}".format( len(ground_truth.coords), len(tested_object.coords) ) else: @@ -94,7 +94,7 @@ def basic_type_assertions( ) assert len(tested_object.coords[0]) == len( ground_truth.coords[0] - ), "{} dimensionality mismatch. Expected {}, had {}".format( + ), "Dimensionality mismatch. Expected {}, had {}".format( len(ground_truth.coords[0]), len(tested_object.coords[0]) ) @@ -262,13 +262,17 @@ def test_prohibited_constructor_forms(): _ = neomodel.contrib.spatial_properties.NeomodelPoint((0, 0), crs="blue_hotel") # Absurd coord dimensionality - with pytest.raises(ValueError,): + with pytest.raises( + ValueError, + ): _ = neomodel.contrib.spatial_properties.NeomodelPoint( (0, 0, 0, 0, 0, 0, 0), crs="cartesian" ) # Absurd datatype passed to copy constructor - with pytest.raises(TypeError,): + with pytest.raises( + TypeError, + ): _ = neomodel.contrib.spatial_properties.NeomodelPoint( "it don't mean a thing if it ain't got that swing", crs="cartesian", @@ -333,7 +337,7 @@ def test_property_accessors_depending_on_crs_shapely_lt_2(): with pytest.raises(AttributeError, match=r'Invalid coordinate \("z"\)'): new_point.z - + def test_property_accessors_depending_on_crs_shapely_gte_2(): """ Tests that points are accessed via their respective accessors. diff --git a/test/test_driver_options.py b/test/test_driver_options.py new file mode 100644 index 00000000..9e2af27a --- /dev/null +++ b/test/test_driver_options.py @@ -0,0 +1,50 @@ +import pytest +from neo4j.exceptions import ClientError +from pytest import raises + +from neomodel import db +from neomodel.exceptions import FeatureNotSupported + + +@pytest.mark.skipif( + db.database_edition != "enterprise", reason="Skipping test for community edition" +) +def test_impersonate(): + with db.impersonate(user="troygreene"): + results, _ = db.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + +@pytest.mark.skipif( + db.database_edition != "enterprise", reason="Skipping test for community edition" +) +def test_impersonate_unauthorized(): + with db.impersonate(user="unknownuser"): + with raises(ClientError): + _ = db.cypher_query("RETURN 'Gabagool'") + + +@pytest.mark.skipif( + db.database_edition != "enterprise", reason="Skipping test for community edition" +) +def test_impersonate_multiple_transactions(): + with db.impersonate(user="troygreene"): + with db.transaction: + results, _ = db.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + with db.transaction: + results, _ = db.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "troygreene" + + results, _ = db.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "neo4j" + + +@pytest.mark.skipif( + db.database_edition == "enterprise", reason="Skipping test for enterprise edition" +) +def test_impersonate_community(): + with raises(FeatureNotSupported): + with db.impersonate(user="troygreene"): + _ = db.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/test_indexing.py b/test/test_indexing.py index 3fb4930c..4b930715 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -1,3 +1,4 @@ +import pytest from pytest import raises from neomodel import ( @@ -28,6 +29,9 @@ def test_unique_error(): assert False, "UniqueProperty not raised." +@pytest.mark.skipif( + db.database_edition != "enterprise", reason="Skipping test for community edition" +) def test_existence_constraint_error(): db.cypher_query( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" diff --git a/test/test_issue283.py b/test/test_issue283.py index af285aeb..b7a63f8f 100644 --- a/test/test_issue283.py +++ b/test/test_issue283.py @@ -81,16 +81,22 @@ class SomePerson(BaseOtherPerson): # Test cases -def test_automatic_object_resolution(): +def test_automatic_result_resolution(): """ Node objects at the end of relationships are instantiated to their corresponding Python object. """ # Create a few entities - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] + B = TechnicalPerson.get_or_create( + {"name": "Happy", "expertise": "Unicorns"} + )[0] + C = TechnicalPerson.get_or_create( + {"name": "Sleepy", "expertise": "Pillows"} + )[0] # Add connections A.friends_with.connect(B) @@ -106,7 +112,7 @@ def test_automatic_object_resolution(): C.delete() -def test_recursive_automatic_object_resolution(): +def test_recursive_automatic_result_resolution(): """ Node objects are instantiated to native Python objects, both at the top level of returned results and in the case where they are returned within @@ -114,12 +120,18 @@ def test_recursive_automatic_object_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"})[ - 0 - ] - B = TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"})[0] - D = TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpier", "expertise": "Grumpiness"} + )[0] + B = TechnicalPerson.get_or_create( + {"name": "Happier", "expertise": "Grumpiness"} + )[0] + C = TechnicalPerson.get_or_create( + {"name": "Sleepier", "expertise": "Pillows"} + )[0] + D = TechnicalPerson.get_or_create( + {"name": "Sneezier", "expertise": "Pillows"} + )[0] # Retrieve mixed results, both at the top level and nested L, _ = neomodel.db.cypher_query( @@ -154,9 +166,15 @@ def test_validation_with_inheritance_from_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] + B = TechnicalPerson.get_or_create( + {"name": "Happy", "expertise": "Unicorns"} + )[0] + C = TechnicalPerson.get_or_create( + {"name": "Sleepy", "expertise": "Pillows"} + )[0] # Pilot Persons D = PilotPerson.get_or_create( @@ -205,9 +223,15 @@ def test_validation_enforcement_to_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] + B = TechnicalPerson.get_or_create( + {"name": "Happy", "expertise": "Unicorns"} + )[0] + C = TechnicalPerson.get_or_create( + {"name": "Sleepy", "expertise": "Pillows"} + )[0] # Pilot Persons D = PilotPerson.get_or_create( @@ -241,7 +265,7 @@ def test_validation_enforcement_to_db(): F.delete() -def test_failed_object_resolution(): +def test_failed_result_resolution(): """ A Neo4j driver node FROM the database contains labels that are unaware to neomodel's Database class. This condition raises ClassDefinitionNotFound @@ -252,7 +276,9 @@ class RandomPerson(BasePerson): randomness = neomodel.FloatProperty(default=random.random) # A Technical Person... - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] # A Random Person... B = RandomPerson.get_or_create({"name": "Mad Hatter"})[0] @@ -261,10 +287,14 @@ class RandomPerson(BasePerson): # Simulate the condition where the definition of class RandomPerson is not # known yet. - del neomodel.db._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] + del neomodel.db._NODE_CLASS_REGISTRY[ + frozenset(["RandomPerson", "BasePerson"]) + ] # Now try to instantiate a RandomPerson - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] with pytest.raises( neomodel.exceptions.NodeClassNotDefined, match=r"Node with labels .* does not resolve to any of the known objects.*", @@ -289,9 +319,13 @@ class UltraTechnicalPerson(SuperTechnicalPerson): ultraness = neomodel.FloatProperty(default=3.1415928) # Create a TechnicalPerson... - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] # ...that is connected to an UltraTechnicalPerson - F = UltraTechnicalPerson(name="Chewbaka", expertise="Aarrr wgh ggwaaah").save() + F = UltraTechnicalPerson( + name="Chewbaka", expertise="Aarrr wgh ggwaaah" + ).save() A.friends_with.connect(F) # Forget about the UltraTechnicalPerson @@ -309,7 +343,9 @@ class UltraTechnicalPerson(SuperTechnicalPerson): # Recall a TechnicalPerson and enumerate its friends. # One of them is UltraTechnicalPerson which would be returned as a valid # node to a friends_with query but is currently unknown to the node class registry. - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + )[0] with pytest.raises(neomodel.exceptions.NodeClassNotDefined): for some_friend in A.friends_with: print(some_friend.name) @@ -334,12 +370,14 @@ class SomePerson(BaseOtherPerson): redefine_class_locally() -def test_relationship_object_resolution(): +def test_relationship_result_resolution(): """ A query returning a "Relationship" object can now instantiate it to a data model class """ # Test specific data - A = PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save() + A = PilotPerson( + name="Zantford Granville", airplane="Gee Bee Model R" + ).save() B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save() C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save() D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save() diff --git a/test/test_match_api.py b/test/test_match_api.py index 1dcb569b..af9e2ea5 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -49,8 +49,8 @@ def test_filter_exclude_via_labels(): results = qb._execute() - assert "(coffee:Coffee)" in qb._ast["match"] - assert "result_class" in qb._ast + assert "(coffee:Coffee)" in qb._ast.match + assert qb._ast.result_class assert len(results) == 1 assert isinstance(results[0], Coffee) assert results[0].name == "Java" @@ -61,8 +61,8 @@ def test_filter_exclude_via_labels(): qb = QueryBuilder(node_set).build_ast() results = qb._execute() - assert "(coffee:Coffee)" in qb._ast["match"] - assert "NOT" in qb._ast["where"][0] + assert "(coffee:Coffee)" in qb._ast.match + assert "NOT" in qb._ast.where[0] assert len(results) == 1 assert results[0].name == "Kenco" @@ -75,7 +75,7 @@ def test_simple_has_via_label(): ns = NodeSet(Coffee).has(suppliers=True) qb = QueryBuilder(ns).build_ast() results = qb._execute() - assert "COFFEE SUPPLIERS" in qb._ast["where"][0] + assert "COFFEE SUPPLIERS" in qb._ast.where[0] assert len(results) == 1 assert results[0].name == "Nescafe" @@ -84,7 +84,7 @@ def test_simple_has_via_label(): qb = QueryBuilder(ns).build_ast() results = qb._execute() assert len(results) > 0 - assert "NOT" in qb._ast["where"][0] + assert "NOT" in qb._ast.where[0] def test_get(): @@ -109,9 +109,9 @@ def test_simple_traverse_with_filter(): results = qb.build_ast()._execute() - assert "lookup" in qb._ast - assert "match" in qb._ast - assert qb._ast["return"] == "suppliers" + assert qb._ast.lookup + assert qb._ast.match + assert qb._ast.return_clause == "suppliers" assert len(results) == 1 assert results[0].name == "Sainsburys" @@ -209,14 +209,14 @@ def test_order_by(): ns = Coffee.nodes.order_by("-price") qb = QueryBuilder(ns).build_ast() - assert qb._ast["order_by"] + assert qb._ast.order_by ns = ns.order_by(None) qb = QueryBuilder(ns).build_ast() - assert not qb._ast["order_by"] + assert not qb._ast.order_by ns = ns.order_by("?") qb = QueryBuilder(ns).build_ast() - assert qb._ast["with"] == "coffee, rand() as r" - assert qb._ast["order_by"] == "r" + assert qb._ast.with_clause == "coffee, rand() as r" + assert qb._ast.order_by == "r" # Test order by on a relationship l = Supplier(name="lidl2").save() diff --git a/test/test_paths.py b/test/test_paths.py new file mode 100644 index 00000000..8c6fef28 --- /dev/null +++ b/test/test_paths.py @@ -0,0 +1,79 @@ +from neomodel import (StringProperty, StructuredNode, UniqueIdProperty, + db, RelationshipTo, IntegerProperty, NeomodelPath, StructuredRel) + +class PersonLivesInCity(StructuredRel): + """ + Relationship with data that will be instantiated as "stand-alone" + """ + some_num = IntegerProperty(index=True, default=12) + +class CountryOfOrigin(StructuredNode): + code = StringProperty(unique_index=True, required=True) + +class CityOfResidence(StructuredNode): + name = StringProperty(required=True) + country = RelationshipTo(CountryOfOrigin, 'FROM_COUNTRY') + +class PersonOfInterest(StructuredNode): + uid = UniqueIdProperty() + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True, default=0) + + country = RelationshipTo(CountryOfOrigin, 'IS_FROM') + city = RelationshipTo(CityOfResidence, 'LIVES_IN', model=PersonLivesInCity) + + +def test_path_instantiation(): + """ + Neo4j driver paths should be instantiated as neomodel paths, with all of + their nodes and relationships resolved to their Python objects wherever + such a mapping is available. + """ + + c1=CountryOfOrigin(code="GR").save() + c2=CountryOfOrigin(code="FR").save() + + ct1 = CityOfResidence(name="Athens", country = c1).save() + ct2 = CityOfResidence(name="Paris", country = c2).save() + + + p1 = PersonOfInterest(name="Bill", age=22).save() + p1.country.connect(c1) + p1.city.connect(ct1) + + p2 = PersonOfInterest(name="Jean", age=28).save() + p2.country.connect(c2) + p2.city.connect(ct2) + + p3 = PersonOfInterest(name="Bo", age=32).save() + p3.country.connect(c1) + p3.city.connect(ct2) + + p4 = PersonOfInterest(name="Drop", age=16).save() + p4.country.connect(c1) + p4.city.connect(ct2) + + # Retrieve a single path + q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects = True) + + path_object = q[0][0][0] + path_nodes = path_object.nodes + path_rels = path_object.relationships + + assert type(path_object) is NeomodelPath + assert type(path_nodes[0]) is CityOfResidence + assert type(path_nodes[1]) is PersonOfInterest + assert type(path_nodes[2]) is CountryOfOrigin + + assert type(path_rels[0]) is PersonLivesInCity + assert type(path_rels[1]) is StructuredRel + + c1.delete() + c2.delete() + ct1.delete() + ct2.delete() + p1.delete() + p2.delete() + p3.delete() + p4.delete() + diff --git a/test/test_properties.py b/test/test_properties.py index c72a46db..454ada26 100644 --- a/test/test_properties.py +++ b/test/test_properties.py @@ -192,12 +192,6 @@ class DefaultTestValue(StructuredNode): a = DefaultTestValue() assert a.name_xx == "jim" a.save() - return - b = DefaultTestValue.index.get(name="jim") - assert b.name == "jim" - - c = DefaultTestValue(name=None) - assert c.name == "jim" def test_default_value_callable():