Skip to content

Commit

Permalink
type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
xrotwang committed Jan 18, 2024
1 parent c796e4f commit 0332889
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/csvw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
__version__ = '3.2.3.dev0'
__author__ = 'Robert Forkel'
__license__ = 'Apache 2.0, see LICENSE'
__copyright__ = 'Copyright (c) 2023 Robert Forkel'
__copyright__ = 'Copyright (c) 2024 Robert Forkel'
2 changes: 1 addition & 1 deletion src/csvw/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def csvwdescribe(args=None, test=False):
def csvwvalidate(args=None, test=False):
init()
args = parsed_args(
"Describe a (set of) CSV file(s) with basic CSVW metadata.",
"Validate a (set of) CSV file(s) described by CSVW metadata.",
args,
(['url'], dict(help='URL or local path to CSV or JSON metadata file.')),
(['-v', '--verbose'], dict(action='store_true', default=False)),
Expand Down
30 changes: 17 additions & 13 deletions src/csvw/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
.. seealso:: http://w3c.github.io/csvw/metadata/#datatypes
"""
import collections
import re
import json as _json
import math
import base64
import typing
import decimal as _decimal
import binascii
import datetime
import warnings
import itertools
import collections

import isodate
import rfc3986
Expand All @@ -26,6 +27,9 @@
import babel.dates
import jsonschema

if typing.TYPE_CHECKING: # pragma: no cover
import csvw

__all__ = ['DATATYPES']

DATATYPES = {}
Expand Down Expand Up @@ -62,11 +66,11 @@ class anyAtomicType:
def value_error(cls, v):
raise ValueError('invalid lexical value for {}: {}'.format(cls.name, v))

def __str__(self):
def __str__(self) -> str:
return self.name

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
return {}

@staticmethod
Expand All @@ -89,7 +93,7 @@ class string(anyAtomicType):
name = 'string'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
if datatype.format:
# We wrap a regex specified as `format` property into a group and add `$` to
# make sure the whole string is matched when validating.
Expand Down Expand Up @@ -251,7 +255,7 @@ class boolean(anyAtomicType):
example = 'false'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
if datatype.format and isinstance(datatype.format, str) and datatype.format.count('|') == 1:
true, false = [[v] for v in datatype.format.split('|')]
else:
Expand Down Expand Up @@ -301,7 +305,7 @@ class dateTime(anyAtomicType):
example = '2018-12-10T20:20:20'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
return dt_format_and_regex(datatype.format)

@staticmethod
Expand Down Expand Up @@ -364,7 +368,7 @@ class date(dateTime):
example = '2018-12-10'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
try:
return dt_format_and_regex(datatype.format or 'yyyy-MM-dd')
except ValueError:
Expand Down Expand Up @@ -393,7 +397,7 @@ class dateTimeStamp(dateTime):
example = '2018-12-10T20:20:20'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
res = dt_format_and_regex(datatype.format or 'yyyy-MM-ddTHH:mm:ss.SSSSSSXXX')
if not res['tz_marker']:
raise ValueError('dateTimeStamp must have timezone marker')
Expand All @@ -409,7 +413,7 @@ class _time(dateTime):
example = '20:20:20'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
return dt_format_and_regex(datatype.format or 'HH:mm:ss', no_date=True)

@staticmethod
Expand Down Expand Up @@ -442,7 +446,7 @@ class duration(anyAtomicType):
example = 'P3Y6M4DT12H30M5S'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
return {'format': datatype.format}

@staticmethod
Expand Down Expand Up @@ -517,7 +521,7 @@ class decimal(anyAtomicType):
_reverse_special = {v: k for k, v in _special.items()}

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
if datatype.format:
return datatype.format if isinstance(datatype.format, dict) \
else {'pattern': datatype.format}
Expand Down Expand Up @@ -820,7 +824,7 @@ class _float(anyAtomicType):
example = '5.3'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
if datatype.format:
return datatype.format if isinstance(datatype.format, dict) \
else {'pattern': datatype.format}
Expand Down Expand Up @@ -987,7 +991,7 @@ class json(string):
example = '{"a": [1,2]}'

@staticmethod
def derived_description(datatype):
def derived_description(datatype: "csvw.Datatype") -> dict:
if datatype.format:
try:
schema = _json.loads(datatype.format)
Expand Down
78 changes: 62 additions & 16 deletions src/csvw/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,37 @@ def identity(s):
}


class SchemaTranslator(typing.Protocol):
def __call__(self, table: str, column: typing.Optional[str] = None) -> str:
...


class ColumnTranslator(typing.Protocol):
def __call__(self, column: str) -> str:
...


def quoted(*names):
return ','.join('`{0}`'.format(name) for name in names)


def insert(db, translate, table, keys, *rows, **kw):
def insert(db: sqlite3.Connection,
translate: SchemaTranslator,
table: str,
keys: typing.Sequence[str],
*rows: list,
single: typing.Optional[bool] = False):
"""
Insert a sequence of rows into a table.
:param db: Database connection.
:param translate: Callable translating table and column names to proper schema object names.
:param table: Untranslated table name.
:param keys: Untranslated column names.
:param rows: Sequence of rows to insert.
:param single: Flag signaling whether to insert all rows at once using `executemany` or one at \
a time, allowing for more focused debugging output in case of errors.
"""
if rows:
sql = "INSERT INTO {0} ({1}) VALUES ({2})".format(
quoted(translate(table)),
Expand All @@ -82,7 +108,7 @@ def insert(db, translate, table, keys, *rows, **kw):
try:
db.executemany(sql, rows)
except: # noqa: E722 - this is purely for debugging.
if 'single' not in kw:
if not single:
for row in rows:
insert(db, translate, table, keys, row, single=True)
else:
Expand All @@ -91,14 +117,14 @@ def insert(db, translate, table, keys, *rows, **kw):
raise


def select(db, table):
def select(db: sqlite3.Connection, table: str) -> typing.Tuple[typing.List[str], typing.Sequence]:
cu = db.execute("SELECT * FROM {0}".format(quoted(table)))
cols = [d[0] for d in cu.description]
return cols, list(cu.fetchall())


@attr.s
class ColSpec(object):
class ColSpec:
"""
A `ColSpec` captures sufficient information about a :class:`csvw.Column` for the DB schema.
"""
Expand All @@ -121,12 +147,12 @@ def __attrs_post_init__(self):
if self.separator and self.db_type != 'TEXT':
self.db_type = 'TEXT'

def check(self, translate):
def check(self, translate: ColumnTranslator) -> typing.Optional[str]:
"""
We try to convert as many data constraints as possible into SQLite CHECK constraints.
:param translate:
:return:
:param translate: Callable to translate column names between CSVW metadata and DB schema.
:return: A string suitable as argument of an SQL CHECK constraint.
"""
if not self.csvw:
return
Expand Down Expand Up @@ -156,7 +182,7 @@ def check(self, translate):
constraints.append('length(`{0}`) <= {1}'.format(cname, c.maxLength))
return ' AND '.join(constraints)

def sql(self, translate) -> str:
def sql(self, translate: ColumnTranslator) -> str:
_check = self.check(translate)
return '`{0}` {1}{2}{3}'.format(
translate(self.name),
Expand Down Expand Up @@ -186,11 +212,16 @@ class TableSpec(object):
primary_key = attr.ib(default=None)

@classmethod
def from_table_metadata(cls, table: csvw.Table, drop_self_referential_fks=True) -> 'TableSpec':
def from_table_metadata(cls,
table: csvw.Table,
drop_self_referential_fks: typing.Optional[bool] = True) -> 'TableSpec':
"""
Create a `TableSpec` from the schema description of a `csvw.metadata.Table`.
:param table: `csvw.metadata.Table` instance.
:param drop_self_referential_fks: Flag signaling whether to drop self-referential foreign \
keys. This may be necessary, if the order of rows in a CSVW table does not guarantee \
referential integrity when inserted in order (e.g. an eralier row refering to a later one).
:return: `TableSpec` instance.
"""
spec = cls(name=table.local_name, primary_key=table.tableSchema.primaryKey)
Expand Down Expand Up @@ -233,6 +264,13 @@ def from_table_metadata(cls, table: csvw.Table, drop_self_referential_fks=True)

@classmethod
def association_table(cls, atable, apk, btable, bpk) -> 'TableSpec':
"""
List-valued foreignKeys are supported as follows: For each pair of tables related through a
list-valued foreign key, an association table is created. To make it possible to distinguish
multiple list-valued foreign keys between the same two tables, the association table has
a column `context`, which stores the name of the foreign key column from which a row in the
assocation table was created.
"""
afk = ColSpec('{0}_{1}'.format(atable, apk))
bfk = ColSpec('{0}_{1}'.format(btable, bpk))
if afk.name == bfk.name:
Expand All @@ -247,7 +285,7 @@ def association_table(cls, atable, apk, btable, bpk) -> 'TableSpec':
]
)

def sql(self, translate) -> str:
def sql(self, translate: SchemaTranslator) -> str:
"""
:param translate:
:return: The SQL statement to create the table.
Expand All @@ -266,12 +304,16 @@ def sql(self, translate) -> str:
translate(self.name), ',\n '.join(clauses))


def schema(tg: csvw.TableGroup, drop_self_referential_fks=True) -> typing.List[TableSpec]:
def schema(tg: csvw.TableGroup,
drop_self_referential_fks: typing.Optional[bool] = True) -> typing.List[TableSpec]:
"""
Convert the table and column descriptions of a `TableGroup` into specifications for the
DB schema.
:param ds:
:param tg: CSVW TableGroup.
:param drop_self_referential_fks: Flag signaling whether to drop self-referential foreign \
keys. This may be necessary, if the order of rows in a CSVW table does not guarantee \
referential integrity when inserted in order (e.g. an eralier row refering to a later one).
:return: A pair (tables, reference_tables).
"""
tables = {}
Expand Down Expand Up @@ -324,7 +366,7 @@ def __init__(
self,
tg: TableGroup,
fname: typing.Optional[typing.Union[pathlib.Path, str]] = None,
translate: typing.Optional[typing.Callable] = None,
translate: typing.Optional[SchemaTranslator] = None,
drop_self_referential_fks: typing.Optional[bool] = True,
):
self.translate = translate or Database.name_translator
Expand All @@ -347,8 +389,8 @@ def name_translator(table: str, column: typing.Optional[str] = None) -> str:
A callable with this signature can be passed into DB creation to control the names
of the schema objects.
:param table: Name of the table before translation
:param column: Name of a column of `table` before translation
:param table: CSVW name of the table before translation
:param column: CSVW name of a column of `table` before translation
:return: Translated table name if `column is None` else translated column name
"""
# By default, no translation is done:
Expand All @@ -361,7 +403,7 @@ def connection(self) -> typing.Union[sqlite3.Connection, contextlib.closing]:
self._connection = sqlite3.connect(':memory:')
return self._connection

def select_many_to_many(self, db, table, context):
def select_many_to_many(self, db, table, context) -> dict:
if context is not None:
context_sql = "WHERE context = '{0}'".format(context)
else:
Expand Down Expand Up @@ -393,6 +435,10 @@ def split_value(self, tname, cname, value) -> typing.Union[typing.List[str], str
return (value or '').split(sep) if sep else value

def read(self) -> typing.Dict[str, typing.List[typing.OrderedDict]]:
"""
:return: A `dict` where keys are SQL table names corresponding to CSVW tables and values \
are lists of rows, represented as dicts where keys are the SQL column names.
"""
res = collections.defaultdict(list)
with self.connection() as conn:
for tname in self.tg.tabledict:
Expand Down
Loading

0 comments on commit 0332889

Please sign in to comment.