Skip to content

Commit

Permalink
Fixed issue #433 - CLI eats cursor (#598)
Browse files Browse the repository at this point in the history
The issue is that underlying iterator is not fully consumed within the body of
the `with file_progress()` block. Instead, that block creates generator
expressions like `docs = (dict(zip(headers, row)) for row in reader)`

These iterables are consumed later, outside the `with file_progress()` block,
which consumes the underlying iterator, and in turn updates the progress bar.

This means that the `ProgressBar.__exit__` method gets called before the last
time the `ProgressBar.update` method gets called. The result is that the code to
make the cursor invisible (inside the `update()` method) is called after the
cleanup code to make it visible (in the `__exit__` method).

The fix is to move consumption of the `docs` iterators within the progress bar block.

(An additional fix, to make ProgressBar more robust against this kind of misuse, would
to make it refusing to update after its `__exit__` method had been called, just
like files cannot be `read()` after they are closed. That requires a in the
click library).
  • Loading branch information
spookylukey authored Nov 4, 2023
1 parent b92ea47 commit 37273d7
Showing 1 changed file with 87 additions and 83 deletions.
170 changes: 87 additions & 83 deletions sqlite_utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,93 +1024,97 @@ def insert_upsert_implementation(
if flatten:
docs = (_flatten(doc) for doc in docs)

if stop_after:
docs = itertools.islice(docs, stop_after)

if convert:
variable = "row"
if lines:
variable = "line"
elif text:
variable = "text"
fn = _compile_code(convert, imports, variable=variable)
if lines:
docs = (fn(doc["line"]) for doc in docs)
elif text:
# Special case: this is allowed to be an iterable
text_value = list(docs)[0]["text"]
fn_return = fn(text_value)
if isinstance(fn_return, dict):
docs = [fn_return]
if stop_after:
docs = itertools.islice(docs, stop_after)

if convert:
variable = "row"
if lines:
variable = "line"
elif text:
variable = "text"
fn = _compile_code(convert, imports, variable=variable)
if lines:
docs = (fn(doc["line"]) for doc in docs)
elif text:
# Special case: this is allowed to be an iterable
text_value = list(docs)[0]["text"]
fn_return = fn(text_value)
if isinstance(fn_return, dict):
docs = [fn_return]
else:
try:
docs = iter(fn_return)
except TypeError:
raise click.ClickException(
"--convert must return dict or iterator"
)
else:
try:
docs = iter(fn_return)
except TypeError:
raise click.ClickException("--convert must return dict or iterator")
else:
docs = (fn(doc) or doc for doc in docs)

extra_kwargs = {
"ignore": ignore,
"replace": replace,
"truncate": truncate,
"analyze": analyze,
}
if not_null:
extra_kwargs["not_null"] = set(not_null)
if default:
extra_kwargs["defaults"] = dict(default)
if upsert:
extra_kwargs["upsert"] = upsert

# docs should all be dictionaries
docs = (verify_is_dict(doc) for doc in docs)

# Apply {"$base64": true, ...} decoding, if needed
docs = (decode_base64_values(doc) for doc in docs)

# For bulk_sql= we use cursor.executemany() instead
if bulk_sql:
if batch_size:
doc_chunks = chunks(docs, batch_size)
else:
doc_chunks = [docs]
for doc_chunk in doc_chunks:
with db.conn:
db.conn.cursor().executemany(bulk_sql, doc_chunk)
return
docs = (fn(doc) or doc for doc in docs)

extra_kwargs = {
"ignore": ignore,
"replace": replace,
"truncate": truncate,
"analyze": analyze,
}
if not_null:
extra_kwargs["not_null"] = set(not_null)
if default:
extra_kwargs["defaults"] = dict(default)
if upsert:
extra_kwargs["upsert"] = upsert

# docs should all be dictionaries
docs = (verify_is_dict(doc) for doc in docs)

# Apply {"$base64": true, ...} decoding, if needed
docs = (decode_base64_values(doc) for doc in docs)

# For bulk_sql= we use cursor.executemany() instead
if bulk_sql:
if batch_size:
doc_chunks = chunks(docs, batch_size)
else:
doc_chunks = [docs]
for doc_chunk in doc_chunks:
with db.conn:
db.conn.cursor().executemany(bulk_sql, doc_chunk)
return

try:
db[table].insert_all(
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
)
except Exception as e:
if (
isinstance(e, OperationalError)
and e.args
and "has no column named" in e.args[0]
):
raise click.ClickException(
"{}\n\nTry using --alter to add additional columns".format(e.args[0])
try:
db[table].insert_all(
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
)
# If we can find sql= and parameters= arguments, show those
variables = _find_variables(e.__traceback__, ["sql", "parameters"])
if "sql" in variables and "parameters" in variables:
raise click.ClickException(
"{}\n\nsql = {}\nparameters = {}".format(
str(e), variables["sql"], variables["parameters"]
except Exception as e:
if (
isinstance(e, OperationalError)
and e.args
and "has no column named" in e.args[0]
):
raise click.ClickException(
"{}\n\nTry using --alter to add additional columns".format(
e.args[0]
)
)
)
else:
raise
if tracker is not None:
db[table].transform(types=tracker.types)

# Clean up open file-like objects
if sniff_buffer:
sniff_buffer.close()
if decoded_buffer:
decoded_buffer.close()
# If we can find sql= and parameters= arguments, show those
variables = _find_variables(e.__traceback__, ["sql", "parameters"])
if "sql" in variables and "parameters" in variables:
raise click.ClickException(
"{}\n\nsql = {}\nparameters = {}".format(
str(e), variables["sql"], variables["parameters"]
)
)
else:
raise
if tracker is not None:
db[table].transform(types=tracker.types)

# Clean up open file-like objects
if sniff_buffer:
sniff_buffer.close()
if decoded_buffer:
decoded_buffer.close()


def _find_variables(tb, vars):
Expand Down

0 comments on commit 37273d7

Please sign in to comment.