Skip to content

Commit

Permalink
Make it so single-member arrays become objects
Browse files Browse the repository at this point in the history
Among other QoL improvements.

Closes better#25.
  • Loading branch information
ctrlcctrlv committed Sep 14, 2022
1 parent fd1c83e commit c0b7b08
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions jsonschema2db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
import change_case
import csv
import datetime
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(self, schema, database_flavor, postgres_schema=None, debug=False,
self._database_flavor = database_flavor
self._debug = debug
self._table_definitions = {}
self._table_required = {root_table: set()} # keyed by _table_definitions
self._links = {}
self._backlinks = {}
self._postgres_schema = postgres_schema
Expand Down Expand Up @@ -86,12 +88,18 @@ def __init__(self, schema, database_flavor, postgres_schema=None, debug=False,
warnings.warn('Ignoring_column because it is too long: %s.%s' % (table, column))
columns = sorted(col for col in column_types.keys() if 0 < len(col) <= max_column_length)
self._table_columns[table] = columns
self._table_required[table] = set()

def _table_name(self, path):
return '__'.join(change_case.ChangeCase.camel_to_snake(self._abbreviations.get(p, p)) for p in path)
def _table_name(self, path, fallible=False):
try:
return '__'.join(change_case.ChangeCase.camel_to_snake(self._abbreviations.get(p, p)) for p in path)
except Exception as e:
warnings.warn('%s: %s getting table name!' % (path, e))
if not fallible:
raise e

def _column_name(self, path):
return self._table_name(path) # same
def _column_name(self, path, **kwargs):
return self._table_name(path, **kwargs) # same

def _execute(self, cursor, query, args=None, query_ok_to_print=True):
if self._debug and query_ok_to_print:
Expand All @@ -112,9 +120,14 @@ def _traverse(self, schema, tree, path=tuple(), table='root', parent=None, comme

if table not in self._table_definitions:
self._table_definitions[table] = {}
self._table_required[table] = set()
if comment:
self._table_comments[table] = comment

if 'required' in tree:
for req in tree['required']:
self._table_required[table].add(self._column_name(path))

definition = None
new_json_path = json_path
while '$ref' in tree:
Expand Down Expand Up @@ -144,7 +157,22 @@ def _traverse(self, schema, tree, path=tuple(), table='root', parent=None, comme
elif 'type' not in tree:
res = {}
warnings.warn('%s.%s: Type info missing' % (table, self._column_name(path)))
elif tree['type'] == 'object':

if 'type' in tree:
if tree['type'] != 'null' and isinstance(tree['type'], Iterable) and 'null' in tree['type']: # nullable
tree['type'] = Nullable(next((t for t in tree['type'] if t != 'null')))
elif (tpath := self._column_name(path)) is not None and tpath in self._table_required[table]:
tree['type'] = NotNullable(tree['type'])

if tree['type'] == 'array' and 'items' in tree:
if len(tree['items']) == 1:
tree['type'] = Nullable('object')
tree['patternProperties'] = { '[0-9]+': tree['items'] }
del tree['items']
else:
warnings.warn('%s.%s: Arrays with mismatched item types are not yet gracefully handled.' % (table, self._column_name(path)))

if tree['type'] == 'object':
print('object:', tree)
res = {}
if 'patternProperties' in tree:
Expand Down Expand Up @@ -242,7 +270,7 @@ def create_tables(self, con):
:param con: psycopg2 connection object
'''
postgres_types = {'boolean': 'bool', 'number': 'float', 'string': 'text', 'enum': 'text', 'integer': 'bigint', 'timestamp': 'timestamptz', 'date': 'date', 'link': 'integer'}
postgres_types = {'boolean': 'bool', 'number': 'float', 'string': 'text', 'enum': 'text', 'integer': 'bigint', 'timestamp': 'timestamptz', 'date': 'date', 'link': 'bigint'}
with con.cursor() as cursor:
if self._postgres_schema is not None:
self._execute(cursor, 'drop schema if exists %s cascade' % self._postgres_schema)
Expand All @@ -253,7 +281,7 @@ def create_tables(self, con):

create_q = 'create table %s (id %s, "%s" %s not null, "%s" text not null, %s unique ("%s", "%s"), unique (id))' % \
(self._postgres_table_name(table), id_data_type, self._item_col_name, postgres_types[self._item_col_type], self._prefix_col_name,
''.join('"%s" %s, ' % (c, postgres_types[t]) for c, t in zip(columns, types)),
''.join('"%s" %s %s, ' % (c, postgres_types[str(t)], 'not null' if isinstance(t, NotNullable) else 'null' if isinstance(t, Nullable) else '') for c, t in zip(columns, types)),
self._item_col_name, self._prefix_col_name)
self._execute(cursor, create_q)
if table in self._table_comments:
Expand Down Expand Up @@ -442,8 +470,17 @@ def __init__(self, *args, **kwargs):
return super(JSONSchemaToPostgres, self).__init__(*args, **kwargs)



class JSONSchemaToRedshift(JSONSchemaToDatabase):
'''Shorthand for JSONSchemaToDatabase(..., database_flavor='redshift')'''
def __init__(self, *args, **kwargs):
kwargs['database_flavor'] = 'redshift'
return super(JSONSchemaToRedshift, self).__init__(*args, **kwargs)


class Nullable(str):
def __eq__(a, b):
return str(a).__eq__(str(b))

class NotNullable(str):
__eq__ = Nullable.__eq__

0 comments on commit c0b7b08

Please sign in to comment.