diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index c37e83b1..5821db64 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -1024,93 +1024,97 @@ def insert_upsert_implementation( if flatten: docs = (_flatten(doc) for doc in docs) - if stop_after: - docs = itertools.islice(docs, stop_after) - - if convert: - variable = "row" - if lines: - variable = "line" - elif text: - variable = "text" - fn = _compile_code(convert, imports, variable=variable) - if lines: - docs = (fn(doc["line"]) for doc in docs) - elif text: - # Special case: this is allowed to be an iterable - text_value = list(docs)[0]["text"] - fn_return = fn(text_value) - if isinstance(fn_return, dict): - docs = [fn_return] + if stop_after: + docs = itertools.islice(docs, stop_after) + + if convert: + variable = "row" + if lines: + variable = "line" + elif text: + variable = "text" + fn = _compile_code(convert, imports, variable=variable) + if lines: + docs = (fn(doc["line"]) for doc in docs) + elif text: + # Special case: this is allowed to be an iterable + text_value = list(docs)[0]["text"] + fn_return = fn(text_value) + if isinstance(fn_return, dict): + docs = [fn_return] + else: + try: + docs = iter(fn_return) + except TypeError: + raise click.ClickException( + "--convert must return dict or iterator" + ) else: - try: - docs = iter(fn_return) - except TypeError: - raise click.ClickException("--convert must return dict or iterator") - else: - docs = (fn(doc) or doc for doc in docs) - - extra_kwargs = { - "ignore": ignore, - "replace": replace, - "truncate": truncate, - "analyze": analyze, - } - if not_null: - extra_kwargs["not_null"] = set(not_null) - if default: - extra_kwargs["defaults"] = dict(default) - if upsert: - extra_kwargs["upsert"] = upsert - - # docs should all be dictionaries - docs = (verify_is_dict(doc) for doc in docs) - - # Apply {"$base64": true, ...} decoding, if needed - docs = (decode_base64_values(doc) for doc in docs) - - # For bulk_sql= we use cursor.executemany() instead - if bulk_sql: - if batch_size: - doc_chunks = chunks(docs, batch_size) - else: - doc_chunks = [docs] - for doc_chunk in doc_chunks: - with db.conn: - db.conn.cursor().executemany(bulk_sql, doc_chunk) - return + docs = (fn(doc) or doc for doc in docs) + + extra_kwargs = { + "ignore": ignore, + "replace": replace, + "truncate": truncate, + "analyze": analyze, + } + if not_null: + extra_kwargs["not_null"] = set(not_null) + if default: + extra_kwargs["defaults"] = dict(default) + if upsert: + extra_kwargs["upsert"] = upsert + + # docs should all be dictionaries + docs = (verify_is_dict(doc) for doc in docs) + + # Apply {"$base64": true, ...} decoding, if needed + docs = (decode_base64_values(doc) for doc in docs) + + # For bulk_sql= we use cursor.executemany() instead + if bulk_sql: + if batch_size: + doc_chunks = chunks(docs, batch_size) + else: + doc_chunks = [docs] + for doc_chunk in doc_chunks: + with db.conn: + db.conn.cursor().executemany(bulk_sql, doc_chunk) + return - try: - db[table].insert_all( - docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs - ) - except Exception as e: - if ( - isinstance(e, OperationalError) - and e.args - and "has no column named" in e.args[0] - ): - raise click.ClickException( - "{}\n\nTry using --alter to add additional columns".format(e.args[0]) + try: + db[table].insert_all( + docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs ) - # If we can find sql= and parameters= arguments, show those - variables = _find_variables(e.__traceback__, ["sql", "parameters"]) - if "sql" in variables and "parameters" in variables: - raise click.ClickException( - "{}\n\nsql = {}\nparameters = {}".format( - str(e), variables["sql"], variables["parameters"] + except Exception as e: + if ( + isinstance(e, OperationalError) + and e.args + and "has no column named" in e.args[0] + ): + raise click.ClickException( + "{}\n\nTry using --alter to add additional columns".format( + e.args[0] + ) ) - ) - else: - raise - if tracker is not None: - db[table].transform(types=tracker.types) - - # Clean up open file-like objects - if sniff_buffer: - sniff_buffer.close() - if decoded_buffer: - decoded_buffer.close() + # If we can find sql= and parameters= arguments, show those + variables = _find_variables(e.__traceback__, ["sql", "parameters"]) + if "sql" in variables and "parameters" in variables: + raise click.ClickException( + "{}\n\nsql = {}\nparameters = {}".format( + str(e), variables["sql"], variables["parameters"] + ) + ) + else: + raise + if tracker is not None: + db[table].transform(types=tracker.types) + + # Clean up open file-like objects + if sniff_buffer: + sniff_buffer.close() + if decoded_buffer: + decoded_buffer.close() def _find_variables(tb, vars):