diff --git a/docs/conf.py b/docs/conf.py index 4137e0f..800339a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -from crate.theme.rtd.conf.sqlalchemy_cratedb import * +from crate.theme.rtd.conf.sqlalchemy_cratedb import * # noqa: F403 # Fallback guards, when parent theme does not introduce them. if "html_theme_options" not in globals(): @@ -11,21 +11,27 @@ sitemap_url_scheme = "{link}" # Disable version chooser. -html_context.update({ - "display_version": False, - "current_version": None, - "versions": [], -}) +html_context.update( # noqa: F405 + { + "display_version": False, + "current_version": None, + "versions": [], + } +) -intersphinx_mapping.update({ - 'py': ('https://docs.python.org/3/', None), - 'sa': ('https://docs.sqlalchemy.org/en/20/', None), - 'dask': ('https://docs.dask.org/en/stable/', None), - 'pandas': ('https://pandas.pydata.org/docs/', None), - }) +intersphinx_mapping.update( + { + "py": ("https://docs.python.org/3/", None), + "sa": ("https://docs.sqlalchemy.org/en/20/", None), + "dask": ("https://docs.dask.org/en/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), + } +) linkcheck_anchors = True -linkcheck_ignore = [r"https://github.com/crate/cratedb-examples/blob/main/by-language/python-sqlalchemy/.*"] +linkcheck_ignore = [ + r"https://github.com/crate/cratedb-examples/blob/main/by-language/python-sqlalchemy/.*" +] rst_prolog = """ .. |nbsp| unicode:: 0xA0 diff --git a/pyproject.toml b/pyproject.toml index 473ab8f..1b4ddea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,10 +5,6 @@ requires = [ "versioningit", ] -[tool.versioningit.vcs] -method = "git" -default-tag = "0.0.0" - [project] name = "sqlalchemy-cratedb" description = "SQLAlchemy dialect for CrateDB." @@ -84,34 +80,32 @@ dynamic = [ "version", ] dependencies = [ - 'backports.zoneinfo<1; python_version < "3.9"', - "crate==1.0.0dev0", + "backports.zoneinfo<1; python_version<'3.9'", + "crate==1.0.0.dev0", "geojson<4,>=2.5", - 'importlib-resources; python_version < "3.9"', + "importlib-resources; python_version<'3.9'", "sqlalchemy<2.1,>=1", "verlib2==0.2", ] -[project.optional-dependencies] -all = [ +optional-dependencies.all = [ "sqlalchemy-cratedb[vector]", ] -develop = [ - "black<25", +optional-dependencies.develop = [ "mypy<1.12", "poethepoet<0.28", "pyproject-fmt<2.3", "ruff<0.7", "validate-pyproject<0.20", ] -doc = [ +optional-dependencies.doc = [ "crate-docs-theme>=0.26.5", "sphinx>=3.5,<9", ] -release = [ +optional-dependencies.release = [ "build<2", "twine<6", ] -test = [ +optional-dependencies.test = [ "cratedb-toolkit[testing]", "dask[dataframe]", "pandas<2.3", @@ -120,54 +114,81 @@ test = [ "pytest-cov<6", "pytest-mock<4", ] -vector = [ +optional-dependencies.vector = [ "numpy", ] -[project.urls] -changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES.md" -documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/" -homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/" -repository = "https://github.com/crate/sqlalchemy-cratedb" -[project.entry-points."sqlalchemy.dialects"] -crate = "sqlalchemy_cratedb:dialect" +urls.changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES.md" +urls.documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/" +urls.homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/" +urls.repository = "https://github.com/crate/sqlalchemy-cratedb" +entry-points."sqlalchemy.dialects".crate = "sqlalchemy_cratedb:dialect" [tool.black] line-length = 100 -[tool.coverage.paths] -source = [ - "src/", +[tool.ruff] +line-length = 100 + +extend-exclude = [ ] -[tool.coverage.run] -branch = false -omit = [ - "tests/*", +lint.select = [ + # Builtins + "A", + # Bugbear + "B", + # comprehensions + "C4", + # Pycodestyle + "E", + # eradicate + "ERA", + # Pyflakes + "F", + # isort + "I", + # pandas-vet + "PD", + # return + "RET", + # Bandit + "S", + # print + "T20", + "W", + # flake8-2020 + "YTT", ] -[tool.coverage.report] -fail_under = 0 -show_missing = true -exclude_lines = [ - "# pragma: no cover", - "raise NotImplemented" +lint.extend-ignore = [ + # zip() without an explicit strict= parameter + "B905", + # Unnecessary generator (rewrite as a `dict` comprehension) + "C402", + # Unnecessary `map` usage (rewrite using a `set` comprehension) + "C417", + # df is a bad variable name. Be kinder to your future self. + "PD901", + # Unnecessary variable assignment before `return` statement + "RET504", + # Unnecessary `elif` after `return` statement + "RET505", + # Probable insecure usage of temporary file or directory + "S108", + # Possible SQL injection vector through string-based query construction + "S608", ] -[tool.mypy] -mypy_path = "src" -packages = ["sqlalchemy_cratedb"] -exclude = [ +lint.per-file-ignores."examples/*" = [ + "T201", # Allow `print` +] + +lint.per-file-ignores."tests/*" = [ + "S101", # Allow use of `assert`, and `print` + "S608", # Possible SQL injection vector through string-based query construction + "W291", # Trailing whitespace + "W293", # Blank line contains whitespace ] -check_untyped_defs = true -explicit_package_bases = true -ignore_missing_imports = true -implicit_optional = true -install_types = true -namespace_packages = true -non_interactive = true -# Needed until `mypy-0.990` for `ConverterDefinition` in `converter.py`. -# https://github.com/python/mypy/issues/731#issuecomment-1260976955 -# enable_recursive_aliases = true [tool.pytest.ini_options] addopts = """ @@ -179,7 +200,7 @@ log_level = "DEBUG" log_cli_level = "DEBUG" log_format = "%(asctime)-15s [%(name)-36s] %(levelname)-8s: %(message)s" pythonpath = [ - "src" + "src", ] testpaths = [ "examples", @@ -194,60 +215,44 @@ xfail_strict = true markers = [ ] -[tool.ruff] -line-length = 100 - -select = [ - # Bandit - "S", - # Bugbear - "B", - # Builtins - "A", - # comprehensions - "C4", - # eradicate - "ERA", - # flake8-2020 - "YTT", - # isort - "I", - # pandas-vet - "PD", - # print - "T20", - # Pycodestyle - "E", - "W", - # Pyflakes - "F", - # return - "RET", +[tool.coverage.paths] +source = [ + "src/", ] -extend-ignore = [ - # zip() without an explicit strict= parameter - "B905", - # df is a bad variable name. Be kinder to your future self. - "PD901", - # Unnecessary variable assignment before `return` statement - "RET504", - # Unnecessary `elif` after `return` statement - "RET505", - # Probable insecure usage of temporary file or directory - "S108", +[tool.coverage.run] +branch = false +omit = [ + "tests/*", ] -extend-exclude = [ +[tool.coverage.report] +fail_under = 0 +show_missing = true +exclude_lines = [ + "# pragma: no cover", + "raise NotImplemented", ] -[tool.ruff.per-file-ignores] -"*/tests/*" = [ - "S101", # Allow use of `assert`, and `print`. - "S608", # Possible SQL injection vector through string-based query construction. +[tool.mypy] +mypy_path = "src" +packages = [ "sqlalchemy_cratedb" ] +exclude = [ ] -"examples/*" = ["T201"] # Allow `print` +check_untyped_defs = true +explicit_package_bases = true +ignore_missing_imports = true +implicit_optional = true +install_types = true +namespace_packages = true +non_interactive = true +# Needed until `mypy-0.990` for `ConverterDefinition` in `converter.py`. +# https://github.com/python/mypy/issues/731#issuecomment-1260976955 +# enable_recursive_aliases = true +[tool.versioningit.vcs] +method = "git" +default-tag = "0.0.0" # =================== # Tasks configuration @@ -256,23 +261,26 @@ extend-exclude = [ [tool.poe.tasks] check = [ - # "lint", + "lint", "test", ] format = [ - { cmd = "black ." }, - # Configure Ruff not to auto-fix (remove!): - # unused imports (F401), unused variables (F841), `print` statements (T201), and commented-out code (ERA001). - { cmd = "ruff --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001 ." }, + # Format project metadata. { cmd = "pyproject-fmt --keep-full-version pyproject.toml" }, + + # Format code. + # Configure Ruff not to auto-fix a few items that are useful in workbench mode. + # e.g.: unused imports (F401), unused variables (F841), `print` statements (T201), commented-out code (ERA001) + { cmd = "ruff format" }, + { cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001" }, ] lint = [ - { cmd = "ruff ." }, - { cmd = "black --check ." }, + { cmd = "ruff format --check" }, + { cmd = "ruff check" }, { cmd = "validate-pyproject pyproject.toml" }, - { cmd = "mypy" }, + # { cmd = "mypy" }, ] release = [ diff --git a/src/sqlalchemy_cratedb/__init__.py b/src/sqlalchemy_cratedb/__init__.py index 297e8fd..41cbf38 100644 --- a/src/sqlalchemy_cratedb/__init__.py +++ b/src/sqlalchemy_cratedb/__init__.py @@ -22,7 +22,7 @@ from .compat.api13 import monkeypatch_add_exec_driver_sql from .dialect import dialect from .predicate import match -from .sa_version import SA_1_4, SA_2_0, SA_VERSION +from .sa_version import SA_1_4, SA_VERSION from .support import insert_bulk from .type.array import ObjectArray from .type.geo import Geopoint, Geoshape @@ -34,7 +34,8 @@ import warnings # SQLAlchemy 1.3 is effectively EOL. - SA13_DEPRECATION_WARNING = textwrap.dedent(""" + SA13_DEPRECATION_WARNING = textwrap.dedent( + """ WARNING: SQLAlchemy 1.3 is effectively EOL. SQLAlchemy 1.3 is EOL since 2023-01-27. @@ -43,8 +44,9 @@ - https://docs.sqlalchemy.org/en/14/changelog/migration_14.html - https://docs.sqlalchemy.org/en/20/changelog/migration_20.html - """.lstrip("\n")) - warnings.warn(message=SA13_DEPRECATION_WARNING, category=DeprecationWarning) + """.lstrip("\n") + ) + warnings.warn(message=SA13_DEPRECATION_WARNING, category=DeprecationWarning, stacklevel=2) # SQLAlchemy 1.3 does not have the `exec_driver_sql` method, so add it. monkeypatch_add_exec_driver_sql() @@ -59,4 +61,5 @@ ObjectType, match, knn_match, + insert_bulk, ] diff --git a/src/sqlalchemy_cratedb/compat/api13.py b/src/sqlalchemy_cratedb/compat/api13.py index 17774b2..6b19dc8 100644 --- a/src/sqlalchemy_cratedb/compat/api13.py +++ b/src/sqlalchemy_cratedb/compat/api13.py @@ -40,7 +40,6 @@ from sqlalchemy.sql import select as original_select from sqlalchemy.util import immutabledict - # `_distill_params_20` copied from SA14's `sqlalchemy.engine.{base,util}`. _no_tuple = () _no_kw = immutabledict() @@ -52,9 +51,7 @@ def _distill_params_20(params): elif isinstance(params, list): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ if params and not isinstance(params[0], (collections_abc.Mapping, tuple)): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) + raise exc.ArgumentError("List argument must consist only of tuples or dictionaries") return (params,), _no_kw elif isinstance( @@ -74,8 +71,7 @@ def exec_driver_sql(self, statement, parameters=None, execution_options=None): """ if execution_options is not None: raise ValueError( - "SA13 backward-compatibility: " - "`exec_driver_sql` does not support `execution_options`" + "SA13 backward-compatibility: " "`exec_driver_sql` does not support `execution_options`" ) args_10style, kwargs_10style = _distill_params_20(parameters) return self.execute(statement, *args_10style, **kwargs_10style) @@ -106,12 +102,11 @@ def select_sa14(*columns, **kw) -> Select: Derived from https://github.com/sqlalchemy/alembic/blob/b1fad6b6/alembic/util/sqla_compat.py#L557-L558 sqlalchemy.exc.ArgumentError: columns argument to select() must be a Python list or other iterable - """ + """ # noqa: E501 if isinstance(columns, tuple) and isinstance(columns[0], list): if "whereclause" in kw: raise ValueError( - "SA13 backward-compatibility: " - "`whereclause` is both in kwargs and columns tuple" + "SA13 backward-compatibility: " "`whereclause` is both in kwargs and columns tuple" ) columns, whereclause = columns kw["whereclause"] = whereclause @@ -153,4 +148,5 @@ def connectionfairy_driver_connection_sa14(self): def monkeypatch_add_connectionfairy_driver_connection(): import sqlalchemy.pool.base + sqlalchemy.pool.base._ConnectionFairy.driver_connection = connectionfairy_driver_connection_sa14 diff --git a/src/sqlalchemy_cratedb/compat/core10.py b/src/sqlalchemy_cratedb/compat/core10.py index 1dce6c7..aae9c52 100644 --- a/src/sqlalchemy_cratedb/compat/core10.py +++ b/src/sqlalchemy_cratedb/compat/core10.py @@ -21,18 +21,21 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql.base import PGCompiler -from sqlalchemy.sql.crud import (REQUIRED, _create_bind_param, - _extend_values_for_multiparams, - _get_multitable_params, - _get_stmt_parameters_params, - _key_getters_for_crud_column, _scan_cols, - _scan_insert_from_select_cols) +from sqlalchemy.sql.crud import ( + REQUIRED, + _create_bind_param, + _extend_values_for_multiparams, + _get_multitable_params, + _get_stmt_parameters_params, + _key_getters_for_crud_column, + _scan_cols, + _scan_insert_from_select_cols, +) from sqlalchemy_cratedb.compiler import CrateCompiler class CrateCompilerSA10(CrateCompiler): - def returning_clause(self, stmt, returning_cols): """ Generate RETURNING clause, PostgreSQL-compatible. @@ -46,70 +49,58 @@ def visit_update(self, update_stmt, **kw): """ # [10] CrateDB patch. - if not update_stmt.parameters and \ - not hasattr(update_stmt, '_crate_specific'): + if not update_stmt.parameters and not hasattr(update_stmt, "_crate_specific"): return super().visit_update(update_stmt, **kw) self.isupdate = True extra_froms = update_stmt._extra_froms - text = 'UPDATE ' + text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes(update_stmt, - update_stmt._prefixes, **kw) + text += self._generate_prefixes(update_stmt, update_stmt._prefixes, **kw) - table_text = self.update_tables_clause(update_stmt, update_stmt.table, - extra_froms, **kw) + table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) dialect_hints = None if update_stmt._hints: - dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text - ) + dialect_hints, table_text = self._setup_crud_hints(update_stmt, table_text) # [10] CrateDB patch. crud_params = _get_crud_params(self, update_stmt, **kw) text += table_text - text += ' SET ' + text += " SET " # [10] CrateDB patch begin. - include_table = \ - extra_froms and self.render_table_with_column_in_update_from + include_table = extra_froms and self.render_table_with_column_in_update_from set_clauses = [] for k, v in crud_params: - clause = k._compiler_dispatch(self, - include_table=include_table) + \ - ' = ' + v + clause = k._compiler_dispatch(self, include_table=include_table) + " = " + v set_clauses.append(clause) for k, v in update_stmt.parameters.items(): - if isinstance(k, str) and '[' in k: + if isinstance(k, str) and "[" in k: bindparam = sa.sql.bindparam(k, v) - set_clauses.append(k + ' = ' + self.process(bindparam)) + set_clauses.append(k + " = " + self.process(bindparam)) - text += ', '.join(set_clauses) + text += ", ".join(set_clauses) # [10] CrateDB patch end. if self.returning or update_stmt._returning: if not self.returning: self.returning = update_stmt._returning if self.returning_precedes_values: - text += " " + self.returning_clause( - update_stmt, self.returning) + text += " " + self.returning_clause(update_stmt, self.returning) if extra_froms: extra_from_text = self.update_from_clause( - update_stmt, - update_stmt.table, - extra_froms, - dialect_hints, - **kw) + update_stmt, update_stmt.table, extra_froms, dialect_hints, **kw + ) if extra_from_text: text += " " + extra_from_text @@ -123,8 +114,7 @@ def visit_update(self, update_stmt, **kw): text += " " + limit_clause if self.returning and not self.returning_precedes_values: - text += " " + self.returning_clause( - update_stmt, self.returning) + text += " " + self.returning_clause(update_stmt, self.returning) return text @@ -149,8 +139,7 @@ def _get_crud_params(compiler, stmt, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and stmt.parameters is None: return [ - (c, _create_bind_param(compiler, c, None, required=True)) - for c in stmt.table.columns + (c, _create_bind_param(compiler, c, None, required=True)) for c in stmt.table.columns ] if stmt._has_multi_parameters: diff --git a/src/sqlalchemy_cratedb/compat/core14.py b/src/sqlalchemy_cratedb/compat/core14.py index a77da5b..89a8222 100644 --- a/src/sqlalchemy_cratedb/compat/core14.py +++ b/src/sqlalchemy_cratedb/compat/core14.py @@ -22,18 +22,21 @@ import sqlalchemy as sa from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.sql import selectable -from sqlalchemy.sql.crud import (REQUIRED, _create_bind_param, - _extend_values_for_multiparams, - _get_stmt_parameter_tuples_params, - _get_update_multitable_params, - _key_getters_for_crud_column, _scan_cols, - _scan_insert_from_select_cols) +from sqlalchemy.sql.crud import ( + REQUIRED, + _create_bind_param, + _extend_values_for_multiparams, + _get_stmt_parameter_tuples_params, + _get_update_multitable_params, + _key_getters_for_crud_column, + _scan_cols, + _scan_insert_from_select_cols, +) from sqlalchemy_cratedb.compiler import CrateCompiler class CrateCompilerSA14(CrateCompiler): - def returning_clause(self, stmt, returning_cols): """ Generate RETURNING clause, PostgreSQL-compatible. @@ -41,15 +44,11 @@ def returning_clause(self, stmt, returning_cols): return PGCompiler.returning_clause(self, stmt, returning_cols) def visit_update(self, update_stmt, **kw): - - compile_state = update_stmt._compile_state_factory( - update_stmt, self, **kw - ) + compile_state = update_stmt._compile_state_factory(update_stmt, self, **kw) update_stmt = compile_state.statement # [14] CrateDB patch. - if not compile_state._dict_parameters and \ - not hasattr(update_stmt, '_crate_specific'): + if not compile_state._dict_parameters and not hasattr(update_stmt, "_crate_specific"): return super().visit_update(update_stmt, **kw) toplevel = not self.stack @@ -64,9 +63,7 @@ def visit_update(self, update_stmt, **kw): if is_multitable: # main table might be a JOIN main_froms = set(selectable._from_objects(update_stmt.table)) - render_extra_froms = [ - f for f in extra_froms if f not in main_froms - ] + render_extra_froms = [f for f in extra_froms if f not in main_froms] correlate_froms = main_froms.union(extra_froms) else: render_extra_froms = [] @@ -83,23 +80,17 @@ def visit_update(self, update_stmt, **kw): text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes( - update_stmt, update_stmt._prefixes, **kw - ) + text += self._generate_prefixes(update_stmt, update_stmt._prefixes, **kw) table_text = self.update_tables_clause( update_stmt, update_stmt.table, render_extra_froms, **kw ) # [14] CrateDB patch. - crud_params = _get_crud_params( - self, update_stmt, compile_state, **kw - ) + crud_params = _get_crud_params(self, update_stmt, compile_state, **kw) if update_stmt._hints: - dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text - ) + dialect_hints, table_text = self._setup_crud_hints(update_stmt, table_text) else: dialect_hints = None @@ -112,23 +103,22 @@ def visit_update(self, update_stmt, **kw): text += " SET " # [14] CrateDB patch begin. - include_table = \ - extra_froms and self.render_table_with_column_in_update_from + include_table = extra_froms and self.render_table_with_column_in_update_from set_clauses = [] - for c, expr, value in crud_params: + for c, expr, value in crud_params: # noqa: B007 key = c._compiler_dispatch(self, include_table=include_table) - clause = key + ' = ' + value + clause = key + " = " + value set_clauses.append(clause) for k, v in compile_state._dict_parameters.items(): - if isinstance(k, str) and '[' in k: + if isinstance(k, str) and "[" in k: bindparam = sa.sql.bindparam(k, v) - clause = k + ' = ' + self.process(bindparam) + clause = k + " = " + self.process(bindparam) set_clauses.append(clause) - text += ', '.join(set_clauses) + text += ", ".join(set_clauses) # [14] CrateDB patch end. if self.returning or update_stmt._returning: @@ -139,19 +129,13 @@ def visit_update(self, update_stmt, **kw): if extra_froms: extra_from_text = self.update_from_clause( - update_stmt, - update_stmt.table, - render_extra_froms, - dialect_hints, - **kw + update_stmt, update_stmt.table, render_extra_froms, dialect_hints, **kw ) if extra_from_text: text += " " + extra_from_text if update_stmt._where_criteria: - t = self._generate_delimited_and_list( - update_stmt._where_criteria, **kw - ) + t = self._generate_delimited_and_list(update_stmt._where_criteria, **kw) if t: text += " WHERE " + t @@ -159,9 +143,7 @@ def visit_update(self, update_stmt, **kw): if limit_clause: text += " " + limit_clause - if ( - self.returning or update_stmt._returning - ) and not self.returning_precedes_values: + if (self.returning or update_stmt._returning) and not self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, self.returning or update_stmt._returning ) @@ -232,14 +214,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): parameters = {} elif stmt_parameter_tuples: parameters = dict( - (_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if key not in spd + (_column_as_key(key), REQUIRED) for key in compiler.column_keys if key not in spd ) else: - parameters = dict( - (_column_as_key(key), REQUIRED) for key in compiler.column_keys - ) + parameters = dict((_column_as_key(key), REQUIRED) for key in compiler.column_keys) # create a list of column assignment clauses as tuples values = [] @@ -340,9 +318,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): kw, ) elif ( - not values - and compiler.for_executemany # noqa: W503 - and compiler.dialect.supports_default_metavalue # noqa: W503 + not values + and compiler.for_executemany # noqa: W503 + and compiler.dialect.supports_default_metavalue # noqa: W503 ): # convert an "INSERT DEFAULT VALUES" # into INSERT (firstcol) VALUES (DEFAULT) which can be turned diff --git a/src/sqlalchemy_cratedb/compat/core20.py b/src/sqlalchemy_cratedb/compat/core20.py index a398509..0bf0cb2 100644 --- a/src/sqlalchemy_cratedb/compat/core20.py +++ b/src/sqlalchemy_cratedb/compat/core20.py @@ -19,6 +19,8 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +# ruff: noqa: S101 Use of `assert` detected + from typing import Any, Dict, List, MutableMapping, Optional, Tuple, Union import sqlalchemy as sa @@ -26,14 +28,20 @@ from sqlalchemy.sql import dml from sqlalchemy.sql.base import _from_objects from sqlalchemy.sql.compiler import SQLCompiler -from sqlalchemy.sql.crud import (REQUIRED, _as_dml_column, _create_bind_param, - _CrudParamElement, _CrudParams, - _extend_values_for_multiparams, - _get_stmt_parameter_tuples_params, - _get_update_multitable_params, - _key_getters_for_crud_column, _scan_cols, - _scan_insert_from_select_cols, - _setup_delete_return_defaults) +from sqlalchemy.sql.crud import ( + REQUIRED, + _as_dml_column, + _create_bind_param, + _CrudParamElement, + _CrudParams, + _extend_values_for_multiparams, + _get_stmt_parameter_tuples_params, + _get_update_multitable_params, + _key_getters_for_crud_column, + _scan_cols, + _scan_insert_from_select_cols, + _setup_delete_return_defaults, +) from sqlalchemy.sql.dml import DMLState, _DMLColumnElement from sqlalchemy.sql.dml import isinsert as _compile_state_isinsert @@ -41,16 +49,12 @@ class CrateCompilerSA20(CrateCompiler): - def visit_update(self, update_stmt, **kw): - compile_state = update_stmt._compile_state_factory( - update_stmt, self, **kw - ) + compile_state = update_stmt._compile_state_factory(update_stmt, self, **kw) update_stmt = compile_state.statement # [20] CrateDB patch. - if not compile_state._dict_parameters and \ - not hasattr(update_stmt, '_crate_specific'): + if not compile_state._dict_parameters and not hasattr(update_stmt, "_crate_specific"): return super().visit_update(update_stmt, **kw) toplevel = not self.stack @@ -67,9 +71,7 @@ def visit_update(self, update_stmt, **kw): if is_multitable: # main table might be a JOIN main_froms = set(_from_objects(update_stmt.table)) - render_extra_froms = [ - f for f in extra_froms if f not in main_froms - ] + render_extra_froms = [f for f in extra_froms if f not in main_froms] correlate_froms = main_froms.union(extra_froms) else: render_extra_froms = [] @@ -86,23 +88,17 @@ def visit_update(self, update_stmt, **kw): text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes( - update_stmt, update_stmt._prefixes, **kw - ) + text += self._generate_prefixes(update_stmt, update_stmt._prefixes, **kw) table_text = self.update_tables_clause( update_stmt, update_stmt.table, render_extra_froms, **kw ) # [20] CrateDB patch. - crud_params_struct = _get_crud_params( - self, update_stmt, compile_state, toplevel, **kw - ) + crud_params_struct = _get_crud_params(self, update_stmt, compile_state, toplevel, **kw) crud_params = crud_params_struct.single_params if update_stmt._hints: - dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text - ) + dialect_hints, table_text = self._setup_crud_hints(update_stmt, table_text) else: dialect_hints = None @@ -114,23 +110,22 @@ def visit_update(self, update_stmt, **kw): text += " SET " # [20] CrateDB patch begin. - include_table = extra_froms and \ - self.render_table_with_column_in_update_from + include_table = extra_froms and self.render_table_with_column_in_update_from set_clauses = [] - for c, expr, value, _ in crud_params: + for c, expr, value, _ in crud_params: # noqa: B007 key = c._compiler_dispatch(self, include_table=include_table) - clause = key + ' = ' + value + clause = key + " = " + value set_clauses.append(clause) for k, v in compile_state._dict_parameters.items(): - if isinstance(k, str) and '[' in k: + if isinstance(k, str) and "[" in k: bindparam = sa.sql.bindparam(k, v) - clause = k + ' = ' + self.process(bindparam) + clause = k + " = " + self.process(bindparam) set_clauses.append(clause) - text += ', '.join(set_clauses) + text += ", ".join(set_clauses) # [20] CrateDB patch end. if self.implicit_returning or update_stmt._returning: @@ -153,9 +148,7 @@ def visit_update(self, update_stmt, **kw): text += " " + extra_from_text if update_stmt._where_criteria: - t = self._generate_delimited_and_list( - update_stmt._where_criteria, **kw - ) + t = self._generate_delimited_and_list(update_stmt._where_criteria, **kw) if t: text += " WHERE " + t @@ -275,15 +268,10 @@ def _get_crud_params( [], ) - stmt_parameter_tuples: Optional[ - List[Tuple[Union[str, ColumnClause[Any]], Any]] - ] + stmt_parameter_tuples: Optional[List[Tuple[Union[str, ColumnClause[Any]], Any]]] spd: Optional[MutableMapping[_DMLColumnElement, Any]] - if ( - _compile_state_isinsert(compile_state) - and compile_state._has_multi_parameters - ): + if _compile_state_isinsert(compile_state) and compile_state._has_multi_parameters: mp = compile_state._multi_parameters assert mp is not None spd = mp[0] @@ -304,14 +292,10 @@ def _get_crud_params( elif stmt_parameter_tuples: assert spd is not None parameters = { - _column_as_key(key): REQUIRED - for key in compiler.column_keys - if key not in spd + _column_as_key(key): REQUIRED for key in compiler.column_keys if key not in spd } else: - parameters = { - _column_as_key(key): REQUIRED for key in compiler.column_keys - } + parameters = {_column_as_key(key): REQUIRED for key in compiler.column_keys} # create a list of column assignment clauses as tuples values: List[_CrudParamElement] = [] @@ -408,10 +392,7 @@ def _get_crud_params( ) """ - if ( - _compile_state_isinsert(compile_state) - and compile_state._has_multi_parameters - ): + if _compile_state_isinsert(compile_state) and compile_state._has_multi_parameters: # is a multiparams, is not an insert from a select assert not stmt._select_names multi_extended_values = _extend_values_for_multiparams( @@ -426,11 +407,7 @@ def _get_crud_params( kw, ) return _CrudParams(values, multi_extended_values) - elif ( - not values - and compiler.for_executemany - and compiler.dialect.supports_default_metavalue - ): + elif not values and compiler.for_executemany and compiler.dialect.supports_default_metavalue: # convert an "INSERT DEFAULT VALUES" # into INSERT (firstcol) VALUES (DEFAULT) which can be turned # into an in-place multi values. This supports diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 6b94f8b..7a81982 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -24,17 +24,18 @@ from collections import defaultdict import sqlalchemy as sa -from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.dialects.postgresql.base import RESERVED_WORDS as POSTGRESQL_RESERVED_WORDS +from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.sql import compiler from sqlalchemy.types import String + +from .sa_version import SA_1_4, SA_VERSION from .type.geo import Geopoint, Geoshape from .type.object import MutableDict, ObjectTypeImpl -from .sa_version import SA_VERSION, SA_1_4 def rewrite_update(clauseelement, multiparams, params): - """ change the params to enable partial updates + """change the params to enable partial updates sqlalchemy by default only supports updates of complex types in the form of @@ -55,9 +56,8 @@ def rewrite_update(clauseelement, multiparams, params): for _params in _multiparams: newparams = {} for key, val in _params.items(): - if ( - not isinstance(val, MutableDict) or - (not any(val._changed_keys) and not any(val._deleted_keys)) + if not isinstance(val, MutableDict) or ( + not any(val._changed_keys) and not any(val._deleted_keys) ): newparams[key] = val continue @@ -68,7 +68,7 @@ def rewrite_update(clauseelement, multiparams, params): for subkey in val._deleted_keys: newparams["{0}['{1}']".format(key, subkey)] = None newmultiparams.append(newparams) - _multiparams = (newmultiparams, ) + _multiparams = (newmultiparams,) clause = clauseelement.values(newmultiparams[0]) clause._crate_specific = True return clause, _multiparams, params @@ -76,7 +76,7 @@ def rewrite_update(clauseelement, multiparams, params): @sa.event.listens_for(sa.engine.Engine, "before_execute", retval=True) def crate_before_execute(conn, clauseelement, multiparams, params, *args, **kwargs): - is_crate = type(conn.dialect).__name__ == 'CrateDialect' + is_crate = type(conn.dialect).__name__ == "CrateDialect" if is_crate and isinstance(clauseelement, sa.sql.expression.Update): if SA_VERSION >= SA_1_4: if params is None: @@ -98,19 +98,19 @@ def crate_before_execute(conn, clauseelement, multiparams, params, *args, **kwar class CrateDDLCompiler(compiler.DDLCompiler): - - __special_opts_tmpl = { - 'partitioned_by': ' PARTITIONED BY ({0})' - } + __special_opts_tmpl = {"partitioned_by": " PARTITIONED BY ({0})"} __clustered_opts_tmpl = { - 'number_of_shards': ' INTO {0} SHARDS', - 'clustered_by': ' BY ({0})', + "number_of_shards": " INTO {0} SHARDS", + "clustered_by": " BY ({0})", } - __clustered_opt_tmpl = ' CLUSTERED{clustered_by}{number_of_shards}' + __clustered_opt_tmpl = " CLUSTERED{clustered_by}{number_of_shards}" def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process(column.type) + ) default = self.get_column_default_string(column) if default is not None: @@ -122,11 +122,9 @@ def get_column_specification(self, column, **kwargs): if column.nullable is False: colspec += " NOT NULL" elif column.nullable and column.primary_key: - raise sa.exc.CompileError( - "Primary key columns cannot be nullable" - ) + raise sa.exc.CompileError("Primary key columns cannot be nullable") - if column.dialect_options['crate'].get('index') is False: + if column.dialect_options["crate"].get("index") is False: if isinstance(column.type, (Geopoint, Geoshape, ObjectTypeImpl)): raise sa.exc.CompileError( "Disabling indexing is not supported for column " @@ -135,8 +133,8 @@ def get_column_specification(self, column, **kwargs): colspec += " INDEX OFF" - if column.dialect_options['crate'].get('columnstore') is False: - if not isinstance(column.type, (String, )): + if column.dialect_options["crate"].get("columnstore") is False: + if not isinstance(column.type, (String,)): raise sa.exc.CompileError( "Controlling the columnstore is only allowed for STRING columns" ) @@ -148,8 +146,7 @@ def get_column_specification(self, column, **kwargs): def visit_computed_column(self, generated): if generated.persisted is False: raise sa.exc.CompileError( - "Virtual computed columns are not supported, set " - "'persisted' to None or True" + "Virtual computed columns are not supported, set " "'persisted' to None or True" ) return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( @@ -157,14 +154,14 @@ def visit_computed_column(self, generated): ) def post_create_table(self, table): - special_options = '' + special_options = "" clustered_options = defaultdict(str) table_opts = [] opts = dict( - (k[len(self.dialect.name) + 1:], v) - for k, v, in table.kwargs.items() - if k.startswith('%s_' % self.dialect.name) + (k[len(self.dialect.name) + 1 :], v) + for k, v in table.kwargs.items() + if k.startswith("%s_" % self.dialect.name) ) for k, v in opts.items(): if k in self.__special_opts_tmpl: @@ -172,69 +169,73 @@ def post_create_table(self, table): elif k in self.__clustered_opts_tmpl: clustered_options[k] = self.__clustered_opts_tmpl[k].format(v) else: - table_opts.append('{0} = {1}'.format(k, v)) + table_opts.append("{0} = {1}".format(k, v)) if clustered_options: special_options += string.Formatter().vformat( - self.__clustered_opt_tmpl, (), clustered_options) + self.__clustered_opt_tmpl, (), clustered_options + ) if table_opts: - return special_options + ' WITH ({0})'.format( - ', '.join(sorted(table_opts))) + return special_options + " WITH ({0})".format(", ".join(sorted(table_opts))) return special_options def visit_foreign_key_constraint(self, constraint, **kw): """ CrateDB does not support foreign key constraints. """ - warnings.warn("CrateDB does not support foreign key constraints, " - "they will be omitted when generating DDL statements.") - return None + warnings.warn( + "CrateDB does not support foreign key constraints, " + "they will be omitted when generating DDL statements.", + stacklevel=2, + ) + return def visit_unique_constraint(self, constraint, **kw): """ CrateDB does not support unique key constraints. """ - warnings.warn("CrateDB does not support unique constraints, " - "they will be omitted when generating DDL statements.") - return None + warnings.warn( + "CrateDB does not support unique constraints, " + "they will be omitted when generating DDL statements.", + stacklevel=2, + ) + return class CrateTypeCompiler(compiler.GenericTypeCompiler): - def visit_string(self, type_, **kw): - return 'STRING' + return "STRING" def visit_unicode(self, type_, **kw): - return 'STRING' + return "STRING" def visit_TEXT(self, type_, **kw): - return 'STRING' + return "STRING" def visit_DECIMAL(self, type_, **kw): - return 'DOUBLE' + return "DOUBLE" def visit_BIGINT(self, type_, **kw): - return 'LONG' + return "LONG" def visit_NUMERIC(self, type_, **kw): - return 'LONG' + return "LONG" def visit_INTEGER(self, type_, **kw): - return 'INT' + return "INT" def visit_SMALLINT(self, type_, **kw): - return 'SHORT' + return "SHORT" def visit_datetime(self, type_, **kw): return self.visit_TIMESTAMP(type_, **kw) def visit_date(self, type_, **kw): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_ARRAY(self, type_, **kw): if type_.dimensions is not None and type_.dimensions > 1: - raise NotImplementedError( - "CrateDB doesn't support multidimensional arrays") - return 'ARRAY({0})'.format(self.process(type_.item_type)) + raise NotImplementedError("CrateDB doesn't support multidimensional arrays") + return "ARRAY({0})".format(self.process(type_.item_type)) def visit_OBJECT(self, type_, **kw): return "OBJECT" @@ -251,32 +252,21 @@ def visit_TIMESTAMP(self, type_, **kw): From `sqlalchemy.dialects.postgresql.base.PGTypeCompiler`. """ - return "TIMESTAMP %s" % ( - (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", - ) + return "TIMESTAMP %s" % ((type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",) class CrateCompiler(compiler.SQLCompiler): - def visit_getitem_binary(self, binary, operator, **kw): - return "{0}['{1}']".format( - self.process(binary.left, **kw), - binary.right.value - ) + return "{0}['{1}']".format(self.process(binary.left, **kw), binary.right.value) - def visit_json_getitem_op_binary( - self, binary, operator, _cast_applied=False, **kw - ): - return "{0}['{1}']".format( - self.process(binary.left, **kw), - binary.right.value - ) + def visit_json_getitem_op_binary(self, binary, operator, _cast_applied=False, **kw): + return "{0}['{1}']".format(self.process(binary.left, **kw), binary.right.value) def visit_any(self, element, **kw): return "%s%sANY (%s)" % ( self.process(element.left, **kw), compiler.OPERATORS[element.operator], - self.process(element.right, **kw) + self.process(element.right, **kw), ) def visit_ilike_case_insensitive_operand(self, element, **kw): @@ -331,29 +321,32 @@ def limit_clause(self, select, **kw): def for_update_clause(self, select, **kw): # CrateDB does not support the `INSERT ... FOR UPDATE` clause. # See https://github.com/crate/crate-python/issues/577. - warnings.warn("CrateDB does not support the 'INSERT ... FOR UPDATE' clause, " - "it will be omitted when generating SQL statements.") - return '' + warnings.warn( + "CrateDB does not support the 'INSERT ... FOR UPDATE' clause, " + "it will be omitted when generating SQL statements.", + stacklevel=2, + ) + return "" -CRATEDB_RESERVED_WORDS = \ - "add, alter, between, by, called, costs, delete, deny, directory, drop, escape, exists, " \ - "extract, first, function, if, index, input, insert, last, match, nulls, object, " \ - "persistent, recursive, reset, returns, revoke, set, stratify, transient, try_cast, " \ +CRATEDB_RESERVED_WORDS = ( + "add, alter, between, by, called, costs, delete, deny, directory, drop, escape, exists, " + "extract, first, function, if, index, input, insert, last, match, nulls, object, " + "persistent, recursive, reset, returns, revoke, set, stratify, transient, try_cast, " "unbounded, update".split(", ") +) class CrateIdentifierPreparer(sa.sql.compiler.IdentifierPreparer): """ Define CrateDB's reserved words to be quoted properly. """ + reserved_words = set(list(POSTGRESQL_RESERVED_WORDS) + CRATEDB_RESERVED_WORDS) def _unquote_identifier(self, value): if value[0] == self.initial_quote: - value = value[1:-1].replace( - self.escape_to_quote, self.escape_quote - ) + value = value[1:-1].replace(self.escape_to_quote, self.escape_quote) return value def format_type(self, type_, use_schema=True): @@ -363,10 +356,6 @@ def format_type(self, type_, use_schema=True): name = self.quote(type_.name) effective_schema = self.schema_for_object(type_) - if ( - not self.omit_schema - and use_schema - and effective_schema is not None - ): + if not self.omit_schema and use_schema and effective_schema is not None: name = self.quote_schema(effective_schema) + "." + name return name diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 6786051..b26ce81 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -20,7 +20,7 @@ # software solely pursuant to the terms of the relevant commercial agreement. import logging -from datetime import datetime, date +from datetime import date, datetime from sqlalchemy import types as sqltypes from sqlalchemy.engine import default, reflection @@ -28,11 +28,11 @@ from sqlalchemy.util import asbool, to_list from .compiler import ( - CrateTypeCompiler, CrateDDLCompiler, CrateIdentifierPreparer, + CrateTypeCompiler, ) -from .sa_version import SA_VERSION, SA_1_4, SA_2_0 +from .sa_version import SA_1_4, SA_2_0, SA_VERSION from .type import FloatVector, ObjectArray, ObjectType TYPES_MAP = { @@ -54,9 +54,12 @@ "text": sqltypes.String, "float_vector": FloatVector, } + +# Needed for SQLAlchemy >= 1.1. +# TODO: Dissolve. try: - # SQLAlchemy >= 1.1 from sqlalchemy.types import ARRAY + TYPES_MAP["integer_array"] = ARRAY(sqltypes.Integer) TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean) TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger) @@ -71,7 +74,7 @@ TYPES_MAP["real_array"] = ARRAY(sqltypes.Float) TYPES_MAP["string_array"] = ARRAY(sqltypes.String) TYPES_MAP["text_array"] = ARRAY(sqltypes.String) -except Exception: +except Exception: # noqa: S110 pass @@ -82,14 +85,16 @@ class Date(sqltypes.Date): def bind_processor(self, dialect): def process(value): if value is not None: - assert isinstance(value, date) - return value.strftime('%Y-%m-%d') + assert isinstance(value, date) # noqa: S101 + return value.strftime("%Y-%m-%d") + return None + return process def result_processor(self, dialect, coltype): def process(value): if not value: - return + return None try: return datetime.utcfromtimestamp(value / 1e3).date() except TypeError: @@ -103,27 +108,29 @@ def process(value): # the date will be returned in the format it was inserted. log.warning( "Received timestamp isn't a long value." - "Trying to parse as date string and then as datetime string") + "Trying to parse as date string and then as datetime string" + ) try: - return datetime.strptime(value, '%Y-%m-%d').date() + return datetime.strptime(value, "%Y-%m-%d").date() except ValueError: - return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ').date() + return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ").date() + return process class DateTime(sqltypes.DateTime): - def bind_processor(self, dialect): def process(value): if isinstance(value, (datetime, date)): - return value.strftime('%Y-%m-%dT%H:%M:%S.%f%z') + return value.strftime("%Y-%m-%dT%H:%M:%S.%f%z") return value + return process def result_processor(self, dialect, coltype): def process(value): if not value: - return + return None try: return datetime.utcfromtimestamp(value / 1e3) except TypeError: @@ -137,11 +144,13 @@ def process(value): # the date will be returned in the format it was inserted. log.warning( "Received timestamp isn't a long value." - "Trying to parse as datetime string and then as date string") + "Trying to parse as datetime string and then as date string" + ) try: - return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ') + return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ") except ValueError: - return datetime.strptime(value, '%Y-%m-%d') + return datetime.strptime(value, "%Y-%m-%d") + return process @@ -154,19 +163,22 @@ def process(value): if SA_VERSION >= SA_2_0: from .compat.core20 import CrateCompilerSA20 + statement_compiler = CrateCompilerSA20 elif SA_VERSION >= SA_1_4: from .compat.core14 import CrateCompilerSA14 + statement_compiler = CrateCompilerSA14 else: from .compat.core10 import CrateCompilerSA10 + statement_compiler = CrateCompilerSA10 class CrateDialect(default.DefaultDialect): - name = 'crate' - driver = 'crate-python' - default_paramstyle = 'qmark' + name = "crate" + driver = "crate-python" + default_paramstyle = "qmark" statement_compiler = statement_compiler ddl_compiler = CrateDDLCompiler type_compiler = CrateTypeCompiler @@ -192,15 +204,13 @@ def __init__(self, **kwargs): # Currently, our SQL parser doesn't support unquoted column names that # start with _. Adding it here causes sqlalchemy to quote such columns. - self.identifier_preparer.illegal_initial_characters.add('_') + self.identifier_preparer.illegal_initial_characters.add("_") def initialize(self, connection): # get lowest server version - self.server_version_info = \ - self._get_server_version_info(connection) + self.server_version_info = self._get_server_version_info(connection) # get default schema name - self.default_schema_name = \ - self._get_default_schema_name(connection) + self.default_schema_name = self._get_default_schema_name(connection) def do_rollback(self, connection): # if any exception is raised by the dbapi, sqlalchemy by default @@ -212,9 +222,9 @@ def do_rollback(self, connection): def connect(self, host=None, port=None, *args, **kwargs): server = None if host: - server = '{0}:{1}'.format(host, port or '4200') - if 'servers' in kwargs: - server = kwargs.pop('servers') + server = "{0}:{1}".format(host, port or "4200") + if "servers" in kwargs: + server = kwargs.pop("servers") servers = to_list(server) if servers: use_ssl = asbool(kwargs.pop("ssl", False)) @@ -224,7 +234,7 @@ def connect(self, host=None, port=None, *args, **kwargs): return self.dbapi.connect(**kwargs) def _get_default_schema_name(self, connection): - return 'doc' + return "doc" def _get_effective_schema_name(self, connection): schema_name_raw = connection.engine.url.query.get("schema") @@ -241,6 +251,7 @@ def _get_server_version_info(self, connection): @classmethod def import_dbapi(cls): from crate import client + return client @classmethod @@ -256,9 +267,7 @@ def has_table(self, connection, table_name, schema=None, **kw): @reflection.cache def get_schema_names(self, connection, **kw): cursor = connection.exec_driver_sql( - "select schema_name " - "from information_schema.schemata " - "order by schema_name asc" + "select schema_name " "from information_schema.schemata " "order by schema_name asc" ) return [row[0] for row in cursor.fetchall()] @@ -271,7 +280,7 @@ def get_table_names(self, connection, schema=None, **kw): "WHERE {0} = ? " "AND table_type = 'BASE TABLE' " "ORDER BY table_name ASC, {0} ASC".format(self.schema_column), - (schema or self.default_schema_name, ) + (schema or self.default_schema_name,), ) return [row[0] for row in cursor.fetchall()] @@ -280,22 +289,25 @@ def get_view_names(self, connection, schema=None, **kw): cursor = connection.exec_driver_sql( "SELECT table_name FROM information_schema.views " "ORDER BY table_name ASC, {0} ASC".format(self.schema_column), - (schema or self.default_schema_name, ) + (schema or self.default_schema_name,), ) return [row[0] for row in cursor.fetchall()] @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - query = "SELECT column_name, data_type " \ - "FROM information_schema.columns " \ - "WHERE table_name = ? AND {0} = ? " \ - "AND column_name !~ ?" \ - .format(self.schema_column) + query = ( + "SELECT column_name, data_type " + "FROM information_schema.columns " + "WHERE table_name = ? AND {0} = ? " + "AND column_name !~ ?".format(self.schema_column) + ) cursor = connection.exec_driver_sql( query, - (table_name, - schema or self.default_schema_name, - r"(.*)\[\'(.*)\'\]") # regex to filter subscript + ( + table_name, + schema or self.default_schema_name, + r"(.*)\[\'(.*)\'\]", + ), # regex to filter subscript ) return [self._create_column_info(row) for row in cursor.fetchall()] @@ -330,17 +342,14 @@ def result_fun(result): rows = result.fetchone() return set(rows[0] if rows else []) - pk_result = engine.exec_driver_sql( - query, - (table_name, schema or self.default_schema_name) - ) + pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name)) pks = result_fun(pk_result) - return {'constrained_columns': list(sorted(pks)), - 'name': 'PRIMARY KEY'} + return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"} @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, - postgresql_ignore_search_path=False, **kw): + def get_foreign_keys( + self, connection, table_name, schema=None, postgresql_ignore_search_path=False, **kw + ): # Crate doesn't support Foreign Keys, so this stays empty return [] @@ -354,12 +363,12 @@ def schema_column(self): def _create_column_info(self, row): return { - 'name': row[0], - 'type': self._resolve_type(row[1]), + "name": row[0], + "type": self._resolve_type(row[1]), # In Crate every column is nullable except PK # Primary Key Constraints are not nullable anyway, no matter what # we return here, so it's fine to return always `True` - 'nullable': True + "nullable": True, } def _resolve_type(self, type_): diff --git a/src/sqlalchemy_cratedb/predicate.py b/src/sqlalchemy_cratedb/predicate.py index 4f974f9..9be323e 100644 --- a/src/sqlalchemy_cratedb/predicate.py +++ b/src/sqlalchemy_cratedb/predicate.py @@ -19,8 +19,8 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. -from sqlalchemy.sql.expression import ColumnElement, literal from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.expression import ColumnElement, literal class Match(ColumnElement): @@ -35,9 +35,8 @@ def __init__(self, column, term, match_type=None, options=None): def compile_column(self, compiler): if isinstance(self.column, dict): - column = ', '.join( - sorted(["{0} {1}".format(compiler.process(k), v) - for k, v in self.column.items()]) + column = ", ".join( + sorted(["{0} {1}".format(compiler.process(k), v) for k, v in self.column.items()]) ) return "({0})".format(column) else: @@ -51,21 +50,22 @@ def compile_using(self, compiler): using = "using {0}".format(self.match_type) with_clause = self.with_clause() if with_clause: - using = ' '.join([using, with_clause]) + using = " ".join([using, with_clause]) return using if self.options: - raise ValueError("missing match_type. " + - "It's not allowed to specify options " + - "without match_type") + raise ValueError( + "missing match_type. " + + "It's not allowed to specify options " + + "without match_type" + ) + return None def with_clause(self): if self.options: - options = ', '.join( - sorted(["{0}={1}".format(k, v) - for k, v in self.options.items()]) - ) + options = ", ".join(sorted(["{0}={1}".format(k, v) for k, v in self.options.items()])) return "with ({0})".format(options) + return None def match(column, term, match_type=None, options=None): @@ -89,11 +89,8 @@ def match(column, term, match_type=None, options=None): @compiles(Match) def compile_match(match, compiler, **kwargs): - func = "match(%s, %s)" % ( - match.compile_column(compiler), - match.compile_term(compiler) - ) + func = "match(%s, %s)" % (match.compile_column(compiler), match.compile_term(compiler)) using = match.compile_using(compiler) if using: - func = ' '.join([func, using]) + func = " ".join([func, using]) return func diff --git a/src/sqlalchemy_cratedb/sa_version.py b/src/sqlalchemy_cratedb/sa_version.py index 6b45f8b..22f31e5 100644 --- a/src/sqlalchemy_cratedb/sa_version.py +++ b/src/sqlalchemy_cratedb/sa_version.py @@ -24,5 +24,5 @@ SA_VERSION = Version(sa.__version__) -SA_1_4 = Version('1.4.0b1') -SA_2_0 = Version('2.0.0') +SA_1_4 = Version("1.4.0b1") +SA_2_0 = Version("2.0.0") diff --git a/src/sqlalchemy_cratedb/support/__init__.py b/src/sqlalchemy_cratedb/support/__init__.py index d140d60..673f712 100644 --- a/src/sqlalchemy_cratedb/support/__init__.py +++ b/src/sqlalchemy_cratedb/support/__init__.py @@ -1,7 +1,10 @@ from sqlalchemy_cratedb.support.pandas import insert_bulk, table_kwargs -from sqlalchemy_cratedb.support.polyfill import check_uniqueness_factory, refresh_after_dml, \ - patch_autoincrement_timestamp -from sqlalchemy_cratedb.support.util import refresh_table, refresh_dirty +from sqlalchemy_cratedb.support.polyfill import ( + check_uniqueness_factory, + patch_autoincrement_timestamp, + refresh_after_dml, +) +from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table __all__ = [ check_uniqueness_factory, diff --git a/src/sqlalchemy_cratedb/support/pandas.py b/src/sqlalchemy_cratedb/support/pandas.py index 90c24ed..1a20b65 100644 --- a/src/sqlalchemy_cratedb/support/pandas.py +++ b/src/sqlalchemy_cratedb/support/pandas.py @@ -18,15 +18,14 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +import logging from contextlib import contextmanager from typing import Any from unittest.mock import patch -import logging - import sqlalchemy as sa -from sqlalchemy_cratedb import SA_VERSION, SA_2_0 +from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION logger = logging.getLogger(__name__) @@ -51,7 +50,7 @@ def insert_bulk(pd_table, conn, keys, data_iter): [1] https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html [2] https://cratedb.com/docs/crate/reference/en/latest/interfaces/http.html#bulk-operations [3] https://github.com/pandas-dev/pandas/blob/v2.0.1/pandas/io/sql.py#L1011-L1027 - """ + """ # noqa: E501 # Compile SQL statement and materialize batch. sql = str(pd_table.table.insert().compile(bind=conn)) @@ -61,7 +60,7 @@ def insert_bulk(pd_table, conn, keys, data_iter): if logger.level == logging.DEBUG: logger.debug(f"Bulk SQL: {sql}") logger.debug(f"Bulk records: {len(data)}") - # logger.debug(f"Bulk data: {data}") + # logger.debug(f"Bulk data: {data}") # noqa: ERA001 # Invoke bulk insert operation. cursor = conn._dbapi_connection.cursor() diff --git a/src/sqlalchemy_cratedb/support/polyfill.py b/src/sqlalchemy_cratedb/support/polyfill.py index 73177e5..22dad7c 100644 --- a/src/sqlalchemy_cratedb/support/polyfill.py +++ b/src/sqlalchemy_cratedb/support/polyfill.py @@ -1,6 +1,7 @@ +import typing as t + import sqlalchemy as sa from sqlalchemy.event import listen -import typing as t from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table @@ -39,7 +40,7 @@ def check_uniqueness_factory(sa_entity, *attribute_names): This is used by CrateDB's MLflow adapter. TODO: Maybe enable through a dialect parameter `crate_polyfill_unique` or such. - """ + """ # noqa: E501 # Synthesize a canonical "name" for the constraint, # composed of all column names involved. @@ -52,7 +53,9 @@ def check_uniqueness(mapper, connection, target): # TODO: How to use `session.query(SqlExperiment)` here? stmt = mapper.selectable.select() for attribute_name in attribute_names: - stmt = stmt.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) + stmt = stmt.filter( + getattr(sa_entity, attribute_name) == getattr(target, attribute_name) + ) stmt = stmt.compile(bind=connection.engine) results = connection.execute(stmt) if results.rowcount > 0: @@ -60,7 +63,8 @@ def check_uniqueness(mapper, connection, target): statement=stmt, params=[], orig=Exception( - f"DuplicateKeyException in table '{target.__tablename__}' " f"on constraint '{constraint_name}'" + f"DuplicateKeyException in table '{target.__tablename__}' " + f"on constraint '{constraint_name}'" ), ) @@ -103,6 +107,7 @@ def refresh_after_dml_engine(engine: sa.engine.Engine): This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` adapters. """ + def receive_after_execute( conn: sa.engine.Connection, clauseelement, multiparams, params, execution_options, result ): diff --git a/src/sqlalchemy_cratedb/support/util.py b/src/sqlalchemy_cratedb/support/util.py index 33cce5f..9b9b07f 100644 --- a/src/sqlalchemy_cratedb/support/util.py +++ b/src/sqlalchemy_cratedb/support/util.py @@ -10,7 +10,9 @@ pass -def refresh_table(connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]): +def refresh_table( + connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"] +): """ Invoke a `REFRESH TABLE` statement. """ diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index 36ba817..b524bb3 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -2,3 +2,12 @@ from .geo import Geopoint, Geoshape from .object import ObjectType from .vector import FloatVector, knn_match + +__all__ = [ + Geopoint, + Geoshape, + ObjectArray, + ObjectType, + FloatVector, + knn_match, +] diff --git a/src/sqlalchemy_cratedb/type/array.py b/src/sqlalchemy_cratedb/type/array.py index ae68d4b..801ef21 100644 --- a/src/sqlalchemy_cratedb/type/array.py +++ b/src/sqlalchemy_cratedb/type/array.py @@ -20,16 +20,14 @@ # software solely pursuant to the terms of the relevant commercial agreement. import sqlalchemy.types as sqltypes -from sqlalchemy.sql import operators, expression -from sqlalchemy.sql import default_comparator from sqlalchemy.ext.mutable import Mutable +from sqlalchemy.sql import default_comparator, expression, operators class MutableList(Mutable, list): - @classmethod def coerce(cls, key, value): - """ Convert plain list to MutableList """ + """Convert plain list to MutableList""" if not isinstance(value, MutableList): if isinstance(value, list): return MutableList(value) @@ -85,7 +83,8 @@ class Any(expression.ColumnElement): ARRAY-bound method """ - __visit_name__ = 'any' + + __visit_name__ = "any" inherit_cache = True def __init__(self, left, right, operator=operators.eq): @@ -100,9 +99,7 @@ class _ObjectArray(sqltypes.UserDefinedType): class Comparator(sqltypes.TypeEngine.Comparator): def __getitem__(self, key): - return default_comparator._binary_operate(self.expr, - operators.getitem, - key) + return default_comparator._binary_operate(self.expr, operators.getitem, key) def any(self, other, operator=operators.eq): """Return ``other operator ANY (array)`` clause. diff --git a/src/sqlalchemy_cratedb/type/geo.py b/src/sqlalchemy_cratedb/type/geo.py index 31abd27..6bf8414 100644 --- a/src/sqlalchemy_cratedb/type/geo.py +++ b/src/sqlalchemy_cratedb/type/geo.py @@ -7,20 +7,18 @@ class Geopoint(sqltypes.UserDefinedType): cache_ok = True class Comparator(sqltypes.TypeEngine.Comparator): - def __getitem__(self, key): - return default_comparator._binary_operate(self.expr, - operators.getitem, - key) + return default_comparator._binary_operate(self.expr, operators.getitem, key) def get_col_spec(self): - return 'GEO_POINT' + return "GEO_POINT" def bind_processor(self, dialect): def process(value): if isinstance(value, geojson.Point): return value.coordinates return value + return process def result_processor(self, dialect, coltype): @@ -33,14 +31,11 @@ class Geoshape(sqltypes.UserDefinedType): cache_ok = True class Comparator(sqltypes.TypeEngine.Comparator): - def __getitem__(self, key): - return default_comparator._binary_operate(self.expr, - operators.getitem, - key) + return default_comparator._binary_operate(self.expr, operators.getitem, key) def get_col_spec(self): - return 'GEO_SHAPE' + return "GEO_SHAPE" def result_processor(self, dialect, coltype): return geojson.GeoJSON.to_instance diff --git a/src/sqlalchemy_cratedb/type/object.py b/src/sqlalchemy_cratedb/type/object.py index 2b1b66c..31f55dc 100644 --- a/src/sqlalchemy_cratedb/type/object.py +++ b/src/sqlalchemy_cratedb/type/object.py @@ -5,7 +5,6 @@ class MutableDict(Mutable, dict): - @classmethod def coerce(cls, key, value): "Convert plain dictionaries to MutableDict." @@ -26,17 +25,17 @@ def __init__(self, initval=None, to_update=None, root_change_key=None): self._overwrite_key = root_change_key self.to_update = self if to_update is None else to_update for k in initval: - initval[k] = self._convert_dict(initval[k], - overwrite_key=k if self._overwrite_key is None else self._overwrite_key - ) + initval[k] = self._convert_dict( + initval[k], overwrite_key=k if self._overwrite_key is None else self._overwrite_key + ) dict.__init__(self, initval) def __setitem__(self, key, value): - value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key) - dict.__setitem__(self, key, value) - self.to_update.on_key_changed( - key if self._overwrite_key is None else self._overwrite_key + value = self._convert_dict( + value, key if self._overwrite_key is None else self._overwrite_key ) + dict.__setitem__(self, key, value) + self.to_update.on_key_changed(key if self._overwrite_key is None else self._overwrite_key) def __delitem__(self, key): dict.__delitem__(self, key) @@ -63,7 +62,6 @@ def __eq__(self, other): class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON): - __visit_name__ = "OBJECT" cache_ok = False @@ -83,8 +81,12 @@ class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON): def __getattr__(name): if name in deprecated_names: - warnings.warn(f"{name} is deprecated and will be removed in future releases. " - f"Please use ObjectType instead.", DeprecationWarning) + warnings.warn( + f"{name} is deprecated and will be removed in future releases. " + f"Please use ObjectType instead.", + category=DeprecationWarning, + stacklevel=2, + ) return globals()[f"_deprecated_{name}"] raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 56e1f50..7fc6447 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -33,15 +33,15 @@ Copyright (c) 2021-2023 Andrew Kane https://github.com/pgvector/pgvector-python """ + import typing as t if t.TYPE_CHECKING: import numpy.typing as npt # pragma: no cover import sqlalchemy as sa -from sqlalchemy.sql.expression import ColumnElement, literal from sqlalchemy.ext.compiler import compiles - +from sqlalchemy.sql.expression import ColumnElement, literal __all__ = [ "from_db", @@ -73,7 +73,9 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: if value.ndim != 1: raise ValueError("expected ndim to be 1") - if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating): + if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype( + value.dtype, np.floating + ): raise ValueError("dtype must be numeric") value = value.tolist() @@ -128,6 +130,7 @@ class KnnMatch(ColumnElement): https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match """ + inherit_cache = True def __init__(self, column, term, k=None): diff --git a/tests/__init__.py b/tests/__init__.py index 874c043..5005cec 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,11 @@ # -*- coding: utf-8 -*- -from sqlalchemy_cratedb.compat.api13 import monkeypatch_amend_select_sa14, monkeypatch_add_connectionfairy_driver_connection +from sqlalchemy_cratedb.compat.api13 import ( + monkeypatch_add_connectionfairy_driver_connection, + monkeypatch_amend_select_sa14, +) from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION + from .util import ParametrizedTestCase # `sql.select()` of SQLAlchemy 1.3 uses old calling semantics, @@ -11,21 +15,21 @@ monkeypatch_add_connectionfairy_driver_connection() from unittest import TestLoader, TestSuite -from .connection_test import SqlAlchemyConnectionTest -from .dict_test import SqlAlchemyDictTypeTest -from .datetime_test import SqlAlchemyDateAndDateTimeTest -from .compiler_test import SqlAlchemyCompilerTest, SqlAlchemyDDLCompilerTest -from .update_test import SqlAlchemyUpdateTest -from .match_test import SqlAlchemyMatchTest + +from .array_test import SqlAlchemyArrayTypeTest from .bulk_test import SqlAlchemyBulkTest -from .insert_from_select_test import SqlAlchemyInsertFromSelectTest +from .compiler_test import SqlAlchemyCompilerTest, SqlAlchemyDDLCompilerTest +from .connection_test import SqlAlchemyConnectionTest from .create_table_test import SqlAlchemyCreateTableTest -from .array_test import SqlAlchemyArrayTypeTest +from .datetime_test import SqlAlchemyDateAndDateTimeTest from .dialect_test import SqlAlchemyDialectTest +from .dict_test import SqlAlchemyDictTypeTest from .function_test import SqlAlchemyFunctionTest -from .warnings_test import SqlAlchemyWarningsTest +from .insert_from_select_test import SqlAlchemyInsertFromSelectTest +from .match_test import SqlAlchemyMatchTest from .query_caching import SqlAlchemyQueryCompilationCaching - +from .update_test import SqlAlchemyUpdateTest +from .warnings_test import SqlAlchemyWarningsTest makeSuite = TestLoader().loadTestsFromTestCase @@ -37,9 +41,21 @@ def test_suite_unit(): tests.addTest(makeSuite(SqlAlchemyDateAndDateTimeTest)) tests.addTest(makeSuite(SqlAlchemyCompilerTest)) tests.addTest(makeSuite(SqlAlchemyDDLCompilerTest)) - tests.addTest(ParametrizedTestCase.parametrize(SqlAlchemyCompilerTest, param={"server_version_info": None})) - tests.addTest(ParametrizedTestCase.parametrize(SqlAlchemyCompilerTest, param={"server_version_info": (4, 0, 12)})) - tests.addTest(ParametrizedTestCase.parametrize(SqlAlchemyCompilerTest, param={"server_version_info": (4, 1, 10)})) + tests.addTest( + ParametrizedTestCase.parametrize( + SqlAlchemyCompilerTest, param={"server_version_info": None} + ) + ) + tests.addTest( + ParametrizedTestCase.parametrize( + SqlAlchemyCompilerTest, param={"server_version_info": (4, 0, 12)} + ) + ) + tests.addTest( + ParametrizedTestCase.parametrize( + SqlAlchemyCompilerTest, param={"server_version_info": (4, 1, 10)} + ) + ) tests.addTest(makeSuite(SqlAlchemyUpdateTest)) tests.addTest(makeSuite(SqlAlchemyMatchTest)) tests.addTest(makeSuite(SqlAlchemyCreateTableTest)) diff --git a/tests/array_test.py b/tests/array_test.py index 6d66332..918c2da 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -21,11 +21,12 @@ from unittest import TestCase -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import sqlalchemy as sa -from sqlalchemy.sql import operators from sqlalchemy.orm import Session +from sqlalchemy.sql import operators + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -33,21 +34,20 @@ from crate.client.cursor import Cursor -fake_cursor = MagicMock(name='fake_cursor') -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor -@patch('crate.client.connection.Cursor', FakeCursor) +@patch("crate.client.connection.Cursor", FakeCursor) class SqlAlchemyArrayTypeTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") Base = declarative_base() self.metadata = sa.MetaData() class User(Base): - __tablename__ = 'users' + __tablename__ = "users" name = sa.Column(sa.String, primary_key=True) friends = sa.Column(sa.ARRAY(sa.String)) @@ -57,55 +57,57 @@ class User(Base): self.session = Session(bind=self.engine) def assertSQL(self, expected_str, actual_expr): - self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + self.assertEqual(expected_str, str(actual_expr).replace("\n", "")) def test_create_with_array(self): - t1 = sa.Table('t', self.metadata, - sa.Column('int_array', sa.ARRAY(sa.Integer)), - sa.Column('str_array', sa.ARRAY(sa.String)) - ) + t1 = sa.Table( + "t", + self.metadata, + sa.Column("int_array", sa.ARRAY(sa.Integer)), + sa.Column("str_array", sa.ARRAY(sa.String)), + ) t1.create(self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'int_array ARRAY(INT), \n\t' - 'str_array ARRAY(STRING)\n)\n\n'), - ()) + ( + "\nCREATE TABLE t (\n\t" + "int_array ARRAY(INT), \n\t" + "str_array ARRAY(STRING)\n)\n\n" + ), + (), + ) def test_array_insert(self): - trillian = self.User(name='Trillian', friends=['Arthur', 'Ford']) + trillian = self.User(name="Trillian", friends=["Arthur", "Ford"]) self.session.add(trillian) self.session.commit() fake_cursor.execute.assert_called_with( ("INSERT INTO users (name, friends, scores) VALUES (?, ?, ?)"), - ('Trillian', ['Arthur', 'Ford'], None)) + ("Trillian", ["Arthur", "Ford"], None), + ) def test_any(self): - s = self.session.query(self.User.name) \ - .filter(self.User.friends.any("arthur")) + s = self.session.query(self.User.name).filter(self.User.friends.any("arthur")) self.assertSQL( - "SELECT users.name AS users_name FROM users " - "WHERE ? = ANY (users.friends)", - s + "SELECT users.name AS users_name FROM users " "WHERE ? = ANY (users.friends)", s ) def test_any_with_operator(self): - s = self.session.query(self.User.name) \ - .filter(self.User.scores.any(6, operator=operators.lt)) + s = self.session.query(self.User.name).filter( + self.User.scores.any(6, operator=operators.lt) + ) self.assertSQL( - "SELECT users.name AS users_name FROM users " - "WHERE ? < ANY (users.scores)", - s + "SELECT users.name AS users_name FROM users " "WHERE ? < ANY (users.scores)", s ) def test_multidimensional_arrays(self): - t1 = sa.Table('t', self.metadata, - sa.Column('unsupported_array', - sa.ARRAY(sa.Integer, dimensions=2)), - ) + t1 = sa.Table( + "t", + self.metadata, + sa.Column("unsupported_array", sa.ARRAY(sa.Integer, dimensions=2)), + ) err = None try: t1.create(self.engine) except NotImplementedError as e: err = e - self.assertEqual(str(err), - "CrateDB doesn't support multidimensional arrays") + self.assertEqual(str(err), "CrateDB doesn't support multidimensional arrays") diff --git a/tests/bulk_test.py b/tests/bulk_test.py index da22e2d..97c9274 100644 --- a/tests/bulk_test.py +++ b/tests/bulk_test.py @@ -21,12 +21,12 @@ import math import sys from unittest import TestCase, skipIf -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import sqlalchemy as sa from sqlalchemy.orm import Session -from sqlalchemy_cratedb import SA_VERSION, SA_2_0 +from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION try: from sqlalchemy.orm import declarative_base @@ -35,19 +35,17 @@ from crate.client.cursor import Cursor - -fake_cursor = MagicMock(name='fake_cursor') -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor, return_value=fake_cursor) +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor, return_value=fake_cursor) class SqlAlchemyBulkTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") Base = declarative_base() class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer) @@ -56,7 +54,7 @@ class Character(Base): self.session = Session(bind=self.engine) @skipIf(SA_VERSION >= SA_2_0, "SQLAlchemy 2.x uses modern bulk INSERT mode") - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_bulk_save_legacy(self): """ Verify legacy SQLAlchemy bulk INSERT mode. @@ -85,17 +83,17 @@ def test_bulk_save_legacy(self): > -- https://github.com/sqlalchemy/sqlalchemy/discussions/6935#discussioncomment-4789701 """ chars = [ - self.character(name='Arthur', age=35), - self.character(name='Banshee', age=26), - self.character(name='Callisto', age=37), + self.character(name="Arthur", age=35), + self.character(name="Banshee", age=26), + self.character(name="Callisto", age=37), ] fake_cursor.description = () fake_cursor.rowcount = len(chars) fake_cursor.executemany.return_value = [ - {'rowcount': 1}, - {'rowcount': 1}, - {'rowcount': 1}, + {"rowcount": 1}, + {"rowcount": 1}, + {"rowcount": 1}, ] self.session.bulk_save_objects(chars) (stmt, bulk_args), _ = fake_cursor.executemany.call_args @@ -103,15 +101,11 @@ def test_bulk_save_legacy(self): expected_stmt = "INSERT INTO characters (name, age) VALUES (?, ?)" self.assertEqual(expected_stmt, stmt) - expected_bulk_args = ( - ('Arthur', 35), - ('Banshee', 26), - ('Callisto', 37) - ) + expected_bulk_args = (("Arthur", 35), ("Banshee", 26), ("Callisto", 37)) self.assertSequenceEqual(expected_bulk_args, bulk_args) @skipIf(SA_VERSION < SA_2_0, "SQLAlchemy 1.x uses legacy bulk INSERT mode") - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_bulk_save_modern(self): """ Verify modern SQLAlchemy bulk INSERT mode. @@ -143,17 +137,17 @@ def test_bulk_save_modern(self): self.maxDiff = None chars = [ - self.character(name='Arthur', age=35), - self.character(name='Banshee', age=26), - self.character(name='Callisto', age=37), + self.character(name="Arthur", age=35), + self.character(name="Banshee", age=26), + self.character(name="Callisto", age=37), ] fake_cursor.description = () fake_cursor.rowcount = len(chars) fake_cursor.execute.return_value = [ - {'rowcount': 1}, - {'rowcount': 1}, - {'rowcount': 1}, + {"rowcount": 1}, + {"rowcount": 1}, + {"rowcount": 1}, ] self.session.add_all(chars) self.session.commit() @@ -163,20 +157,24 @@ def test_bulk_save_modern(self): self.assertEqual(expected_stmt, stmt) expected_bulk_args = ( - 'Arthur', 35, - 'Banshee', 26, - 'Callisto', 37, + "Arthur", + 35, + "Banshee", + 26, + "Callisto", + 37, ) self.assertSequenceEqual(expected_bulk_args, bulk_args) @skipIf(sys.version_info < (3, 8), "SQLAlchemy/pandas is not supported on Python <3.8") @skipIf(SA_VERSION < SA_2_0, "SQLAlchemy 1.4 is no longer supported by pandas 2.2") - @patch('crate.client.connection.Cursor', mock_cursor=FakeCursor) + @patch("crate.client.connection.Cursor", mock_cursor=FakeCursor) def test_bulk_save_pandas(self, mock_cursor): """ Verify bulk INSERT with pandas. """ from pueblo.testing.pandas import makeTimeDataFrame + from sqlalchemy_cratedb import insert_bulk # 42 records / 8 chunksize = 5.25, which means 6 batches will be emitted. @@ -201,8 +199,8 @@ def test_bulk_save_pandas(self, mock_cursor): # Initializing the query has an overhead of two calls to the cursor object, probably one # initial connection from the DB-API driver, to inquire the database version, and another - # one, for SQLAlchemy. SQLAlchemy will use it to inquire the table schema using `information_schema`, - # and to eventually issue the `CREATE TABLE ...` statement. + # one, for SQLAlchemy. SQLAlchemy will use it to inquire the table schema using + # `information_schema`, and to eventually issue the `CREATE TABLE ...` statement. effective_op_count = mock_cursor.call_count - 2 # Verify number of batches. @@ -210,13 +208,14 @@ def test_bulk_save_pandas(self, mock_cursor): @skipIf(sys.version_info < (3, 8), "SQLAlchemy/Dask is not supported on Python <3.8") @skipIf(SA_VERSION < SA_2_0, "SQLAlchemy 1.4 is no longer supported by pandas 2.2") - @patch('crate.client.connection.Cursor', mock_cursor=FakeCursor) + @patch("crate.client.connection.Cursor", mock_cursor=FakeCursor) def test_bulk_save_dask(self, mock_cursor): """ Verify bulk INSERT with Dask. """ import dask.dataframe as dd from pueblo.testing.pandas import makeTimeDataFrame + from sqlalchemy_cratedb import insert_bulk # 42 records / 4 partitions means each partition has a size of 10.5 elements. diff --git a/tests/compiler_test.py b/tests/compiler_test.py index d1f19f5..83c3925 100644 --- a/tests/compiler_test.py +++ b/tests/compiler_test.py @@ -20,15 +20,14 @@ # software solely pursuant to the terms of the relevant commercial agreement. import warnings from textwrap import dedent -from unittest import mock, skipIf, TestCase +from unittest import TestCase, mock, skipIf from unittest.mock import MagicMock, patch -from crate.client.cursor import Cursor -from sqlalchemy_cratedb.compiler import crate_before_execute - import sqlalchemy as sa -from sqlalchemy.sql import text, Update +from crate.client.cursor import Cursor +from sqlalchemy.sql import Update, text +from sqlalchemy_cratedb.compiler import crate_before_execute from tests.settings import crate_host from tests.util import ExtraAssertions @@ -37,41 +36,41 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import SA_VERSION, SA_1_4, SA_2_0, ObjectType -from .util import ParametrizedTestCase +from sqlalchemy_cratedb import ObjectType +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_2_0, SA_VERSION +from .util import ParametrizedTestCase class SqlAlchemyCompilerTest(ParametrizedTestCase, ExtraAssertions): - def setUp(self): - self.crate_engine = sa.create_engine('crate://') + self.crate_engine = sa.create_engine("crate://") if isinstance(self.param, dict) and "server_version_info" in self.param: server_version_info = self.param["server_version_info"] self.crate_engine.dialect.server_version_info = server_version_info - self.sqlite_engine = sa.create_engine('sqlite://') + self.sqlite_engine = sa.create_engine("sqlite://") self.metadata = sa.MetaData() - self.mytable = sa.Table('mytable', self.metadata, - sa.Column('name', sa.String), - sa.Column('data', ObjectType)) + self.mytable = sa.Table( + "mytable", self.metadata, sa.Column("name", sa.String), sa.Column("data", ObjectType) + ) - self.update = Update(self.mytable).where(text('name=:name')) - self.values = [{'name': 'crate'}] - self.values = (self.values, ) + self.update = Update(self.mytable).where(text("name=:name")) + self.values = [{"name": "crate"}] + self.values = (self.values,) def test_sqlite_update_not_rewritten(self): clauseelement, multiparams, params = crate_before_execute( self.sqlite_engine, self.update, self.values, {} ) - self.assertFalse(hasattr(clauseelement, '_crate_specific')) + self.assertFalse(hasattr(clauseelement, "_crate_specific")) def test_crate_update_rewritten(self): clauseelement, multiparams, params = crate_before_execute( self.crate_engine, self.update, self.values, {} ) - self.assertTrue(hasattr(clauseelement, '_crate_specific')) + self.assertTrue(hasattr(clauseelement, "_crate_specific")) def test_bulk_update_on_builtin_type(self): """ @@ -84,7 +83,7 @@ def test_bulk_update_on_builtin_type(self): self.crate_engine, self.update, data, None ) - self.assertFalse(hasattr(clauseelement, '_crate_specific')) + self.assertFalse(hasattr(clauseelement, "_crate_specific")) def test_select_with_ilike_no_escape(self): """ @@ -93,17 +92,23 @@ def test_select_with_ilike_no_escape(self): selectable = self.mytable.select().where(self.mytable.c.name.ilike("%foo%")) statement = str(selectable.compile(bind=self.crate_engine)) if self.crate_engine.dialect.has_ilike_operator(): - self.assertEqual(statement, dedent(""" + self.assertEqual( + statement, + dedent(""" SELECT mytable.name, mytable.data FROM mytable WHERE mytable.name ILIKE ? - """).strip()) # noqa: W291 + """).strip(), + ) # noqa: W291 else: - self.assertEqual(statement, dedent(""" + self.assertEqual( + statement, + dedent(""" SELECT mytable.name, mytable.data FROM mytable WHERE lower(mytable.name) LIKE lower(?) - """).strip()) # noqa: W291 + """).strip(), + ) # noqa: W291 def test_select_with_not_ilike_no_escape(self): """ @@ -112,35 +117,44 @@ def test_select_with_not_ilike_no_escape(self): selectable = self.mytable.select().where(self.mytable.c.name.notilike("%foo%")) statement = str(selectable.compile(bind=self.crate_engine)) if SA_VERSION < SA_1_4 or not self.crate_engine.dialect.has_ilike_operator(): - self.assertEqual(statement, dedent(""" + self.assertEqual( + statement, + dedent(""" SELECT mytable.name, mytable.data FROM mytable WHERE lower(mytable.name) NOT LIKE lower(?) - """).strip()) # noqa: W291 + """).strip(), + ) # noqa: W291 else: - self.assertEqual(statement, dedent(""" + self.assertEqual( + statement, + dedent(""" SELECT mytable.name, mytable.data FROM mytable WHERE mytable.name NOT ILIKE ? - """).strip()) # noqa: W291 + """).strip(), + ) # noqa: W291 def test_select_with_ilike_and_escape(self): """ Verify the compiler fails when using CrateDB's native `ILIKE` method together with `ESCAPE`. """ - selectable = self.mytable.select().where(self.mytable.c.name.ilike("%foo%", escape='\\')) + selectable = self.mytable.select().where(self.mytable.c.name.ilike("%foo%", escape="\\")) with self.assertRaises(NotImplementedError) as cmex: selectable.compile(bind=self.crate_engine) self.assertEqual(str(cmex.exception), "Unsupported feature: ESCAPE is not supported") - @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 and earlier do not support native `NOT ILIKE` compilation") + @skipIf( + SA_VERSION < SA_1_4, + "SQLAlchemy 1.3 and earlier do not support native `NOT ILIKE` compilation", + ) def test_select_with_not_ilike_and_escape(self): """ Verify the compiler fails when using CrateDB's native `ILIKE` method together with `ESCAPE`. """ - selectable = self.mytable.select().where(self.mytable.c.name.notilike("%foo%", escape='\\')) + selectable = self.mytable.select().where(self.mytable.c.name.notilike("%foo%", escape="\\")) with self.assertRaises(NotImplementedError) as cmex: selectable.compile(bind=self.crate_engine) self.assertEqual(str(cmex.exception), "Unsupported feature: ESCAPE is not supported") @@ -152,9 +166,13 @@ def test_select_with_offset(self): selectable = self.mytable.select().offset(5) statement = str(selectable.compile(bind=self.crate_engine)) if SA_VERSION >= SA_1_4: - self.assertEqual(statement, "SELECT mytable.name, mytable.data \nFROM mytable\n LIMIT ALL OFFSET ?") + self.assertEqual( + statement, "SELECT mytable.name, mytable.data \nFROM mytable\n LIMIT ALL OFFSET ?" + ) else: - self.assertEqual(statement, "SELECT mytable.name, mytable.data \nFROM mytable \n LIMIT ALL OFFSET ?") + self.assertEqual( + statement, "SELECT mytable.name, mytable.data \nFROM mytable \n LIMIT ALL OFFSET ?" + ) def test_select_with_limit(self): """ @@ -170,7 +188,9 @@ def test_select_with_offset_and_limit(self): """ selectable = self.mytable.select().offset(5).limit(42) statement = str(selectable.compile(bind=self.crate_engine)) - self.assertEqual(statement, "SELECT mytable.name, mytable.data \nFROM mytable \n LIMIT ? OFFSET ?") + self.assertEqual( + statement, "SELECT mytable.name, mytable.data \nFROM mytable \n LIMIT ? OFFSET ?" + ) def test_insert_multivalues(self): """ @@ -202,7 +222,10 @@ def test_insert_multivalues(self): statement = str(insertable.compile(bind=self.crate_engine)) self.assertEqual(statement, "INSERT INTO mytable (name) VALUES (?), (?), (?)") - @skipIf(SA_VERSION < SA_2_0, "SQLAlchemy 1.x does not support the 'insertmanyvalues' dialect feature") + @skipIf( + SA_VERSION < SA_2_0, + "SQLAlchemy 1.x does not support the 'insertmanyvalues' dialect feature", + ) def test_insert_manyvalues(self): """ Verify the `use_insertmanyvalues` and `use_insertmanyvalues_wo_returning` dialect features. @@ -242,19 +265,27 @@ def test_insert_manyvalues(self): statement = str(insertable.compile(bind=self.crate_engine)) self.assertEqual(statement, "INSERT INTO mytable (name, data) VALUES (?, ?)") - with mock.patch("crate.client.http.Client.sql", autospec=True, return_value={"cols": []}) as client_mock: - + with mock.patch( + "crate.client.http.Client.sql", autospec=True, return_value={"cols": []} + ) as client_mock: with self.crate_engine.begin() as conn: # Adjust page size on a per-connection level. conn.execution_options(insertmanyvalues_page_size=batch_size) conn.execute(insertable, parameters=records) # Verify that input data has been batched correctly. - self.assertListEqual(client_mock.mock_calls, [ - mock.call(mock.ANY, 'INSERT INTO mytable (name) VALUES (?), (?)', ('foo_0', 'foo_1'), None), - mock.call(mock.ANY, 'INSERT INTO mytable (name) VALUES (?), (?)', ('foo_2', 'foo_3'), None), - mock.call(mock.ANY, 'INSERT INTO mytable (name) VALUES (?)', ('foo_4', ), None), - ]) + self.assertListEqual( + client_mock.mock_calls, + [ + mock.call( + mock.ANY, "INSERT INTO mytable (name) VALUES (?), (?)", ("foo_0", "foo_1"), None + ), + mock.call( + mock.ANY, "INSERT INTO mytable (name) VALUES (?), (?)", ("foo_2", "foo_3"), None + ), + mock.call(mock.ANY, "INSERT INTO mytable (name) VALUES (?)", ("foo_4",), None), + ], + ) def test_for_update(self): """ @@ -263,7 +294,6 @@ def test_for_update(self): """ with warnings.catch_warnings(record=True) as w: - # By default, warnings from a loop will only be emitted once. # This scenario tests exactly this behaviour, to verify logs # don't get flooded. @@ -281,11 +311,14 @@ def test_for_update(self): # Verify if corresponding warning is emitted, once. self.assertEqual(len(w), 1) self.assertIsSubclass(w[-1].category, UserWarning) - self.assertIn("CrateDB does not support the 'INSERT ... FOR UPDATE' clause, " - "it will be omitted when generating SQL statements.", str(w[-1].message)) + self.assertIn( + "CrateDB does not support the 'INSERT ... FOR UPDATE' clause, " + "it will be omitted when generating SQL statements.", + str(w[-1].message), + ) -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") @@ -319,7 +352,7 @@ def execute_wrapper(self, query, *args, **kwargs): return self.fake_cursor -@patch('crate.client.connection.Cursor', FakeCursor) +@patch("crate.client.connection.Cursor", FakeCursor) class SqlAlchemyDDLCompilerTest(CompilerTestCase, ExtraAssertions): """ Verify a few scenarios regarding the DDL compiler. @@ -363,24 +396,28 @@ class ItemStore(Base): root = sa.orm.relationship(RootStore, back_populates="items") with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. warnings.simplefilter("always") # Verify SQL DDL statement. self.metadata.create_all(self.engine, tables=[RootStore.__table__], checkfirst=False) - self.assertEqual(self.executed_statement, dedent(""" + self.assertEqual( + self.executed_statement, + dedent(""" CREATE TABLE testdrive.root ( \tid INT NOT NULL, \tname STRING, \tPRIMARY KEY (id) ) - """)) # noqa: W291, W293 + """), + ) # noqa: W291, W293 # Verify SQL DDL statement. self.metadata.create_all(self.engine, tables=[ItemStore.__table__], checkfirst=False) - self.assertEqual(self.executed_statement, dedent(""" + self.assertEqual( + self.executed_statement, + dedent(""" CREATE TABLE testdrive.item ( \tid INT NOT NULL, \tname STRING, @@ -388,13 +425,17 @@ class ItemStore(Base): \tPRIMARY KEY (id) ) - """)) # noqa: W291, W293 + """), + ) # noqa: W291, W293 # Verify if corresponding warning is emitted. self.assertEqual(len(w), 1) self.assertIsSubclass(w[-1].category, UserWarning) - self.assertIn("CrateDB does not support foreign key constraints, " - "they will be omitted when generating DDL statements.", str(w[-1].message)) + self.assertIn( + "CrateDB does not support foreign key constraints, " + "they will be omitted when generating DDL statements.", + str(w[-1].message), + ) def test_ddl_with_unique_key(self): """ @@ -412,26 +453,31 @@ class FooBar(Base): name = sa.Column(sa.String, unique=True) with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. warnings.simplefilter("always") # Verify SQL DDL statement. self.metadata.create_all(self.engine, tables=[FooBar.__table__], checkfirst=False) - self.assertEqual(self.executed_statement, dedent(""" + self.assertEqual( + self.executed_statement, + dedent(""" CREATE TABLE testdrive.foobar ( \tid INT NOT NULL, \tname STRING, \tPRIMARY KEY (id) ) - """)) # noqa: W291, W293 + """), + ) # noqa: W291, W293 # Verify if corresponding warning is emitted. self.assertEqual(len(w), 1) self.assertIsSubclass(w[-1].category, UserWarning) - self.assertIn("CrateDB does not support unique constraints, " - "they will be omitted when generating DDL statements.", str(w[-1].message)) + self.assertIn( + "CrateDB does not support unique constraints, " + "they will be omitted when generating DDL statements.", + str(w[-1].message), + ) def test_ddl_with_reserved_words_and_uppercase(self): """ @@ -452,7 +498,9 @@ class FooBar(Base): # Verify SQL DDL statement. self.metadata.create_all(self.engine, tables=[FooBar.__table__], checkfirst=False) - self.assertEqual(self.executed_statement, dedent(""" + self.assertEqual( + self.executed_statement, + dedent(""" CREATE TABLE testdrive.foobar ( \t"ID" INT NOT NULL, \t"index" INT, @@ -461,4 +509,5 @@ class FooBar(Base): \tPRIMARY KEY ("ID") ) - """)) # noqa: W291, W293 + """), + ) # noqa: W291, W293 diff --git a/tests/connection_test.py b/tests/connection_test.py index f1a560e..00adb25 100644 --- a/tests/connection_test.py +++ b/tests/connection_test.py @@ -20,39 +20,40 @@ # software solely pursuant to the terms of the relevant commercial agreement. from unittest import TestCase + import sqlalchemy as sa from sqlalchemy.exc import NoSuchModuleError class SqlAlchemyConnectionTest(TestCase): - def test_connection_server_uri_unknown_sa_plugin(self): with self.assertRaises(NoSuchModuleError): sa.create_engine("foobar://otherhost:19201") def test_default_connection(self): - engine = sa.create_engine('crate://') + engine = sa.create_engine("crate://") conn = engine.raw_connection() - self.assertEqual(">", - repr(conn.driver_connection)) + self.assertEqual( + ">", repr(conn.driver_connection) + ) conn.close() engine.dispose() def test_connection_server_uri_http(self): - engine = sa.create_engine( - "crate://otherhost:19201") + engine = sa.create_engine("crate://otherhost:19201") conn = engine.raw_connection() - self.assertEqual(">", - repr(conn.driver_connection)) + self.assertEqual( + ">", repr(conn.driver_connection) + ) conn.close() engine.dispose() def test_connection_server_uri_https(self): - engine = sa.create_engine( - "crate://otherhost:19201/?ssl=true") + engine = sa.create_engine("crate://otherhost:19201/?ssl=true") conn = engine.raw_connection() - self.assertEqual(">", - repr(conn.driver_connection)) + self.assertEqual( + ">", repr(conn.driver_connection) + ) conn.close() engine.dispose() @@ -62,38 +63,36 @@ def test_connection_server_uri_invalid_port(self): self.assertIn("invalid literal for int() with base 10: 'bar'", str(context.exception)) def test_connection_server_uri_https_with_trusted_user(self): - engine = sa.create_engine( - "crate://foo@otherhost:19201/?ssl=true") + engine = sa.create_engine("crate://foo@otherhost:19201/?ssl=true") conn = engine.raw_connection() - self.assertEqual(">", - repr(conn.driver_connection)) + self.assertEqual( + ">", repr(conn.driver_connection) + ) self.assertEqual(conn.driver_connection.client.username, "foo") self.assertEqual(conn.driver_connection.client.password, None) conn.close() engine.dispose() def test_connection_server_uri_https_with_credentials(self): - engine = sa.create_engine( - "crate://foo:bar@otherhost:19201/?ssl=true") + engine = sa.create_engine("crate://foo:bar@otherhost:19201/?ssl=true") conn = engine.raw_connection() - self.assertEqual(">", - repr(conn.driver_connection)) + self.assertEqual( + ">", repr(conn.driver_connection) + ) self.assertEqual(conn.driver_connection.client.username, "foo") self.assertEqual(conn.driver_connection.client.password, "bar") conn.close() engine.dispose() def test_connection_server_uri_parameter_timeout(self): - engine = sa.create_engine( - "crate://otherhost:19201/?timeout=42.42") + engine = sa.create_engine("crate://otherhost:19201/?timeout=42.42") conn = engine.raw_connection() self.assertEqual(conn.driver_connection.client._pool_kw["timeout"], 42.42) conn.close() engine.dispose() def test_connection_server_uri_parameter_pool_size(self): - engine = sa.create_engine( - "crate://otherhost:19201/?pool_size=20") + engine = sa.create_engine("crate://otherhost:19201/?pool_size=20") conn = engine.raw_connection() self.assertEqual(conn.driver_connection.client._pool_kw["maxsize"], 20) conn.close() @@ -101,29 +100,28 @@ def test_connection_server_uri_parameter_pool_size(self): def test_connection_multiple_server_http(self): engine = sa.create_engine( - "crate://", connect_args={ - 'servers': ['localhost:4201', 'localhost:4202'] - } + "crate://", connect_args={"servers": ["localhost:4201", "localhost:4202"]} ) conn = engine.raw_connection() self.assertEqual( - ">", - repr(conn.driver_connection)) + ">", + repr(conn.driver_connection), + ) conn.close() engine.dispose() def test_connection_multiple_server_https(self): engine = sa.create_engine( - "crate://", connect_args={ - 'servers': ['localhost:4201', 'localhost:4202'], - 'ssl': True, - } + "crate://", + connect_args={ + "servers": ["localhost:4201", "localhost:4202"], + "ssl": True, + }, ) conn = engine.raw_connection() self.assertEqual( - ">", - repr(conn.driver_connection)) + ">", + repr(conn.driver_connection), + ) conn.close() engine.dispose() diff --git a/tests/create_table_test.py b/tests/create_table_test.py index e5bca25..ac4b2dd 100644 --- a/tests/create_table_test.py +++ b/tests/create_table_test.py @@ -20,33 +20,33 @@ # software solely pursuant to the terms of the relevant commercial agreement. import pytest import sqlalchemy as sa + try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import ObjectType, ObjectArray, Geopoint -from crate.client.cursor import Cursor - from unittest import TestCase -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch +from crate.client.cursor import Cursor + +from sqlalchemy_cratedb import Geopoint, ObjectArray, ObjectType -fake_cursor = MagicMock(name='fake_cursor') -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor -@patch('crate.client.connection.Cursor', FakeCursor) +@patch("crate.client.connection.Cursor", FakeCursor) class SqlAlchemyCreateTableTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") self.Base = declarative_base() def test_table_basic_types(self): class User(self.Base): - __tablename__ = 'users' + __tablename__ = "users" string_col = sa.Column(sa.String, primary_key=True) unicode_col = sa.Column(sa.Unicode) text_col = sa.Column(sa.Text) @@ -62,198 +62,232 @@ class User(self.Base): self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE users (\n\tstring_col STRING NOT NULL, ' - '\n\tunicode_col STRING, \n\ttext_col STRING, \n\tint_col INT, ' - '\n\tlong_col1 LONG, \n\tlong_col2 LONG, ' - '\n\tbool_col BOOLEAN, ' - '\n\tshort_col SHORT, ' - '\n\tdatetime_col TIMESTAMP WITHOUT TIME ZONE, ' - '\n\tdate_col TIMESTAMP, ' - '\n\tfloat_col FLOAT, ' - '\n\tdouble_col DOUBLE, ' - '\n\tPRIMARY KEY (string_col)\n)\n\n'), - ()) + ( + "\nCREATE TABLE users (\n\tstring_col STRING NOT NULL, " + "\n\tunicode_col STRING, \n\ttext_col STRING, \n\tint_col INT, " + "\n\tlong_col1 LONG, \n\tlong_col2 LONG, " + "\n\tbool_col BOOLEAN, " + "\n\tshort_col SHORT, " + "\n\tdatetime_col TIMESTAMP WITHOUT TIME ZONE, " + "\n\tdate_col TIMESTAMP, " + "\n\tfloat_col FLOAT, " + "\n\tdouble_col DOUBLE, " + "\n\tPRIMARY KEY (string_col)\n)\n\n" + ), + (), + ) def test_column_obj(self): class DummyTable(self.Base): - __tablename__ = 'dummy' + __tablename__ = "dummy" pk = sa.Column(sa.String, primary_key=True) obj_col = sa.Column(ObjectType) + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE dummy (\n\tpk STRING NOT NULL, \n\tobj_col OBJECT, ' - '\n\tPRIMARY KEY (pk)\n)\n\n'), - ()) + ( + "\nCREATE TABLE dummy (\n\tpk STRING NOT NULL, \n\tobj_col OBJECT, " + "\n\tPRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_table_clustered_by(self): class DummyTable(self.Base): - __tablename__ = 't' - __table_args__ = { - 'crate_clustered_by': 'p' - } + __tablename__ = "t" + __table_args__ = {"crate_clustered_by": "p"} pk = sa.Column(sa.String, primary_key=True) p = sa.Column(sa.String) + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'p STRING, \n\t' - 'PRIMARY KEY (pk)\n' - ') CLUSTERED BY (p)\n\n'), - ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "p STRING, \n\t" + "PRIMARY KEY (pk)\n" + ") CLUSTERED BY (p)\n\n" + ), + (), + ) def test_column_computed(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" ts = sa.Column(sa.BigInteger, primary_key=True) p = sa.Column(sa.BigInteger, sa.Computed("date_trunc('day', ts)")) + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'ts LONG NOT NULL, \n\t' - 'p LONG GENERATED ALWAYS AS (date_trunc(\'day\', ts)), \n\t' - 'PRIMARY KEY (ts)\n' - ')\n\n'), - ()) + ( + "\nCREATE TABLE t (\n\t" + "ts LONG NOT NULL, \n\t" + "p LONG GENERATED ALWAYS AS (date_trunc('day', ts)), \n\t" + "PRIMARY KEY (ts)\n" + ")\n\n" + ), + (), + ) def test_column_computed_virtual(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" ts = sa.Column(sa.BigInteger, primary_key=True) p = sa.Column(sa.BigInteger, sa.Computed("date_trunc('day', ts)", persisted=False)) + with self.assertRaises(sa.exc.CompileError): self.Base.metadata.create_all(bind=self.engine) def test_table_partitioned_by(self): class DummyTable(self.Base): - __tablename__ = 't' - __table_args__ = { - 'crate_partitioned_by': 'p', - 'invalid_option': 1 - } + __tablename__ = "t" + __table_args__ = {"crate_partitioned_by": "p", "invalid_option": 1} pk = sa.Column(sa.String, primary_key=True) p = sa.Column(sa.String) + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'p STRING, \n\t' - 'PRIMARY KEY (pk)\n' - ') PARTITIONED BY (p)\n\n'), - ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "p STRING, \n\t" + "PRIMARY KEY (pk)\n" + ") PARTITIONED BY (p)\n\n" + ), + (), + ) def test_table_number_of_shards_and_replicas(self): class DummyTable(self.Base): - __tablename__ = 't' - __table_args__ = { - 'crate_number_of_replicas': '2', - 'crate_number_of_shards': 3 - } + __tablename__ = "t" + __table_args__ = {"crate_number_of_replicas": "2", "crate_number_of_shards": 3} pk = sa.Column(sa.String, primary_key=True) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'PRIMARY KEY (pk)\n' - ') CLUSTERED INTO 3 SHARDS WITH (number_of_replicas = 2)\n\n'), - ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "PRIMARY KEY (pk)\n" + ") CLUSTERED INTO 3 SHARDS WITH (number_of_replicas = 2)\n\n" + ), + (), + ) def test_table_clustered_by_and_number_of_shards(self): class DummyTable(self.Base): - __tablename__ = 't' - __table_args__ = { - 'crate_clustered_by': 'p', - 'crate_number_of_shards': 3 - } + __tablename__ = "t" + __table_args__ = {"crate_clustered_by": "p", "crate_number_of_shards": 3} pk = sa.Column(sa.String, primary_key=True) p = sa.Column(sa.String, primary_key=True) + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'p STRING NOT NULL, \n\t' - 'PRIMARY KEY (pk, p)\n' - ') CLUSTERED BY (p) INTO 3 SHARDS\n\n'), - ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "p STRING NOT NULL, \n\t" + "PRIMARY KEY (pk, p)\n" + ") CLUSTERED BY (p) INTO 3 SHARDS\n\n" + ), + (), + ) def test_table_translog_durability(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" __table_args__ = { 'crate_"translog.durability"': "'async'", } pk = sa.Column(sa.String, primary_key=True) + self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'PRIMARY KEY (pk)\n' - """) WITH ("translog.durability" = 'async')\n\n"""), - ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "PRIMARY KEY (pk)\n" + """) WITH ("translog.durability" = 'async')\n\n""" + ), + (), + ) def test_column_object_array(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) tags = sa.Column(ObjectArray) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'tags ARRAY(OBJECT), \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "tags ARRAY(OBJECT), \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_column_nullable(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.Integer, nullable=True) b = sa.Column(sa.Integer, nullable=False) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'a INT, \n\t' - 'b INT NOT NULL, \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "a INT, \n\t" + "b INT NOT NULL, \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_column_pk_nullable(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True, nullable=True) + with self.assertRaises(sa.exc.CompileError): self.Base.metadata.create_all(bind=self.engine) def test_column_crate_index(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.Integer, crate_index=False) b = sa.Column(sa.Integer, crate_index=True) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'a INT INDEX OFF, \n\t' - 'b INT, \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "a INT INDEX OFF, \n\t" + "b INT, \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) @pytest.mark.skip("CompileError not raised") def test_column_geopoint_without_index(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(Geopoint, crate_index=False) + with self.assertRaises(sa.exc.CompileError): self.Base.metadata.create_all(bind=self.engine) def test_text_column_without_columnstore(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.String, crate_columnstore=False) b = sa.Column(sa.String, crate_columnstore=True) @@ -262,16 +296,20 @@ class DummyTable(self.Base): self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'a STRING STORAGE WITH (columnstore = false), \n\t' - 'b STRING, \n\t' - 'c STRING, \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "a STRING STORAGE WITH (columnstore = false), \n\t" + "b STRING, \n\t" + "c STRING, \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_non_text_column_without_columnstore(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.Integer, crate_columnstore=False) @@ -280,52 +318,68 @@ class DummyTable(self.Base): def test_column_server_default_text_func(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.DateTime, server_default=sa.text("now()")) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_column_server_default_string(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.String, server_default="Zaphod") self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'a STRING DEFAULT \'Zaphod\', \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "a STRING DEFAULT 'Zaphod', \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_column_server_default_func(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) a = sa.Column(sa.DateTime, server_default=sa.func.now()) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) def test_column_server_default_text_constant(self): class DummyTable(self.Base): - __tablename__ = 't' + __tablename__ = "t" pk = sa.Column(sa.String, primary_key=True) answer = sa.Column(sa.Integer, server_default=sa.text("42")) self.Base.metadata.create_all(bind=self.engine) fake_cursor.execute.assert_called_with( - ('\nCREATE TABLE t (\n\t' - 'pk STRING NOT NULL, \n\t' - 'answer INT DEFAULT 42, \n\t' - 'PRIMARY KEY (pk)\n)\n\n'), ()) + ( + "\nCREATE TABLE t (\n\t" + "pk STRING NOT NULL, \n\t" + "answer INT DEFAULT 42, \n\t" + "PRIMARY KEY (pk)\n)\n\n" + ), + (), + ) diff --git a/tests/datetime_test.py b/tests/datetime_test.py index b4712cb..da0ea4a 100644 --- a/tests/datetime_test.py +++ b/tests/datetime_test.py @@ -23,14 +23,14 @@ import datetime as dt from unittest import TestCase, skipIf -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import pytest import sqlalchemy as sa from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 from sqlalchemy_cratedb.dialect import DateTime +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION try: from sqlalchemy.orm import declarative_base @@ -44,15 +44,16 @@ from crate.client.cursor import Cursor - -fake_cursor = MagicMock(name='fake_cursor') -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor INPUT_DATE = dt.date(2009, 5, 13) INPUT_DATETIME_NOTZ = dt.datetime(2009, 5, 13, 19, 00, 30, 123456) -INPUT_DATETIME_TZ = dt.datetime(2009, 5, 13, 19, 00, 30, 123456, tzinfo=zoneinfo.ZoneInfo("Europe/Kyiv")) +INPUT_DATETIME_TZ = dt.datetime( + 2009, 5, 13, 19, 00, 30, 123456, tzinfo=zoneinfo.ZoneInfo("Europe/Kyiv") +) OUTPUT_DATE = INPUT_DATE OUTPUT_TIMETZ_NOTZ = dt.time(19, 00, 30, 123000) OUTPUT_TIMETZ_TZ = dt.time(16, 00, 30, 123000) @@ -61,34 +62,31 @@ @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") -@patch('crate.client.connection.Cursor', FakeCursor) +@patch("crate.client.connection.Cursor", FakeCursor) class SqlAlchemyDateAndDateTimeTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") Base = declarative_base() class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) date = sa.Column(sa.Date) datetime = sa.Column(sa.DateTime) fake_cursor.description = ( - ('characters_name', None, None, None, None, None, None), - ('characters_date', None, None, None, None, None, None) + ("characters_name", None, None, None, None, None, None), + ("characters_date", None, None, None, None, None, None), ) self.session = Session(bind=self.engine) self.Character = Character def test_date_can_handle_datetime(self): - """ date type should also be able to handle iso datetime strings. + """date type should also be able to handle iso datetime strings. this verifies that the fallback in the Date result_processor works. """ - fake_cursor.fetchall.return_value = [ - ('Trillian', '2013-07-16T00:00:00.000Z') - ] + fake_cursor.fetchall.return_value = [("Trillian", "2013-07-16T00:00:00.000Z")] self.session.query(self.Character).first() def test_date_can_handle_tz_aware_datetime(self): @@ -137,8 +135,13 @@ def test_datetime_notz(session): session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select( - FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() + result = ( + session.execute( + sa.select(FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz) + ) + .mappings() + .first() + ) # Compare outcome. assert result["date"] == OUTPUT_DATE @@ -171,8 +174,13 @@ def test_datetime_tz(session): # Query record. session.expunge(foo_item) - result = session.execute(sa.select( - FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() + result = ( + session.execute( + sa.select(FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz) + ) + .mappings() + .first() + ) # Compare outcome. assert result["date"] == OUTPUT_DATE @@ -209,7 +217,13 @@ def test_datetime_date(session): session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() + result = ( + session.execute( + sa.select(FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz) + ) + .mappings() + .first() + ) # Compare outcome. assert result["datetime_notz"] == dt.datetime(2009, 5, 13, 0, 0, 0) diff --git a/tests/dialect_test.py b/tests/dialect_test.py index d2213ac..74ae11a 100644 --- a/tests/dialect_test.py +++ b/tests/dialect_test.py @@ -24,34 +24,33 @@ from unittest.mock import MagicMock, patch import sqlalchemy as sa - from crate.client.cursor import Cursor -from sqlalchemy_cratedb import SA_VERSION, ObjectType -from sqlalchemy_cratedb import SA_1_4, SA_2_0 from sqlalchemy import inspect from sqlalchemy.orm import Session + +from sqlalchemy_cratedb import ObjectType +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_2_0, SA_VERSION + try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.testing import eq_, in_, is_true -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) -@patch('crate.client.connection.Cursor', FakeCursor) +@patch("crate.client.connection.Cursor", FakeCursor) class SqlAlchemyDialectTest(TestCase): - def execute_wrapper(self, query, *args, **kwargs): self.executed_statement = query return self.fake_cursor def setUp(self): - - self.fake_cursor = MagicMock(name='fake_cursor') + self.fake_cursor = MagicMock(name="fake_cursor") FakeCursor.return_value = self.fake_cursor - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") self.executed_statement = None @@ -62,7 +61,7 @@ def setUp(self): self.base = declarative_base() class Character(self.base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer, primary_key=True) @@ -73,9 +72,7 @@ class Character(self.base): def init_mock(self, return_value=None): self.fake_cursor.rowcount = 1 - self.fake_cursor.description = ( - ('foo', None, None, None, None, None, None), - ) + self.fake_cursor.description = (("foo", None, None, None, None, None, None),) self.fake_cursor.fetchall = MagicMock(return_value=return_value) def test_primary_keys_2_3_0(self): @@ -83,12 +80,10 @@ def test_primary_keys_2_3_0(self): self.engine.dialect.server_version_info = (2, 3, 0) self.fake_cursor.rowcount = 3 - self.fake_cursor.description = ( - ('foo', None, None, None, None, None, None), - ) + self.fake_cursor.description = (("foo", None, None, None, None, None, None),) self.fake_cursor.fetchall = MagicMock(return_value=[["id"], ["id2"], ["id3"]]) - eq_(insp.get_pk_constraint("characters")['constrained_columns'], ["id", "id2", "id3"]) + eq_(insp.get_pk_constraint("characters")["constrained_columns"], ["id", "id2", "id3"]) self.fake_cursor.fetchall.assert_called_once_with() in_("information_schema.key_column_usage", self.executed_statement) in_("table_catalog = ?", self.executed_statement) @@ -98,58 +93,60 @@ def test_primary_keys_3_0_0(self): self.engine.dialect.server_version_info = (3, 0, 0) self.fake_cursor.rowcount = 3 - self.fake_cursor.description = ( - ('foo', None, None, None, None, None, None), - ) + self.fake_cursor.description = (("foo", None, None, None, None, None, None),) self.fake_cursor.fetchall = MagicMock(return_value=[["id"], ["id2"], ["id3"]]) - eq_(insp.get_pk_constraint("characters")['constrained_columns'], ["id", "id2", "id3"]) + eq_(insp.get_pk_constraint("characters")["constrained_columns"], ["id", "id2", "id3"]) self.fake_cursor.fetchall.assert_called_once_with() in_("information_schema.key_column_usage", self.executed_statement) in_("table_schema = ?", self.executed_statement) def test_get_table_names(self): self.fake_cursor.rowcount = 1 - self.fake_cursor.description = ( - ('foo', None, None, None, None, None, None), - ) + self.fake_cursor.description = (("foo", None, None, None, None, None, None),) self.fake_cursor.fetchall = MagicMock(return_value=[["t1"], ["t2"]]) insp = inspect(self.session.bind) self.engine.dialect.server_version_info = (2, 0, 0) - eq_(insp.get_table_names(schema="doc"), - ['t1', 't2']) - in_("WHERE table_schema = ? AND table_type = 'BASE TABLE' ORDER BY", self.executed_statement) + eq_(insp.get_table_names(schema="doc"), ["t1", "t2"]) + in_( + "WHERE table_schema = ? AND table_type = 'BASE TABLE' ORDER BY", self.executed_statement + ) def test_get_view_names(self): self.fake_cursor.rowcount = 1 - self.fake_cursor.description = ( - ('foo', None, None, None, None, None, None), - ) + self.fake_cursor.description = (("foo", None, None, None, None, None, None),) self.fake_cursor.fetchall = MagicMock(return_value=[["v1"], ["v2"]]) insp = inspect(self.session.bind) self.engine.dialect.server_version_info = (2, 0, 0) - eq_(insp.get_view_names(schema="doc"), - ['v1', 'v2']) - eq_(self.executed_statement, "SELECT table_name FROM information_schema.views " - "ORDER BY table_name ASC, table_schema ASC") + eq_(insp.get_view_names(schema="doc"), ["v1", "v2"]) + eq_( + self.executed_statement, + "SELECT table_name FROM information_schema.views " + "ORDER BY table_name ASC, table_schema ASC", + ) @skipIf(SA_VERSION < SA_1_4, "Inspector.has_table only available on SQLAlchemy>=1.4") def test_has_table(self): self.init_mock(return_value=[["foo"], ["bar"]]) insp = inspect(self.session.bind) is_true(insp.has_table("bar")) - eq_(self.executed_statement, + eq_( + self.executed_statement, "SELECT table_name FROM information_schema.tables " "WHERE table_schema = ? AND table_type = 'BASE TABLE' " - "ORDER BY table_name ASC, table_schema ASC") + "ORDER BY table_name ASC, table_schema ASC", + ) @skipIf(SA_VERSION < SA_2_0, "Inspector.has_schema only available on SQLAlchemy>=2.0") def test_has_schema(self): self.init_mock( - return_value=[["blob"], ["doc"], ["information_schema"], ["pg_catalog"], ["sys"]]) + return_value=[["blob"], ["doc"], ["information_schema"], ["pg_catalog"], ["sys"]] + ) insp = inspect(self.session.bind) is_true(insp.has_schema("doc")) - eq_(self.executed_statement, - "select schema_name from information_schema.schemata order by schema_name asc") + eq_( + self.executed_statement, + "select schema_name from information_schema.schemata order by schema_name asc", + ) diff --git a/tests/dict_test.py b/tests/dict_test.py index 5f2692c..5769df3 100644 --- a/tests/dict_test.py +++ b/tests/dict_test.py @@ -22,108 +22,91 @@ from __future__ import absolute_import from unittest import TestCase, skipIf -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import sqlalchemy as sa -from sqlalchemy.sql import select from sqlalchemy.orm import Session +from sqlalchemy.sql import select + try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import ObjectArray, ObjectType, SA_VERSION, SA_1_4 from crate.client.cursor import Cursor +from sqlalchemy_cratedb import ObjectArray, ObjectType +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION -fake_cursor = MagicMock(name='fake_cursor') -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class SqlAlchemyDictTypeTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") metadata = sa.MetaData() - self.mytable = sa.Table('mytable', metadata, - sa.Column('name', sa.String), - sa.Column('data', ObjectType)) + self.mytable = sa.Table( + "mytable", metadata, sa.Column("name", sa.String), sa.Column("data", ObjectType) + ) def assertSQL(self, expected_str, selectable): actual_expr = selectable.compile(bind=self.engine) - self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + self.assertEqual(expected_str, str(actual_expr).replace("\n", "")) def test_select_with_dict_column(self): mytable = self.mytable self.assertSQL( - "SELECT mytable.data['x'] AS anon_1 FROM mytable", - select(mytable.c.data['x']) + "SELECT mytable.data['x'] AS anon_1 FROM mytable", select(mytable.c.data["x"]) ) def test_select_with_dict_column_where_clause(self): mytable = self.mytable - s = select(mytable.c.data).\ - where(mytable.c.data['x'] == 1) - self.assertSQL( - "SELECT mytable.data FROM mytable WHERE mytable.data['x'] = ?", - s - ) + s = select(mytable.c.data).where(mytable.c.data["x"] == 1) + self.assertSQL("SELECT mytable.data FROM mytable WHERE mytable.data['x'] = ?", s) def test_select_with_dict_column_nested_where(self): mytable = self.mytable s = select(mytable.c.name) - s = s.where(mytable.c.data['x']['y'] == 1) - self.assertSQL( - "SELECT mytable.name FROM mytable " + - "WHERE mytable.data['x']['y'] = ?", - s - ) + s = s.where(mytable.c.data["x"]["y"] == 1) + self.assertSQL("SELECT mytable.name FROM mytable " + "WHERE mytable.data['x']['y'] = ?", s) def test_select_with_dict_column_where_clause_gt(self): mytable = self.mytable - s = select(mytable.c.data).\ - where(mytable.c.data['x'] > 1) - self.assertSQL( - "SELECT mytable.data FROM mytable WHERE mytable.data['x'] > ?", - s - ) + s = select(mytable.c.data).where(mytable.c.data["x"] > 1) + self.assertSQL("SELECT mytable.data FROM mytable WHERE mytable.data['x'] > ?", s) def test_select_with_dict_column_where_clause_other_col(self): mytable = self.mytable s = select(mytable.c.name) - s = s.where(mytable.c.data['x'] == mytable.c.name) + s = s.where(mytable.c.data["x"] == mytable.c.name) self.assertSQL( - "SELECT mytable.name FROM mytable " + - "WHERE mytable.data['x'] = mytable.name", - s + "SELECT mytable.name FROM mytable " + "WHERE mytable.data['x'] = mytable.name", s ) def test_update_with_dict_column(self): mytable = self.mytable - stmt = mytable.update().\ - where(mytable.c.name == 'Arthur Dent').\ - values({ - "data['x']": "Trillian" - }) - self.assertSQL( - "UPDATE mytable SET data['x'] = ? WHERE mytable.name = ?", - stmt + stmt = ( + mytable.update() + .where(mytable.c.name == "Arthur Dent") + .values({"data['x']": "Trillian"}) ) + self.assertSQL("UPDATE mytable SET data['x'] = ? WHERE mytable.name = ?", stmt) def set_up_character_and_cursor(self, return_value=None): - return_value = return_value or [('Trillian', {})] + return_value = return_value or [("Trillian", {})] fake_cursor.fetchall.return_value = return_value fake_cursor.description = ( - ('characters_name', None, None, None, None, None, None), - ('characters_data', None, None, None, None, None, None) + ("characters_name", None, None, None, None, None, None), + ("characters_data", None, None, None, None, None, None), ) fake_cursor.rowcount = 1 Base = declarative_base() class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer) data = sa.Column(ObjectType) @@ -134,77 +117,76 @@ class Character(Base): def test_assign_null_to_object_array(self): session, Character = self.set_up_character_and_cursor() - char_1 = Character(name='Trillian', data_list=None) + char_1 = Character(name="Trillian", data_list=None) self.assertIsNone(char_1.data_list) - char_2 = Character(name='Trillian', data_list=1) + char_2 = Character(name="Trillian", data_list=1) self.assertEqual(char_2.data_list, [1]) - char_3 = Character(name='Trillian', data_list=[None]) + char_3 = Character(name="Trillian", data_list=[None]) self.assertEqual(char_3.data_list, [None]) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_assign_to_object_type_after_commit(self): - session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', None)] - ) - char = Character(name='Trillian') + session, Character = self.set_up_character_and_cursor(return_value=[("Trillian", None)]) + char = Character(name="Trillian") session.add(char) session.commit() - char.data = {'x': 1} + char.data = {"x": 1} self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( "UPDATE characters SET data = ? WHERE characters.name = ?", - ({'x': 1}, 'Trillian',) + ( + {"x": 1}, + "Trillian", + ), ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_change_tracking(self): session, Character = self.set_up_character_and_cursor() - char = Character(name='Trillian') + char = Character(name="Trillian") session.add(char) session.commit() try: - char.data['x'] = 1 + char.data["x"] = 1 except Exception: - print(fake_cursor.fetchall.called) - print(fake_cursor.mock_calls) + print(fake_cursor.fetchall.called) # noqa: T201 + print(fake_cursor.mock_calls) # noqa: T201 raise self.assertIn(char, session.dirty) try: session.commit() except Exception: - print(fake_cursor.mock_calls) + print(fake_cursor.mock_calls) # noqa: T201 raise self.assertNotIn(char, session.dirty) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update(self): session, Character = self.set_up_character_and_cursor() - char = Character(name='Trillian') + char = Character(name="Trillian") session.add(char) session.commit() - char.data['x'] = 1 - char.data['y'] = 2 + char.data["x"] = 1 + char.data["y"] = 2 session.commit() # on python 3 dicts aren't sorted so the order if x or y is updated # first isn't deterministic try: fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['y'] = ?, data['x'] = ? " - "WHERE characters.name = ?"), - (2, 1, 'Trillian') + ("UPDATE characters SET data['y'] = ?, data['x'] = ? " "WHERE characters.name = ?"), + (2, 1, "Trillian"), ) except AssertionError: fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['x'] = ?, data['y'] = ? " - "WHERE characters.name = ?"), - (1, 2, 'Trillian') + ("UPDATE characters SET data['x'] = ?, data['y'] = ? " "WHERE characters.name = ?"), + (1, 2, "Trillian"), ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update_only_one_key_changed(self): """ If only one attribute of Crate is changed @@ -212,143 +194,123 @@ def test_partial_dict_update_only_one_key_changed(self): not all attributes of Crate. """ session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', dict(x=1, y=2))] + return_value=[("Trillian", {"x": 1, "y": 2})] ) - char = Character(name='Trillian') - char.data = dict(x=1, y=2) + char = Character(name="Trillian") + char.data = {"x": 1, "y": 2} session.add(char) session.commit() - char.data['y'] = 3 + char.data["y"] = 3 session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['y'] = ? " - "WHERE characters.name = ?"), - (3, 'Trillian') + ("UPDATE characters SET data['y'] = ? " "WHERE characters.name = ?"), (3, "Trillian") ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update_with_regular_column(self): session, Character = self.set_up_character_and_cursor() - char = Character(name='Trillian') + char = Character(name="Trillian") session.add(char) session.commit() - char.data['x'] = 1 + char.data["x"] = 1 char.age = 20 session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET age = ?, data['x'] = ? " - "WHERE characters.name = ?"), - (20, 1, 'Trillian') + ("UPDATE characters SET age = ?, data['x'] = ? " "WHERE characters.name = ?"), + (20, 1, "Trillian"), ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update_with_delitem(self): - session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', {'x': 1})] - ) + session, Character = self.set_up_character_and_cursor(return_value=[("Trillian", {"x": 1})]) - char = Character(name='Trillian') - char.data = {'x': 1} + char = Character(name="Trillian") + char.data = {"x": 1} session.add(char) session.commit() - del char.data['x'] + del char.data["x"] self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['x'] = ? " - "WHERE characters.name = ?"), - (None, 'Trillian') + ("UPDATE characters SET data['x'] = ? " "WHERE characters.name = ?"), (None, "Trillian") ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update_with_delitem_setitem(self): - """ test that the change tracking doesn't get messed up + """test that the change tracking doesn't get messed up delitem -> setitem """ - session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', {'x': 1})] - ) + session, Character = self.set_up_character_and_cursor(return_value=[("Trillian", {"x": 1})]) session = Session(bind=self.engine) - char = Character(name='Trillian') - char.data = {'x': 1} + char = Character(name="Trillian") + char.data = {"x": 1} session.add(char) session.commit() - del char.data['x'] - char.data['x'] = 4 + del char.data["x"] + char.data["x"] = 4 self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['x'] = ? " - "WHERE characters.name = ?"), - (4, 'Trillian') + ("UPDATE characters SET data['x'] = ? " "WHERE characters.name = ?"), (4, "Trillian") ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update_with_setitem_delitem(self): - """ test that the change tracking doesn't get messed up + """test that the change tracking doesn't get messed up setitem -> delitem """ - session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', {'x': 1})] - ) + session, Character = self.set_up_character_and_cursor(return_value=[("Trillian", {"x": 1})]) - char = Character(name='Trillian') - char.data = {'x': 1} + char = Character(name="Trillian") + char.data = {"x": 1} session.add(char) session.commit() - char.data['x'] = 4 - del char.data['x'] + char.data["x"] = 4 + del char.data["x"] self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['x'] = ? " - "WHERE characters.name = ?"), - (None, 'Trillian') + ("UPDATE characters SET data['x'] = ? " "WHERE characters.name = ?"), (None, "Trillian") ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_partial_dict_update_with_setitem_delitem_setitem(self): - """ test that the change tracking doesn't get messed up + """test that the change tracking doesn't get messed up setitem -> delitem -> setitem """ - session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', {'x': 1})] - ) + session, Character = self.set_up_character_and_cursor(return_value=[("Trillian", {"x": 1})]) - char = Character(name='Trillian') - char.data = {'x': 1} + char = Character(name="Trillian") + char.data = {"x": 1} session.add(char) session.commit() - char.data['x'] = 4 - del char.data['x'] - char.data['x'] = 3 + char.data["x"] = 4 + del char.data["x"] + char.data["x"] = 3 self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['x'] = ? " - "WHERE characters.name = ?"), - (3, 'Trillian') + ("UPDATE characters SET data['x'] = ? " "WHERE characters.name = ?"), (3, "Trillian") ) def set_up_character_and_cursor_data_list(self, return_value=None): - return_value = return_value or [('Trillian', {})] + return_value = return_value or [("Trillian", {})] fake_cursor.fetchall.return_value = return_value fake_cursor.description = ( - ('characters_name', None, None, None, None, None, None), - ('characters_data_list', None, None, None, None, None, None) - + ("characters_name", None, None, None, None, None, None), + ("characters_data_list", None, None, None, None, None, None), ) fake_cursor.rowcount = 1 Base = declarative_base() class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) data_list = sa.Column(ObjectArray) @@ -357,48 +319,46 @@ class Character(Base): def _setup_object_array_char(self): session, Character = self.set_up_character_and_cursor_data_list( - return_value=[('Trillian', [{'1': 1}, {'2': 2}])] + return_value=[("Trillian", [{"1": 1}, {"2": 2}])] ) - char = Character(name='Trillian', data_list=[{'1': 1}, {'2': 2}]) + char = Character(name="Trillian", data_list=[{"1": 1}, {"2": 2}]) session.add(char) session.commit() return session, char - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_setitem_change_tracking(self): session, char = self._setup_object_array_char() - char.data_list[1] = {'3': 3} + char.data_list[1] = {"3": 3} self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data_list = ? " - "WHERE characters.name = ?"), - ([{'1': 1}, {'3': 3}], 'Trillian') + ("UPDATE characters SET data_list = ? " "WHERE characters.name = ?"), + ([{"1": 1}, {"3": 3}], "Trillian"), ) def _setup_nested_object_char(self): session, Character = self.set_up_character_and_cursor( - return_value=[('Trillian', {'nested': {'x': 1, 'y': {'z': 2}}})] + return_value=[("Trillian", {"nested": {"x": 1, "y": {"z": 2}}})] ) - char = Character(name='Trillian') - char.data = {'nested': {'x': 1, 'y': {'z': 2}}} + char = Character(name="Trillian") + char.data = {"nested": {"x": 1, "y": {"z": 2}}} session.add(char) session.commit() return session, char - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_nested_object_change_tracking(self): session, char = self._setup_nested_object_char() char.data["nested"]["x"] = 3 self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['nested'] = ? " - "WHERE characters.name = ?"), - ({'y': {'z': 2}, 'x': 3}, 'Trillian') + ("UPDATE characters SET data['nested'] = ? " "WHERE characters.name = ?"), + ({"y": {"z": 2}, "x": 3}, "Trillian"), ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_deep_nested_object_change_tracking(self): session, char = self._setup_nested_object_char() # change deep nested object @@ -406,12 +366,11 @@ def test_deep_nested_object_change_tracking(self): self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['nested'] = ? " - "WHERE characters.name = ?"), - ({'y': {'z': 5}, 'x': 1}, 'Trillian') + ("UPDATE characters SET data['nested'] = ? " "WHERE characters.name = ?"), + ({"y": {"z": 5}, "x": 1}, "Trillian"), ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_delete_nested_object_tracking(self): session, char = self._setup_nested_object_char() # delete nested object @@ -419,42 +378,41 @@ def test_delete_nested_object_tracking(self): self.assertIn(char, session.dirty) session.commit() fake_cursor.execute.assert_called_with( - ("UPDATE characters SET data['nested'] = ? " - "WHERE characters.name = ?"), - ({'y': {}, 'x': 1}, 'Trillian') + ("UPDATE characters SET data['nested'] = ? " "WHERE characters.name = ?"), + ({"y": {}, "x": 1}, "Trillian"), ) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_append_change_tracking(self): session, char = self._setup_object_array_char() - char.data_list.append({'3': 3}) + char.data_list.append({"3": 3}) self.assertIn(char, session.dirty) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_insert_change_tracking(self): session, char = self._setup_object_array_char() - char.data_list.insert(0, {'3': 3}) + char.data_list.insert(0, {"3": 3}) self.assertIn(char, session.dirty) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_slice_change_tracking(self): session, char = self._setup_object_array_char() - char.data_list[:] = [{'3': 3}] + char.data_list[:] = [{"3": 3}] self.assertIn(char, session.dirty) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_extend_change_tracking(self): session, char = self._setup_object_array_char() - char.data_list.extend([{'3': 3}]) + char.data_list.extend([{"3": 3}]) self.assertIn(char, session.dirty) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_pop_change_tracking(self): session, char = self._setup_object_array_char() char.data_list.pop() self.assertIn(char, session.dirty) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_object_array_remove_change_tracking(self): session, char = self._setup_object_array_char() item = char.data_list[0] diff --git a/tests/function_test.py b/tests/function_test.py index 072ab43..3c62bff 100644 --- a/tests/function_test.py +++ b/tests/function_test.py @@ -23,6 +23,7 @@ import sqlalchemy as sa from sqlalchemy.sql.sqltypes import TIMESTAMP + try: from sqlalchemy.orm import declarative_base except ImportError: diff --git a/tests/insert_from_select_test.py b/tests/insert_from_select_test.py index ac414bc..7d57e81 100644 --- a/tests/insert_from_select_test.py +++ b/tests/insert_from_select_test.py @@ -20,12 +20,12 @@ # software solely pursuant to the terms of the relevant commercial agreement. from datetime import datetime from unittest import TestCase, skipIf -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import sqlalchemy as sa from sqlalchemy.orm import Session -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION try: from sqlalchemy.orm import declarative_base @@ -34,25 +34,23 @@ from crate.client.cursor import Cursor - -fake_cursor = MagicMock(name='fake_cursor') +fake_cursor = MagicMock(name="fake_cursor") fake_cursor.rowcount = 1 -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class SqlAlchemyInsertFromSelectTest(TestCase): - def assertSQL(self, expected_str, actual_expr): - self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + self.assertEqual(expected_str, str(actual_expr).replace("\n", "")) def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") Base = declarative_base() class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer) @@ -60,7 +58,7 @@ class Character(Base): status = sa.Column(sa.String) class CharacterArchive(Base): - __tablename__ = 'characters_archive' + __tablename__ = "characters_archive" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer) @@ -71,17 +69,20 @@ class CharacterArchive(Base): self.character_archived = CharacterArchive self.session = Session(bind=self.engine) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_insert_from_select_triggered(self): - char = self.character(name='Arthur', status='Archived') + char = self.character(name="Arthur", status="Archived") self.session.add(char) self.session.commit() - sel = sa.select(self.character.name, self.character.age).where(self.character.status == "Archived") - ins = sa.insert(self.character_archived).from_select(['name', 'age'], sel) + sel = sa.select(self.character.name, self.character.age).where( + self.character.status == "Archived" + ) + ins = sa.insert(self.character_archived).from_select(["name", "age"], sel) self.session.execute(ins) self.session.commit() self.assertSQL( - "INSERT INTO characters_archive (name, age) SELECT characters.name, characters.age FROM characters WHERE characters.status = ?", - ins.compile(bind=self.engine) + "INSERT INTO characters_archive (name, age) " + "SELECT characters.name, characters.age FROM characters WHERE characters.status = ?", + ins.compile(bind=self.engine), ) diff --git a/tests/integration.py b/tests/integration.py index 80d155e..941a209 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -21,15 +21,16 @@ from __future__ import absolute_import +import doctest +import logging import os import sys import unittest -import doctest from pprint import pprint -import logging from crate.client import connect -from sqlalchemy_cratedb import SA_VERSION, SA_2_0 + +from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION from tests.settings import crate_host log = logging.getLogger() @@ -40,30 +41,26 @@ def cprint(s): if isinstance(s, bytes): - s = s.decode('utf-8') - print(s) + s = s.decode("utf-8") + print(s) # noqa: T201 def docs_path(*parts): - return os.path.abspath( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), *parts - ) - ) + return os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), *parts)) def provision_database(): - drop_tables() with connect(crate_host) as conn: cursor = conn.cursor() - with open(docs_path('tests/assets/locations.sql')) as s: + with open(docs_path("tests/assets/locations.sql")) as s: stmt = s.read() cursor.execute(stmt) - stmt = ("SELECT COUNT(*) FROM information_schema.tables " - "WHERE table_name = 'locations'") + stmt = ( + "SELECT COUNT(*) FROM information_schema.tables " "WHERE table_name = 'locations'" + ) cursor.execute(stmt) assert cursor.fetchall()[0][0] == 1 @@ -76,8 +73,9 @@ def provision_database(): # refresh location table so imported data is visible immediately cursor.execute("REFRESH TABLE locations") # create blob table - cursor.execute("CREATE BLOB TABLE myfiles CLUSTERED INTO 1 SHARDS " + - "WITH (number_of_replicas=0)") + cursor.execute( + "CREATE BLOB TABLE myfiles CLUSTERED INTO 1 SHARDS " + "WITH (number_of_replicas=0)" + ) # create users cursor.execute("CREATE USER me WITH (password = 'my_secret_pw')") @@ -155,7 +153,7 @@ def _execute_statement(cursor, stmt, on_error="raise"): # FIXME: Why does this croak on statements like ``DROP TABLE cities``? # Note: When needing to debug the test environment, you may want to # enable this logger statement. - # log.exception("Executing SQL statement failed") + # log.exception("Executing SQL statement failed") # noqa: ERA001 if on_error == "ignore": pass elif on_error == "raise": @@ -163,12 +161,11 @@ def _execute_statement(cursor, stmt, on_error="raise"): def setUp(test): - provision_database() - test.globs['crate_host'] = crate_host - test.globs['pprint'] = pprint - test.globs['print'] = cprint + test.globs["crate_host"] = crate_host + test.globs["pprint"] = pprint + test.globs["print"] = cprint def tearDown(test): @@ -177,21 +174,21 @@ def tearDown(test): def create_test_suite(): suite = unittest.TestSuite() - flags = (doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS) + flags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS sqlalchemy_integration_tests = [ - 'docs/getting-started.rst', - 'docs/crud.rst', - 'docs/working-with-types.rst', - 'docs/advanced-querying.rst', - 'docs/inspection-reflection.rst', + "docs/getting-started.rst", + "docs/crud.rst", + "docs/working-with-types.rst", + "docs/advanced-querying.rst", + "docs/inspection-reflection.rst", ] # Don't run DataFrame integration tests on SQLAlchemy 1.3 and Python 3.7. skip_dataframe = SA_VERSION < SA_2_0 or sys.version_info < (3, 8) if not skip_dataframe: sqlalchemy_integration_tests += [ - 'docs/dataframe.rst', + "docs/dataframe.rst", ] s = doctest.DocFileSuite( @@ -200,7 +197,7 @@ def create_test_suite(): setUp=setUp, tearDown=tearDown, optionflags=flags, - encoding='utf-8' + encoding="utf-8", ) suite.addTest(s) diff --git a/tests/match_test.py b/tests/match_test.py index 048e590..0ab2cb4 100644 --- a/tests/match_test.py +++ b/tests/match_test.py @@ -25,40 +25,40 @@ import sqlalchemy as sa from sqlalchemy.orm import Session + try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import ObjectType -from sqlalchemy_cratedb.predicate import match from crate.client.cursor import Cursor +from sqlalchemy_cratedb import ObjectType +from sqlalchemy_cratedb.predicate import match -fake_cursor = MagicMock(name='fake_cursor') -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +fake_cursor = MagicMock(name="fake_cursor") +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor class SqlAlchemyMatchTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") metadata = sa.MetaData() - self.quotes = sa.Table('quotes', metadata, - sa.Column('author', sa.String), - sa.Column('quote', sa.String)) + self.quotes = sa.Table( + "quotes", metadata, sa.Column("author", sa.String), sa.Column("quote", sa.String) + ) self.session, self.Character = self.set_up_character_and_session() self.maxDiff = None def assertSQL(self, expected_str, actual_expr): - self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + self.assertEqual(expected_str, str(actual_expr).replace("\n", "")) def set_up_character_and_session(self): Base = declarative_base() class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) info = sa.Column(ObjectType) @@ -66,72 +66,76 @@ class Character(Base): return session, Character def test_simple_match(self): - query = self.session.query(self.Character.name) \ - .filter(match(self.Character.name, 'Trillian')) + query = self.session.query(self.Character.name).filter( + match(self.Character.name, "Trillian") + ) self.assertSQL( - "SELECT characters.name AS characters_name FROM characters " + - "WHERE match(characters.name, ?)", - query + "SELECT characters.name AS characters_name FROM characters " + + "WHERE match(characters.name, ?)", + query, ) def test_match_boost(self): - query = self.session.query(self.Character.name) \ - .filter(match({self.Character.name: 0.5}, 'Trillian')) + query = self.session.query(self.Character.name).filter( + match({self.Character.name: 0.5}, "Trillian") + ) self.assertSQL( - "SELECT characters.name AS characters_name FROM characters " + - "WHERE match((characters.name 0.5), ?)", - query + "SELECT characters.name AS characters_name FROM characters " + + "WHERE match((characters.name 0.5), ?)", + query, ) def test_muli_match(self): - query = self.session.query(self.Character.name) \ - .filter(match({self.Character.name: 0.5, - self.Character.info['race']: 0.9}, - 'Trillian')) + query = self.session.query(self.Character.name).filter( + match({self.Character.name: 0.5, self.Character.info["race"]: 0.9}, "Trillian") + ) self.assertSQL( - "SELECT characters.name AS characters_name FROM characters " + - "WHERE match(" + - "(characters.info['race'] 0.9, characters.name 0.5), ?" + - ")", - query + "SELECT characters.name AS characters_name FROM characters " + + "WHERE match(" + + "(characters.info['race'] 0.9, characters.name 0.5), ?" + + ")", + query, ) def test_match_type_options(self): - query = self.session.query(self.Character.name) \ - .filter(match({self.Character.name: 0.5, - self.Character.info['race']: 0.9}, - 'Trillian', - match_type='phrase', - options={'fuzziness': 3, 'analyzer': 'english'})) + query = self.session.query(self.Character.name).filter( + match( + {self.Character.name: 0.5, self.Character.info["race"]: 0.9}, + "Trillian", + match_type="phrase", + options={"fuzziness": 3, "analyzer": "english"}, + ) + ) self.assertSQL( - "SELECT characters.name AS characters_name FROM characters " + - "WHERE match(" + - "(characters.info['race'] 0.9, characters.name 0.5), ?" + - ") using phrase with (analyzer=english, fuzziness=3)", - query + "SELECT characters.name AS characters_name FROM characters " + + "WHERE match(" + + "(characters.info['race'] 0.9, characters.name 0.5), ?" + + ") using phrase with (analyzer=english, fuzziness=3)", + query, ) def test_score(self): - query = self.session.query(self.Character.name, - sa.literal_column('_score')) \ - .filter(match(self.Character.name, 'Trillian')) + query = self.session.query(self.Character.name, sa.literal_column("_score")).filter( + match(self.Character.name, "Trillian") + ) self.assertSQL( - "SELECT characters.name AS characters_name, _score " + - "FROM characters WHERE match(characters.name, ?)", - query + "SELECT characters.name AS characters_name, _score " + + "FROM characters WHERE match(characters.name, ?)", + query, ) def test_options_without_type(self): query = self.session.query(self.Character.name).filter( - match({self.Character.name: 0.5, self.Character.info['race']: 0.9}, - 'Trillian', - options={'boost': 10.0}) + match( + {self.Character.name: 0.5, self.Character.info["race"]: 0.9}, + "Trillian", + options={"boost": 10.0}, + ) ) err = None try: str(query) except ValueError as e: err = e - msg = "missing match_type. " + \ - "It's not allowed to specify options without match_type" + msg = "missing match_type. " + "It's not allowed to specify options without match_type" self.assertEqual(str(err), msg) diff --git a/tests/query_caching.py b/tests/query_caching.py index 16a7582..34d0180 100644 --- a/tests/query_caching.py +++ b/tests/query_caching.py @@ -20,14 +20,16 @@ # software solely pursuant to the terms of the relevant commercial agreement. from __future__ import absolute_import + from unittest import TestCase, skipIf import sqlalchemy as sa +from crate.testing.settings import crate_host from sqlalchemy.orm import Session from sqlalchemy.sql.operators import eq -from sqlalchemy_cratedb import SA_VERSION, SA_1_4, ObjectArray, ObjectType -from crate.testing.settings import crate_host +from sqlalchemy_cratedb import ObjectArray, ObjectType +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION try: from sqlalchemy.orm import declarative_base @@ -36,7 +38,6 @@ class SqlAlchemyQueryCompilationCaching(TestCase): - def setUp(self): self.engine = sa.create_engine(f"crate://{crate_host}") self.metadata = sa.MetaData(schema="testdrive") @@ -50,7 +51,7 @@ def setup_entity(self): Base = declarative_base(metadata=self.metadata) class Character(Base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer) data = sa.Column(ObjectType) @@ -66,8 +67,8 @@ def setup_data(self): self.metadata.create_all(self.engine) Character = self.Character - char1 = Character(name='Trillian', data={'x': 1}, data_list=[{'foo': 1, 'bar': 10}]) - char2 = Character(name='Slartibartfast', data={'y': 2}, data_list=[{'bar': 2}]) + char1 = Character(name="Trillian", data={"x": 1}, data_list=[{"foo": 1, "bar": 10}]) + char2 = Character(name="Slartibartfast", data={"y": 2}, data_list=[{"bar": 2}]) self.session.add(char1) self.session.add(char2) self.session.commit() @@ -89,11 +90,11 @@ def test_object_multiple_select_legacy(self): self.setup_data() Character = self.Character - selectable = sa.select(Character).where(Character.data['x'] == 1) + selectable = sa.select(Character).where(Character.data["x"] == 1) result = self.session.execute(selectable).scalar_one().data self.assertEqual({"x": 1}, result) - selectable = sa.select(Character).where(Character.data['y'] == 2) + selectable = sa.select(Character).where(Character.data["y"] == 2) result = self.session.execute(selectable).scalar_one().data self.assertEqual({"y": 2}, result) @@ -113,11 +114,11 @@ def test_object_multiple_select_modern(self): self.setup_data() Character = self.Character - selectable = sa.select(Character).where(Character.data['x'].as_integer() == 1) + selectable = sa.select(Character).where(Character.data["x"].as_integer() == 1) result = self.session.execute(selectable).scalar_one().data self.assertEqual({"x": 1}, result) - selectable = sa.select(Character).where(Character.data['y'].as_integer() == 2) + selectable = sa.select(Character).where(Character.data["y"].as_integer() == 2) result = self.session.execute(selectable).scalar_one().data self.assertEqual({"y": 2}, result) @@ -132,10 +133,10 @@ def test_objectarray_multiple_select(self): self.setup_data() Character = self.Character - selectable = sa.select(Character).where(Character.data_list['foo'].any(1, operator=eq)) + selectable = sa.select(Character).where(Character.data_list["foo"].any(1, operator=eq)) result = self.session.execute(selectable).scalar_one().data self.assertEqual({"x": 1}, result) - selectable = sa.select(Character).where(Character.data_list['bar'].any(2, operator=eq)) + selectable = sa.select(Character).where(Character.data_list["bar"].any(2, operator=eq)) result = self.session.execute(selectable).scalar_one().data self.assertEqual({"y": 2}, result) diff --git a/tests/settings.py b/tests/settings.py index 2f0b0f1..103a42b 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -21,8 +21,7 @@ # software solely pursuant to the terms of the relevant commercial agreement. from __future__ import absolute_import - crate_port = 4200 -localhost = '127.0.0.1' +localhost = "127.0.0.1" crate_host = "{host}:{port}".format(host=localhost, port=crate_port) crate_uri = "http://%s" % crate_host diff --git a/tests/test_support_pandas.py b/tests/test_support_pandas.py index d00aae6..47fe9c7 100644 --- a/tests/test_support_pandas.py +++ b/tests/test_support_pandas.py @@ -2,11 +2,10 @@ import sys import pytest -from sqlalchemy.exc import ProgrammingError - from pueblo.testing.pandas import makeTimeDataFrame +from sqlalchemy.exc import ProgrammingError -from sqlalchemy_cratedb import SA_VERSION, SA_2_0 +from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION from sqlalchemy_cratedb.support.pandas import table_kwargs TABLE_NAME = "foobar" @@ -17,8 +16,12 @@ df["time"] = df.index -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier") -@pytest.mark.skipif(SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier") +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" +) +@pytest.mark.skipif( + SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier" +) def test_table_kwargs_partitioned_by(cratedb_service): """ Validate adding CrateDB dialect table option `PARTITIONED BY` at runtime. @@ -49,8 +52,12 @@ def test_table_kwargs_partitioned_by(cratedb_service): assert 'PARTITIONED BY ("time")' in ddl[0][0] -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier") -@pytest.mark.skipif(SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier") +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" +) +@pytest.mark.skipif( + SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier" +) def test_table_kwargs_translog_durability(cratedb_service): """ Validate adding CrateDB dialect table option `translog.durability` at runtime. @@ -81,8 +88,12 @@ def test_table_kwargs_translog_durability(cratedb_service): assert """"translog.durability" = 'ASYNC'""" in ddl[0][0] -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier") -@pytest.mark.skipif(SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier") +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" +) +@pytest.mark.skipif( + SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier" +) def test_table_kwargs_unknown(cratedb_service): """ Validate behaviour when adding an unknown CrateDB dialect table option. @@ -96,5 +107,9 @@ def test_table_kwargs_unknown(cratedb_service): if_exists="replace", index=False, ) - assert ex.match(re.escape('SQLParseException[Invalid property "unknown_option" ' - 'passed to [ALTER | CREATE] TABLE statement]')) + assert ex.match( + re.escape( + 'SQLParseException[Invalid property "unknown_option" ' + "passed to [ALTER | CREATE] TABLE statement]" + ) + ) diff --git a/tests/test_support_polyfill.py b/tests/test_support_polyfill.py index d495fee..d5f39cc 100644 --- a/tests/test_support_polyfill.py +++ b/tests/test_support_polyfill.py @@ -6,17 +6,23 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb.support import check_uniqueness_factory, patch_autoincrement_timestamp, refresh_after_dml +from sqlalchemy_cratedb.support import ( + check_uniqueness_factory, + patch_autoincrement_timestamp, + refresh_after_dml, +) -@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3 and earlier") +@pytest.mark.skipif( + SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3 and earlier" +) def test_autoincrement_timestamp(cratedb_service): """ Validate autoincrement columns using `sa.DateTime` columns. @@ -31,7 +37,7 @@ def test_autoincrement_timestamp(cratedb_service): # Define DDL. class FooBar(Base): - __tablename__ = 'foobar' + __tablename__ = "foobar" id = sa.Column(sa.String, primary_key=True) date = sa.Column(sa.DateTime, autoincrement=True) number = sa.Column(sa.BigInteger, autoincrement=True) @@ -47,7 +53,9 @@ class FooBar(Base): session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select(FooBar.date, FooBar.number, FooBar.string)).mappings().first() + result = ( + session.execute(sa.select(FooBar.date, FooBar.number, FooBar.string)).mappings().first() + ) # Compare outcome. assert result["date"].year == dt.datetime.now().year @@ -55,7 +63,9 @@ class FooBar(Base): assert result["string"] >= "1718846016235" -@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Feature not supported on SQLAlchemy 1.3 and earlier") +@pytest.mark.skipif( + SA_VERSION < SA_1_4, reason="Feature not supported on SQLAlchemy 1.3 and earlier" +) def test_check_uniqueness_factory(cratedb_service): """ Validate basic synthetic UNIQUE constraints. @@ -69,7 +79,7 @@ def test_check_uniqueness_factory(cratedb_service): # Define DDL. class FooBar(Base): - __tablename__ = 'foobar' + __tablename__ = "foobar" id = sa.Column(sa.String, primary_key=True) name = sa.Column(sa.String) @@ -93,7 +103,9 @@ class FooBar(Base): assert ex.match("DuplicateKeyException in table 'foobar' on constraint 'name'") -@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Feature not supported on SQLAlchemy 1.3 and earlier") +@pytest.mark.skipif( + SA_VERSION < SA_1_4, reason="Feature not supported on SQLAlchemy 1.3 and earlier" +) @pytest.mark.parametrize("mode", ["engine", "session"]) def test_refresh_after_dml(cratedb_service, mode): """ @@ -115,7 +127,7 @@ def test_refresh_after_dml(cratedb_service, mode): # Define DDL. class FooBar(Base): - __tablename__ = 'foobar' + __tablename__ = "foobar" id = sa.Column(sa.String, primary_key=True) Base.metadata.drop_all(engine, checkfirst=True) @@ -131,7 +143,9 @@ class FooBar(Base): result = query.first() # Sanity checks. - assert result is not None, "Database result is empty. Most probably, `REFRESH TABLE` wasn't issued." + assert ( + result is not None + ), "Database result is empty. Most probably, `REFRESH TABLE` wasn't issued." # Compare outcome. assert result[0] == "foo" diff --git a/tests/update_test.py b/tests/update_test.py index 8a2f139..9a06511 100644 --- a/tests/update_test.py +++ b/tests/update_test.py @@ -20,12 +20,14 @@ # software solely pursuant to the terms of the relevant commercial agreement. from datetime import datetime from unittest import TestCase, skipIf -from unittest.mock import patch, MagicMock - -from sqlalchemy_cratedb import ObjectType, SA_VERSION, SA_1_4 +from unittest.mock import MagicMock, patch import sqlalchemy as sa from sqlalchemy.orm import Session + +from sqlalchemy_cratedb import ObjectType +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -33,22 +35,20 @@ from crate.client.cursor import Cursor - -fake_cursor = MagicMock(name='fake_cursor') +fake_cursor = MagicMock(name="fake_cursor") fake_cursor.rowcount = 1 -FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) +FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) FakeCursor.return_value = fake_cursor @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") class SqlAlchemyUpdateTest(TestCase): - def setUp(self): - self.engine = sa.create_engine('crate://') + self.engine = sa.create_engine("crate://") self.base = declarative_base() class Character(self.base): - __tablename__ = 'characters' + __tablename__ = "characters" name = sa.Column(sa.String, primary_key=True) age = sa.Column(sa.Integer) @@ -58,58 +58,58 @@ class Character(self.base): self.character = Character self.session = Session(bind=self.engine) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_onupdate_is_triggered(self): - char = self.character(name='Arthur') + char = self.character(name="Arthur") self.session.add(char) self.session.commit() now = datetime.utcnow() - fake_cursor.fetchall.return_value = [('Arthur', None)] + fake_cursor.fetchall.return_value = [("Arthur", None)] fake_cursor.description = ( - ('characters_name', None, None, None, None, None, None), - ('characters_ts', None, None, None, None, None, None), + ("characters_name", None, None, None, None, None, None), + ("characters_ts", None, None, None, None, None, None), ) char.age = 40 self.session.commit() - expected_stmt = ("UPDATE characters SET age = ?, " - "ts = ? WHERE characters.name = ?") + expected_stmt = "UPDATE characters SET age = ?, " "ts = ? WHERE characters.name = ?" args, kwargs = fake_cursor.execute.call_args stmt = args[0] args = args[1] self.assertEqual(expected_stmt, stmt) self.assertEqual(40, args[0]) - dt = datetime.strptime(args[1], '%Y-%m-%dT%H:%M:%S.%f') + dt = datetime.strptime(args[1], "%Y-%m-%dT%H:%M:%S.%f") self.assertIsInstance(dt, datetime) self.assertGreater(dt, now) - self.assertEqual('Arthur', args[2]) + self.assertEqual("Arthur", args[2]) - @patch('crate.client.connection.Cursor', FakeCursor) + @patch("crate.client.connection.Cursor", FakeCursor) def test_bulk_update(self): """ - Checks whether bulk updates work correctly - on native types and Crate types. + Checks whether bulk updates work correctly + on native types and Crate types. """ before_update_time = datetime.utcnow() - self.session.query(self.character).update({ - # change everyone's name to Julia - self.character.name: 'Julia', - self.character.obj: {'favorite_book': 'Romeo & Juliet'} - }) + self.session.query(self.character).update( + { + # change everyone's name to Julia + self.character.name: "Julia", + self.character.obj: {"favorite_book": "Romeo & Juliet"}, + } + ) self.session.commit() - expected_stmt = ("UPDATE characters SET " - "name = ?, obj = ?, ts = ?") + expected_stmt = "UPDATE characters SET " "name = ?, obj = ?, ts = ?" args, kwargs = fake_cursor.execute.call_args stmt = args[0] args = args[1] self.assertEqual(expected_stmt, stmt) - self.assertEqual('Julia', args[0]) - self.assertEqual({'favorite_book': 'Romeo & Juliet'}, args[1]) - dt = datetime.strptime(args[2], '%Y-%m-%dT%H:%M:%S.%f') + self.assertEqual("Julia", args[0]) + self.assertEqual({"favorite_book": "Romeo & Juliet"}, args[1]) + dt = datetime.strptime(args[2], "%Y-%m-%dT%H:%M:%S.%f") self.assertIsInstance(dt, datetime) self.assertGreater(dt, before_update_time) diff --git a/tests/util.py b/tests/util.py index 4acc7a0..cac408e 100644 --- a/tests/util.py +++ b/tests/util.py @@ -15,12 +15,10 @@ def assertIsSubclass(self, cls, superclass, msg=None): r = issubclass(cls, superclass) except TypeError: if not isinstance(cls, type): - self.fail(self._formatMessage(msg, - '%r is not a class' % (cls,))) + self.fail(self._formatMessage(msg, "%r is not a class" % (cls,))) raise if not r: - self.fail(self._formatMessage(msg, - '%r is not a subclass of %r' % (cls, superclass))) + self.fail(self._formatMessage(msg, "%r is not a subclass of %r" % (cls, superclass))) class ParametrizedTestCase(unittest.TestCase): @@ -30,14 +28,15 @@ class ParametrizedTestCase(unittest.TestCase): https://eli.thegreenplace.net/2011/08/02/python-unit-testing-parametrized-test-cases """ + def __init__(self, methodName="runTest", param=None): super(ParametrizedTestCase, self).__init__(methodName) self.param = param @staticmethod def parametrize(testcase_klass, param=None): - """ Create a suite containing all tests taken from the given - subclass, passing them the parameter 'param'. + """Create a suite containing all tests taken from the given + subclass, passing them the parameter 'param'. """ testloader = unittest.TestLoader() testnames = testloader.getTestCaseNames(testcase_klass) diff --git a/tests/vector_test.py b/tests/vector_test.py index 6a564d7..dcd5329 100644 --- a/tests/vector_test.py +++ b/tests/vector_test.py @@ -38,8 +38,8 @@ from crate.client.cursor import Cursor -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 from sqlalchemy_cratedb import FloatVector, knn_match +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION from sqlalchemy_cratedb.type.vector import from_db, to_db fake_cursor = MagicMock(name="fake_cursor") @@ -48,7 +48,10 @@ if SA_VERSION < SA_1_4: - pytest.skip(reason="The FloatVector type is not supported on SQLAlchemy 1.3 and earlier", allow_module_level=True) + pytest.skip( + reason="The FloatVector type is not supported on SQLAlchemy 1.3 and earlier", + allow_module_level=True, + ) @patch("crate.client.connection.Cursor", FakeCursor) @@ -56,6 +59,7 @@ class SqlAlchemyVectorTypeTest(TestCase): """ Verify compilation of SQL statements where the schema includes the `FloatVector` type. """ + def setUp(self): self.engine = sa.create_engine("crate://") metadata = sa.MetaData() @@ -68,23 +72,17 @@ def setUp(self): self.session = Session(bind=self.engine) def assertSQL(self, expected_str, actual_expr): - self.assertEqual(expected_str, str(actual_expr).replace('\n', '')) + self.assertEqual(expected_str, str(actual_expr).replace("\n", "")) def test_create_invoke(self): self.table.create(self.engine) fake_cursor.execute.assert_called_with( - ( - "\nCREATE TABLE testdrive (\n\t" - "name STRING, \n\t" - "data FLOAT_VECTOR(3)\n)\n\n" - ), + ("\nCREATE TABLE testdrive (\n\t" "name STRING, \n\t" "data FLOAT_VECTOR(3)\n)\n\n"), (), ) def test_insert_invoke(self): - stmt = self.table.insert().values( - name="foo", data=[42.42, 43.43, 44.44] - ) + stmt = self.table.insert().values(name="foo", data=[42.42, 43.43, 44.44]) with self.engine.connect() as conn: conn.execute(stmt) fake_cursor.execute.assert_called_with( @@ -102,16 +100,16 @@ def test_select_invoke(self): ) def test_sql_select(self): - self.assertSQL( - "SELECT testdrive.data FROM testdrive", select(self.table.c.data) - ) + self.assertSQL("SELECT testdrive.data FROM testdrive", select(self.table.c.data)) def test_sql_match(self): - query = self.session.query(self.table.c.name) \ - .filter(knn_match(self.table.c.data, [42.42, 43.43], 3)) + query = self.session.query(self.table.c.name).filter( + knn_match(self.table.c.data, [42.42, 43.43], 3) + ) self.assertSQL( - "SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)", - query + "SELECT testdrive.name AS testdrive_name " + "FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)", + query, ) @@ -121,8 +119,8 @@ def test_from_db_success(): """ np = pytest.importorskip("numpy") assert from_db(None) is None - assert np.array_equal(from_db(False), np.array(0., dtype=np.float32)) - assert np.array_equal(from_db(True), np.array(1., dtype=np.float32)) + assert np.array_equal(from_db(False), np.array(0.0, dtype=np.float32)) + assert np.array_equal(from_db(True), np.array(1.0, dtype=np.float32)) assert np.array_equal(from_db(42), np.array(42, dtype=np.float32)) assert np.array_equal(from_db(42.42), np.array(42.42, dtype=np.float32)) assert np.array_equal(from_db([42.42, 43.43]), np.array([42.42, 43.43], dtype=np.float32)) @@ -227,7 +225,7 @@ def test_float_vector_integration(cratedb_service): # Define DDL. class SearchIndex(Base): - __tablename__ = 'search' + __tablename__ = "search" name = sa.Column(sa.String, primary_key=True) embedding = sa.Column(FloatVector(3)) @@ -241,8 +239,9 @@ class SearchIndex(Base): session.execute(sa.text("REFRESH TABLE search")) # Query record. - query = session.query(SearchIndex.embedding) \ - .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) + query = session.query(SearchIndex.embedding).filter( + knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3) + ) result = query.first() # Compare outcome. diff --git a/tests/warnings_test.py b/tests/warnings_test.py index b74b8b3..88ec7cc 100644 --- a/tests/warnings_test.py +++ b/tests/warnings_test.py @@ -3,7 +3,7 @@ import warnings from unittest import TestCase, skipIf -from sqlalchemy_cratedb import SA_1_4, SA_VERSION +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION from tests.util import ExtraAssertions @@ -14,14 +14,15 @@ class SqlAlchemyWarningsTest(TestCase, ExtraAssertions): https://docs.python.org/3/library/warnings.html#testing-warnings """ - @skipIf(SA_VERSION >= SA_1_4, "There is no deprecation warning for " - "SQLAlchemy 1.3 on higher versions") + @skipIf( + SA_VERSION >= SA_1_4, + "There is no deprecation warning for " "SQLAlchemy 1.3 on higher versions", + ) def test_sa13_deprecation_warning(self): """ Verify that a `DeprecationWarning` is issued when running SQLAlchemy 1.3. """ with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. warnings.simplefilter("always") @@ -42,23 +43,27 @@ def test_craty_object_deprecation_warning(self): """ with warnings.catch_warnings(record=True) as w: - # Import the deprecated symbol. from sqlalchemy_cratedb.type.object import Craty # noqa: F401 # Verify details of the deprecation warning. self.assertEqual(len(w), 1) self.assertIsSubclass(w[-1].category, DeprecationWarning) - self.assertIn("Craty is deprecated and will be removed in future releases. " - "Please use ObjectType instead.", str(w[-1].message)) + self.assertIn( + "Craty is deprecated and will be removed in future releases. " + "Please use ObjectType instead.", + str(w[-1].message), + ) with warnings.catch_warnings(record=True) as w: - # Import the deprecated symbol. from sqlalchemy_cratedb.type.object import Object # noqa: F401 # Verify details of the deprecation warning. self.assertEqual(len(w), 1) self.assertIsSubclass(w[-1].category, DeprecationWarning) - self.assertIn("Object is deprecated and will be removed in future releases. " - "Please use ObjectType instead.", str(w[-1].message)) + self.assertIn( + "Object is deprecated and will be removed in future releases. " + "Please use ObjectType instead.", + str(w[-1].message), + )