Skip to content

Commit

Permalink
Add option to pass driver instead of database_url - WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Conjeaud committed Aug 29, 2023
1 parent 1a3a194 commit c322145
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 79 deletions.
6 changes: 6 additions & 0 deletions neomodel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,9 @@
RESOLVER = None
TRUSTED_CERTIFICATES = neo4j.TrustSystemCAs()
USER_AGENT = f"neomodel/v{__version__}"

DRIVER = neo4j.GraphDatabase().driver(
"bolt://localhost:7687", auth=("neo4j", "foobarbaz")
)
# TODO : Try passing a different database name
# DATABASE_NAME = "testdatabase"
10 changes: 6 additions & 4 deletions neomodel/scripts/neomodel_install_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from __future__ import print_function

import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import textwrap
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from importlib import import_module
from os import environ, path

Expand Down Expand Up @@ -70,14 +70,16 @@ def load_python_module_or_file(name):
def main():
parser = ArgumentParser(
formatter_class=RawDescriptionHelpFormatter,
description=textwrap.dedent("""
description=textwrap.dedent(
"""
Setup indexes and constraints on labels in Neo4j for your neomodel schema.
If a connection URL is not specified, the tool will look up the environment
variable NEO4J_BOLT_URL. If that environment variable is not set, the tool
will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687
"""
))
),
)

parser.add_argument(
"apps",
Expand Down Expand Up @@ -107,7 +109,7 @@ def main():

# Connect after to override any code in the module that may set the connection
print(f"Connecting to {bolt_url}")
db.set_connection(bolt_url)
db.set_connection(url=bolt_url)

install_all_labels()

Expand Down
12 changes: 7 additions & 5 deletions neomodel/scripts/neomodel_remove_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,26 @@
"""
from __future__ import print_function

from argparse import ArgumentParser, RawDescriptionHelpFormatter
import textwrap
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from os import environ

from .. import db, remove_all_labels


def main():
parser = ArgumentParser(
formatter_class=RawDescriptionHelpFormatter,
description=textwrap.dedent("""
formatter_class=RawDescriptionHelpFormatter,
description=textwrap.dedent(
"""
Drop all indexes and constraints on labels from schema in Neo4j database.
If a connection URL is not specified, the tool will look up the environment
variable NEO4J_BOLT_URL. If that environment variable is not set, the tool
will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687
"""
))
),
)

parser.add_argument(
"--db",
Expand All @@ -59,7 +61,7 @@ def main():

# Connect after to override any code in the module that may set the connection
print(f"Connecting to {bolt_url}")
db.set_connection(bolt_url)
db.set_connection(url=bolt_url)

remove_all_labels()

Expand Down
122 changes: 73 additions & 49 deletions neomodel/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, Sequence
from urllib.parse import quote, unquote, urlparse

from neo4j import DEFAULT_DATABASE, GraphDatabase, basic_auth
from neo4j import DEFAULT_DATABASE, Driver, GraphDatabase, basic_auth
from neo4j.api import Bookmarks
from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired
from neo4j.graph import Node, Path, Relationship
Expand All @@ -33,8 +33,11 @@ def wrapper(self, *args, **kwargs):
else:
_db = self

if not _db.url:
_db.set_connection(config.DATABASE_URL)
if not _db.driver:
if config.DRIVER:
_db.set_connection(driver=config.DRIVER)
elif config.DATABASE_URL:
_db.set_connection(url=config.DATABASE_URL)

return func(self, *args, **kwargs)

Expand Down Expand Up @@ -78,65 +81,85 @@ def __init__(self):
self._database_edition = None
self.impersonated_user = None

def set_connection(self, url):
def set_connection(self, url: str = None, driver: Driver = None):
"""
Sets the connection URL to the address a Neo4j server is set up at
"""
p_start = url.replace(":", "", 1).find(":") + 2
p_end = url.rfind("@")
password = url[p_start:p_end]
url = url.replace(password, quote(password))
parsed_url = urlparse(url)

valid_schemas = [
"bolt",
"bolt+s",
"bolt+ssc",
"bolt+routing",
"neo4j",
"neo4j+s",
"neo4j+ssc",
]

if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas:
credentials, hostname = parsed_url.netloc.rsplit("@", 1)
username, password = credentials.split(":")
password = unquote(password)
database_name = parsed_url.path.strip("/")
else:
raise ValueError(
f"Expecting url format: bolt://user:password@localhost:7687 got {url}"
if driver:
self.driver = driver
if hasattr(config, "DATABASE_NAME"):
self._database_name = config.DATABASE_NAME
elif url:
p_start = url.replace(":", "", 1).find(":") + 2
p_end = url.rfind("@")
password = url[p_start:p_end]
url = url.replace(password, quote(password))
parsed_url = urlparse(url)

valid_schemas = [
"bolt",
"bolt+s",
"bolt+ssc",
"bolt+routing",
"neo4j",
"neo4j+s",
"neo4j+ssc",
]

if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas:
credentials, hostname = parsed_url.netloc.rsplit("@", 1)
username, password = credentials.split(":")
password = unquote(password)
database_name = parsed_url.path.strip("/")
else:
raise ValueError(
f"Expecting url format: bolt://user:password@localhost:7687 got {url}"
)

options = {
"auth": basic_auth(username, password),
"connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT,
"connection_timeout": config.CONNECTION_TIMEOUT,
"keep_alive": config.KEEP_ALIVE,
"max_connection_lifetime": config.MAX_CONNECTION_LIFETIME,
"max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE,
"max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME,
"resolver": config.RESOLVER,
"user_agent": config.USER_AGENT,
}

if "+s" not in parsed_url.scheme:
options["encrypted"] = config.ENCRYPTED
options["trusted_certificates"] = config.TRUSTED_CERTIFICATES

self.driver = GraphDatabase.driver(
parsed_url.scheme + "://" + hostname, **options
)
self.url = url
self._database_name = (
DEFAULT_DATABASE if database_name == "" else database_name
)

options = {
"auth": basic_auth(username, password),
"connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT,
"connection_timeout": config.CONNECTION_TIMEOUT,
"keep_alive": config.KEEP_ALIVE,
"max_connection_lifetime": config.MAX_CONNECTION_LIFETIME,
"max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE,
"max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME,
"resolver": config.RESOLVER,
"user_agent": config.USER_AGENT,
}

if "+s" not in parsed_url.scheme:
options["encrypted"] = config.ENCRYPTED
options["trusted_certificates"] = config.TRUSTED_CERTIFICATES

self.driver = GraphDatabase.driver(
parsed_url.scheme + "://" + hostname, **options
)
self.url = url
self._pid = os.getpid()
self._active_transaction = None
self._database_name = DEFAULT_DATABASE if database_name == "" else database_name

# Getting the information about the database version requires a connection to the database
self._database_version = None
self._database_edition = None
self._update_database_version()

def close_connection(self):
"""
Closes the currently open driver.
The driver should always be called at the end of the application's lifecyle.
If you pass your own driver to neomodel, you can also close it yourself without this method.
"""
self._database_version = None
self._database_edition = None
self._database_name = None
self.driver.close()
self.driver = None

@property
def database_version(self):
if self._database_version is None:
Expand Down Expand Up @@ -420,6 +443,7 @@ def _run_cypher_query(
raise exc_info[1].with_traceback(exc_info[2])
except SessionExpired:
if retry_on_session_expire:
# TODO : What about if config passes driver instead of url ?
self.set_connection(self.url)
return self.cypher_query(
query=query,
Expand Down
15 changes: 6 additions & 9 deletions test/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@

from neomodel import config, db

INITIAL_URL = db.url


@pytest.fixture(autouse=True)
def setup_teardown():
yield
# Teardown actions after tests have run
# Reconnect to initial URL for potential subsequent tests
db.driver.close()
db.set_connection(INITIAL_URL)
db.close_connection()
db.set_connection(url=config.DATABASE_URL)


@pytest.fixture(autouse=True, scope="session")
Expand All @@ -27,7 +25,7 @@ def neo4j_logging():
def test_connect_to_aura(protocol):
cypher_return = "hello world"
default_cypher_query = f"RETURN '{cypher_return}'"
db.driver.close()
db.close_connection()

_set_connection(protocol=protocol)
result, _ = db.cypher_query(default_cypher_query)
Expand All @@ -41,17 +39,16 @@ def _set_connection(protocol):
AURA_TEST_DB_PASSWORD = os.environ["AURA_TEST_DB_PASSWORD"]
AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"]

config.DATABASE_URL = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}"
db.set_connection(config.DATABASE_URL)
database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}"
db.set_connection(url=database_url)


@pytest.mark.parametrize(
"url", ["bolt://user:password", "http://user:password@localhost:7687"]
)
def test_wrong_url_format(url):
prev_url = db.url
with pytest.raises(
ValueError,
match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}",
):
db.set_connection(url)
db.set_connection(url=url)
8 changes: 4 additions & 4 deletions test/test_database_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def test_change_password():

util.change_neo4j_password(db, "neo4j", new_password)

db.set_connection(new_url)
db.set_connection(url=new_url)

with pytest.raises(AuthError):
db.set_connection(prev_url)
db.set_connection(url=prev_url)

db.set_connection(new_url)
db.set_connection(url=new_url)
util.change_neo4j_password(db, "neo4j", prev_password)
db.set_connection(prev_url)
db.set_connection(url=prev_url)
21 changes: 13 additions & 8 deletions test/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from neo4j.exceptions import ClientError, TransactionError
from pytest import raises

from neomodel import StringProperty, StructuredNode, UniqueProperty, db, install_labels
from neomodel import (
StringProperty,
StructuredNode,
UniqueProperty,
config,
db,
install_labels,
)


class APerson(StructuredNode):
Expand Down Expand Up @@ -80,9 +87,9 @@ def test_read_transaction():
people = APerson.nodes.all()
assert people

with pytest.raises(TransactionError):
with raises(TransactionError):
with db.read_transaction:
with pytest.raises(ClientError) as e:
with raises(ClientError) as e:
APerson(name="Gina").save()
assert e.value.code == "Neo.ClientError.Statement.AccessMode"

Expand All @@ -97,21 +104,19 @@ def test_write_transaction():

def double_transaction():
db.begin()
with pytest.raises(SystemError, match=r"Transaction in progress"):
with raises(SystemError, match=r"Transaction in progress"):
db.begin()

db.rollback()


def test_set_connection_works():
assert APerson(name="New guy 1").save()
from socket import gaierror

old_url = db.url
with raises(ValueError):
db.set_connection("bolt://user:password@6.6.6.6.6.6.6.6:7687")
db.set_connection(url="bolt://user:password@6.6.6.6.6.6.6.6:7687")
APerson(name="New guy 2").save()
db.set_connection(old_url)
db.set_connection(url=config.DATABASE_URL)
# set connection back
assert APerson(name="New guy 3").save()

Expand Down

0 comments on commit c322145

Please sign in to comment.