Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it so single-member arrays become objects #51

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions jsonschema2db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
import change_case
import csv
import datetime
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(self, schema, database_flavor, postgres_schema=None, debug=False,
self._database_flavor = database_flavor
self._debug = debug
self._table_definitions = {}
self._table_required = {root_table: set()} # keyed by _table_definitions
self._links = {}
self._backlinks = {}
self._postgres_schema = postgres_schema
Expand Down Expand Up @@ -86,12 +88,18 @@ def __init__(self, schema, database_flavor, postgres_schema=None, debug=False,
warnings.warn('Ignoring_column because it is too long: %s.%s' % (table, column))
columns = sorted(col for col in column_types.keys() if 0 < len(col) <= max_column_length)
self._table_columns[table] = columns
self._table_required[table] = set()

def _table_name(self, path):
return '__'.join(change_case.ChangeCase.camel_to_snake(self._abbreviations.get(p, p)) for p in path)
def _table_name(self, path, fallible=False):
try:
return '__'.join(change_case.ChangeCase.camel_to_snake(self._abbreviations.get(p, p)) for p in path)
except Exception as e:
warnings.warn('%s: %s getting table name!' % (path, e))
if not fallible:
raise e

def _column_name(self, path):
return self._table_name(path) # same
def _column_name(self, path, **kwargs):
return self._table_name(path, **kwargs) # same

def _execute(self, cursor, query, args=None, query_ok_to_print=True):
if self._debug and query_ok_to_print:
Expand All @@ -112,9 +120,14 @@ def _traverse(self, schema, tree, path=tuple(), table='root', parent=None, comme

if table not in self._table_definitions:
self._table_definitions[table] = {}
self._table_required[table] = set()
if comment:
self._table_comments[table] = comment

if 'required' in tree:
for req in tree['required']:
self._table_required[table].add(self._column_name(path))

definition = None
new_json_path = json_path
while '$ref' in tree:
Expand Down Expand Up @@ -144,7 +157,22 @@ def _traverse(self, schema, tree, path=tuple(), table='root', parent=None, comme
elif 'type' not in tree:
res = {}
warnings.warn('%s.%s: Type info missing' % (table, self._column_name(path)))
elif tree['type'] == 'object':

if 'type' in tree:
if tree['type'] != 'null' and isinstance(tree['type'], Iterable) and 'null' in tree['type']: # nullable
tree['type'] = Nullable(next((t for t in tree['type'] if t != 'null')))
elif (tpath := self._column_name(path)) is not None and tpath in self._table_required[table]:
tree['type'] = NotNullable(tree['type'])

if tree['type'] == 'array' and 'items' in tree:
if len(tree['items']) == 1:
tree['type'] = Nullable('object')
tree['patternProperties'] = { '[0-9]+': tree['items'] }
del tree['items']
else:
warnings.warn('%s.%s: Arrays with mismatched item types are not yet gracefully handled.' % (table, self._column_name(path)))

if tree['type'] == 'object':
print('object:', tree)
res = {}
if 'patternProperties' in tree:
Expand Down Expand Up @@ -242,7 +270,7 @@ def create_tables(self, con):

:param con: psycopg2 connection object
'''
postgres_types = {'boolean': 'bool', 'number': 'float', 'string': 'text', 'enum': 'text', 'integer': 'bigint', 'timestamp': 'timestamptz', 'date': 'date', 'link': 'integer'}
postgres_types = {'boolean': 'bool', 'number': 'float', 'string': 'text', 'enum': 'text', 'integer': 'bigint', 'timestamp': 'timestamptz', 'date': 'date', 'link': 'bigint'}
with con.cursor() as cursor:
if self._postgres_schema is not None:
self._execute(cursor, 'drop schema if exists %s cascade' % self._postgres_schema)
Expand All @@ -253,7 +281,7 @@ def create_tables(self, con):

create_q = 'create table %s (id %s, "%s" %s not null, "%s" text not null, %s unique ("%s", "%s"), unique (id))' % \
(self._postgres_table_name(table), id_data_type, self._item_col_name, postgres_types[self._item_col_type], self._prefix_col_name,
''.join('"%s" %s, ' % (c, postgres_types[t]) for c, t in zip(columns, types)),
''.join('"%s" %s %s, ' % (c, postgres_types[str(t)], 'not null' if isinstance(t, NotNullable) else 'null' if isinstance(t, Nullable) else '') for c, t in zip(columns, types)),
self._item_col_name, self._prefix_col_name)
self._execute(cursor, create_q)
if table in self._table_comments:
Expand Down Expand Up @@ -442,8 +470,17 @@ def __init__(self, *args, **kwargs):
return super(JSONSchemaToPostgres, self).__init__(*args, **kwargs)



class JSONSchemaToRedshift(JSONSchemaToDatabase):
'''Shorthand for JSONSchemaToDatabase(..., database_flavor='redshift')'''
def __init__(self, *args, **kwargs):
kwargs['database_flavor'] = 'redshift'
return super(JSONSchemaToRedshift, self).__init__(*args, **kwargs)


class Nullable(str):
def __eq__(a, b):
return str(a).__eq__(str(b))

class NotNullable(str):
__eq__ = Nullable.__eq__