From 0332889afa4840f5a4786e408d1bb9979c751ea0 Mon Sep 17 00:00:00 2001 From: Robert Forkel Date: Thu, 18 Jan 2024 10:00:35 +0100 Subject: [PATCH] type hints --- src/csvw/__init__.py | 2 +- src/csvw/__main__.py | 2 +- src/csvw/datatypes.py | 30 +++++++++-------- src/csvw/db.py | 78 ++++++++++++++++++++++++++++++++++--------- src/csvw/dsv.py | 26 ++++++++++----- 5 files changed, 99 insertions(+), 39 deletions(-) diff --git a/src/csvw/__init__.py b/src/csvw/__init__.py index 22edb22..216da37 100644 --- a/src/csvw/__init__.py +++ b/src/csvw/__init__.py @@ -24,4 +24,4 @@ __version__ = '3.2.3.dev0' __author__ = 'Robert Forkel' __license__ = 'Apache 2.0, see LICENSE' -__copyright__ = 'Copyright (c) 2023 Robert Forkel' +__copyright__ = 'Copyright (c) 2024 Robert Forkel' diff --git a/src/csvw/__main__.py b/src/csvw/__main__.py index 854a169..2d07b77 100644 --- a/src/csvw/__main__.py +++ b/src/csvw/__main__.py @@ -62,7 +62,7 @@ def csvwdescribe(args=None, test=False): def csvwvalidate(args=None, test=False): init() args = parsed_args( - "Describe a (set of) CSV file(s) with basic CSVW metadata.", + "Validate a (set of) CSV file(s) described by CSVW metadata.", args, (['url'], dict(help='URL or local path to CSV or JSON metadata file.')), (['-v', '--verbose'], dict(action='store_true', default=False)), diff --git a/src/csvw/datatypes.py b/src/csvw/datatypes.py index 7c81d2a..ad17973 100644 --- a/src/csvw/datatypes.py +++ b/src/csvw/datatypes.py @@ -8,16 +8,17 @@ .. seealso:: http://w3c.github.io/csvw/metadata/#datatypes """ -import collections import re import json as _json import math import base64 +import typing import decimal as _decimal import binascii import datetime import warnings import itertools +import collections import isodate import rfc3986 @@ -26,6 +27,9 @@ import babel.dates import jsonschema +if typing.TYPE_CHECKING: # pragma: no cover + import csvw + __all__ = ['DATATYPES'] DATATYPES = {} @@ -62,11 +66,11 @@ class anyAtomicType: def value_error(cls, v): raise ValueError('invalid lexical value for {}: {}'.format(cls.name, v)) - def __str__(self): + def __str__(self) -> str: return self.name @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: return {} @staticmethod @@ -89,7 +93,7 @@ class string(anyAtomicType): name = 'string' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: if datatype.format: # We wrap a regex specified as `format` property into a group and add `$` to # make sure the whole string is matched when validating. @@ -251,7 +255,7 @@ class boolean(anyAtomicType): example = 'false' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: if datatype.format and isinstance(datatype.format, str) and datatype.format.count('|') == 1: true, false = [[v] for v in datatype.format.split('|')] else: @@ -301,7 +305,7 @@ class dateTime(anyAtomicType): example = '2018-12-10T20:20:20' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: return dt_format_and_regex(datatype.format) @staticmethod @@ -364,7 +368,7 @@ class date(dateTime): example = '2018-12-10' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: try: return dt_format_and_regex(datatype.format or 'yyyy-MM-dd') except ValueError: @@ -393,7 +397,7 @@ class dateTimeStamp(dateTime): example = '2018-12-10T20:20:20' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: res = dt_format_and_regex(datatype.format or 'yyyy-MM-ddTHH:mm:ss.SSSSSSXXX') if not res['tz_marker']: raise ValueError('dateTimeStamp must have timezone marker') @@ -409,7 +413,7 @@ class _time(dateTime): example = '20:20:20' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: return dt_format_and_regex(datatype.format or 'HH:mm:ss', no_date=True) @staticmethod @@ -442,7 +446,7 @@ class duration(anyAtomicType): example = 'P3Y6M4DT12H30M5S' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: return {'format': datatype.format} @staticmethod @@ -517,7 +521,7 @@ class decimal(anyAtomicType): _reverse_special = {v: k for k, v in _special.items()} @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: if datatype.format: return datatype.format if isinstance(datatype.format, dict) \ else {'pattern': datatype.format} @@ -820,7 +824,7 @@ class _float(anyAtomicType): example = '5.3' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: if datatype.format: return datatype.format if isinstance(datatype.format, dict) \ else {'pattern': datatype.format} @@ -987,7 +991,7 @@ class json(string): example = '{"a": [1,2]}' @staticmethod - def derived_description(datatype): + def derived_description(datatype: "csvw.Datatype") -> dict: if datatype.format: try: schema = _json.loads(datatype.format) diff --git a/src/csvw/db.py b/src/csvw/db.py index ce13b7a..f1a0534 100644 --- a/src/csvw/db.py +++ b/src/csvw/db.py @@ -69,11 +69,37 @@ def identity(s): } +class SchemaTranslator(typing.Protocol): + def __call__(self, table: str, column: typing.Optional[str] = None) -> str: + ... + + +class ColumnTranslator(typing.Protocol): + def __call__(self, column: str) -> str: + ... + + def quoted(*names): return ','.join('`{0}`'.format(name) for name in names) -def insert(db, translate, table, keys, *rows, **kw): +def insert(db: sqlite3.Connection, + translate: SchemaTranslator, + table: str, + keys: typing.Sequence[str], + *rows: list, + single: typing.Optional[bool] = False): + """ + Insert a sequence of rows into a table. + + :param db: Database connection. + :param translate: Callable translating table and column names to proper schema object names. + :param table: Untranslated table name. + :param keys: Untranslated column names. + :param rows: Sequence of rows to insert. + :param single: Flag signaling whether to insert all rows at once using `executemany` or one at \ + a time, allowing for more focused debugging output in case of errors. + """ if rows: sql = "INSERT INTO {0} ({1}) VALUES ({2})".format( quoted(translate(table)), @@ -82,7 +108,7 @@ def insert(db, translate, table, keys, *rows, **kw): try: db.executemany(sql, rows) except: # noqa: E722 - this is purely for debugging. - if 'single' not in kw: + if not single: for row in rows: insert(db, translate, table, keys, row, single=True) else: @@ -91,14 +117,14 @@ def insert(db, translate, table, keys, *rows, **kw): raise -def select(db, table): +def select(db: sqlite3.Connection, table: str) -> typing.Tuple[typing.List[str], typing.Sequence]: cu = db.execute("SELECT * FROM {0}".format(quoted(table))) cols = [d[0] for d in cu.description] return cols, list(cu.fetchall()) @attr.s -class ColSpec(object): +class ColSpec: """ A `ColSpec` captures sufficient information about a :class:`csvw.Column` for the DB schema. """ @@ -121,12 +147,12 @@ def __attrs_post_init__(self): if self.separator and self.db_type != 'TEXT': self.db_type = 'TEXT' - def check(self, translate): + def check(self, translate: ColumnTranslator) -> typing.Optional[str]: """ We try to convert as many data constraints as possible into SQLite CHECK constraints. - :param translate: - :return: + :param translate: Callable to translate column names between CSVW metadata and DB schema. + :return: A string suitable as argument of an SQL CHECK constraint. """ if not self.csvw: return @@ -156,7 +182,7 @@ def check(self, translate): constraints.append('length(`{0}`) <= {1}'.format(cname, c.maxLength)) return ' AND '.join(constraints) - def sql(self, translate) -> str: + def sql(self, translate: ColumnTranslator) -> str: _check = self.check(translate) return '`{0}` {1}{2}{3}'.format( translate(self.name), @@ -186,11 +212,16 @@ class TableSpec(object): primary_key = attr.ib(default=None) @classmethod - def from_table_metadata(cls, table: csvw.Table, drop_self_referential_fks=True) -> 'TableSpec': + def from_table_metadata(cls, + table: csvw.Table, + drop_self_referential_fks: typing.Optional[bool] = True) -> 'TableSpec': """ Create a `TableSpec` from the schema description of a `csvw.metadata.Table`. :param table: `csvw.metadata.Table` instance. + :param drop_self_referential_fks: Flag signaling whether to drop self-referential foreign \ + keys. This may be necessary, if the order of rows in a CSVW table does not guarantee \ + referential integrity when inserted in order (e.g. an eralier row refering to a later one). :return: `TableSpec` instance. """ spec = cls(name=table.local_name, primary_key=table.tableSchema.primaryKey) @@ -233,6 +264,13 @@ def from_table_metadata(cls, table: csvw.Table, drop_self_referential_fks=True) @classmethod def association_table(cls, atable, apk, btable, bpk) -> 'TableSpec': + """ + List-valued foreignKeys are supported as follows: For each pair of tables related through a + list-valued foreign key, an association table is created. To make it possible to distinguish + multiple list-valued foreign keys between the same two tables, the association table has + a column `context`, which stores the name of the foreign key column from which a row in the + assocation table was created. + """ afk = ColSpec('{0}_{1}'.format(atable, apk)) bfk = ColSpec('{0}_{1}'.format(btable, bpk)) if afk.name == bfk.name: @@ -247,7 +285,7 @@ def association_table(cls, atable, apk, btable, bpk) -> 'TableSpec': ] ) - def sql(self, translate) -> str: + def sql(self, translate: SchemaTranslator) -> str: """ :param translate: :return: The SQL statement to create the table. @@ -266,12 +304,16 @@ def sql(self, translate) -> str: translate(self.name), ',\n '.join(clauses)) -def schema(tg: csvw.TableGroup, drop_self_referential_fks=True) -> typing.List[TableSpec]: +def schema(tg: csvw.TableGroup, + drop_self_referential_fks: typing.Optional[bool] = True) -> typing.List[TableSpec]: """ Convert the table and column descriptions of a `TableGroup` into specifications for the DB schema. - :param ds: + :param tg: CSVW TableGroup. + :param drop_self_referential_fks: Flag signaling whether to drop self-referential foreign \ + keys. This may be necessary, if the order of rows in a CSVW table does not guarantee \ + referential integrity when inserted in order (e.g. an eralier row refering to a later one). :return: A pair (tables, reference_tables). """ tables = {} @@ -324,7 +366,7 @@ def __init__( self, tg: TableGroup, fname: typing.Optional[typing.Union[pathlib.Path, str]] = None, - translate: typing.Optional[typing.Callable] = None, + translate: typing.Optional[SchemaTranslator] = None, drop_self_referential_fks: typing.Optional[bool] = True, ): self.translate = translate or Database.name_translator @@ -347,8 +389,8 @@ def name_translator(table: str, column: typing.Optional[str] = None) -> str: A callable with this signature can be passed into DB creation to control the names of the schema objects. - :param table: Name of the table before translation - :param column: Name of a column of `table` before translation + :param table: CSVW name of the table before translation + :param column: CSVW name of a column of `table` before translation :return: Translated table name if `column is None` else translated column name """ # By default, no translation is done: @@ -361,7 +403,7 @@ def connection(self) -> typing.Union[sqlite3.Connection, contextlib.closing]: self._connection = sqlite3.connect(':memory:') return self._connection - def select_many_to_many(self, db, table, context): + def select_many_to_many(self, db, table, context) -> dict: if context is not None: context_sql = "WHERE context = '{0}'".format(context) else: @@ -393,6 +435,10 @@ def split_value(self, tname, cname, value) -> typing.Union[typing.List[str], str return (value or '').split(sep) if sep else value def read(self) -> typing.Dict[str, typing.List[typing.OrderedDict]]: + """ + :return: A `dict` where keys are SQL table names corresponding to CSVW tables and values \ + are lists of rows, represented as dicts where keys are the SQL column names. + """ res = collections.defaultdict(list) with self.connection() as conn: for tname in self.tg.tabledict: diff --git a/src/csvw/dsv.py b/src/csvw/dsv.py index 8a1a466..9168cb0 100644 --- a/src/csvw/dsv.py +++ b/src/csvw/dsv.py @@ -32,12 +32,14 @@ 'rewrite', 'add_rows', 'filter_rows_as_dict', ] +LINES_OR_PATH = typing.Union[str, pathlib.Path, typing.IO, typing.Iterable[str]] -def normalize_encoding(encoding): + +def normalize_encoding(encoding: str) -> str: return codecs.lookup(encoding).name -class UnicodeWriter(object): +class UnicodeWriter: """ Write Unicode data to a csv file. @@ -122,7 +124,7 @@ def writerows(self, rows: typing.Iterable[typing.Union[tuple, list]]): self.writerow(row) -class UnicodeReader(object): +class UnicodeReader: """ Read Unicode data from a csv file. @@ -145,7 +147,7 @@ class UnicodeReader(object): """ def __init__( self, - f: typing.Union[str, pathlib.Path, typing.IO, typing.Iterable[str]], + f: LINES_OR_PATH, dialect: typing.Optional[typing.Union[Dialect, str]] = None, **kw): self.f = f @@ -333,7 +335,11 @@ def item(self, row): **{self._normalize_fieldname(k): v for k, v in d.items() if k in self.fieldnames}) -def iterrows(lines_or_file, namedtuples=False, dicts=False, encoding='utf-8', **kw): +def iterrows(lines_or_file: LINES_OR_PATH, + namedtuples: typing.Optional[bool] = False, + dicts: typing.Optional[bool] = False, + encoding: typing.Optional[str] = 'utf-8', + **kw) -> typing.Generator: """Convenience factory function for csv reader. :param lines_or_file: Content to be read. Either a file handle, a file path or a list\ @@ -361,7 +367,9 @@ def iterrows(lines_or_file, namedtuples=False, dicts=False, encoding='utf-8', ** reader = iterrows -def rewrite(fname, visitor, **kw): +def rewrite(fname: typing.Union[str, pathlib.Path], + visitor: typing.Callable[[int, typing.List[str]], typing.Union[None, typing.List[str]]], + **kw): """Utility function to rewrite rows in dsv files. :param fname: Path of the dsv file to operate on. @@ -383,7 +391,7 @@ def rewrite(fname, visitor, **kw): shutil.move(str(tmp), str(fname)) # Path.replace is Python 3.3+ -def add_rows(fname, *rows): +def add_rows(fname: typing.Union[str, pathlib.Path], *rows: typing.List[str]): with tempfile.NamedTemporaryFile(delete=False) as fp: tmp = pathlib.Path(fp.name) @@ -397,7 +405,9 @@ def add_rows(fname, *rows): shutil.move(str(tmp), str(fname)) # Path.replace is Python 3.3+ -def filter_rows_as_dict(fname, filter_, **kw): +def filter_rows_as_dict(fname: typing.Union[str, pathlib.Path], + filter_: typing.Callable[[dict], bool], + **kw) -> int: """Rewrite a dsv file, filtering the rows. :param fname: Path to dsv file