Skip to content

Commit

Permalink
Merge pull request #23 from edsu/black
Browse files Browse the repository at this point in the history
Add black formatting
  • Loading branch information
Florents-Tselai committed Oct 24, 2023
2 parents c938e1b + 3c2712a commit 626f443
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 102 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ jobs:
pip install poetry
poetry install
- name: Check formatting
run: poetry run black --check .

- name: Run tests
run: |
poetry run pytest
40 changes: 27 additions & 13 deletions tests/test_warcdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,54 @@
# all these WARC files were created with wget except for apod.warc.gz which was
# created with browsertrix-crawler

@pytest.mark.parametrize("warc_path", [str(tests_dir / "google.warc"),
str(tests_dir / "google.warc.gz"),
str(tests_dir / "no-warc-info.warc"),
str(tests_dir / "scoop.wacz"),
"https://tselai.com/data/google.warc",
"https://tselai.com/data/google.warc.gz"
])


@pytest.mark.parametrize(
"warc_path",
[
str(tests_dir / "google.warc"),
str(tests_dir / "google.warc.gz"),
str(tests_dir / "no-warc-info.warc"),
str(tests_dir / "scoop.wacz"),
"https://tselai.com/data/google.warc",
"https://tselai.com/data/google.warc.gz",
],
)
def test_import(warc_path):
runner = CliRunner()
args = ["import", db_file, warc_path]
result = runner.invoke(warcdb_cli, args)
assert result.exit_code == 0
db = sqlite_utils.Database(db_file)
assert set(db.table_names()) == {
'metadata', 'request', 'resource', 'response', 'warcinfo', '_sqlite_migrations'
"metadata",
"request",
"resource",
"response",
"warcinfo",
"_sqlite_migrations",
}

if warc_path == str(tests_dir / "google.warc"):
assert db.table('warcinfo').get('<urn:uuid:7ABED2CA-7CBD-48A0-92E5-0059EBFC111A>')
assert db.table('request').get('<urn:uuid:524F62DD-D788-4085-B14D-22B0CDC0AC53>')
assert db.table("warcinfo").get(
"<urn:uuid:7ABED2CA-7CBD-48A0-92E5-0059EBFC111A>"
)
assert db.table("request").get(
"<urn:uuid:524F62DD-D788-4085-B14D-22B0CDC0AC53>"
)

os.remove(db_file)


def test_column_names():
runner = CliRunner()
runner.invoke(warcdb_cli, ["import", db_file, str(pathlib.Path('tests/google.warc'))])
runner.invoke(
warcdb_cli, ["import", db_file, str(pathlib.Path("tests/google.warc"))]
)

# make sure that the columns are named correctly (lowercase with underscores)
db = sqlite_utils.Database(db_file)
for table in db.tables:
for col in table.columns:
assert re.match(r'^[a-z_]+', col.name), f'column {col.name} named correctly'
assert re.match(r"^[a-z_]+", col.name), f"column {col.name} named correctly"

os.remove(db_file)
194 changes: 105 additions & 89 deletions warcdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def dict_union(*args):
""" Utility function to union multiple dicts """
"""Utility function to union multiple dicts"""
# https://stackoverflow.com/a/15936211/1333954
return dict(chain.from_iterable(d.iteritems() for d in args))

Expand All @@ -26,10 +26,10 @@ def dict_union(*args):


def headers_to_json(self):
return dumps([{'header': h, 'value': v} for h, v in self.headers])
return dumps([{"header": h, "value": v} for h, v in self.headers])


setattr(StatusAndHeaders, 'to_json', headers_to_json)
setattr(StatusAndHeaders, "to_json", headers_to_json)

""" Monkeypatch warcio.ArcWarcRecord.payload """

Expand All @@ -39,18 +39,18 @@ def record_payload(self: ArcWarcRecord):
return self.content_stream().read()


setattr(ArcWarcRecord, 'payload', record_payload)
setattr(ArcWarcRecord, "payload", record_payload)

""" Monkeypatch warcio.ArcWarcRecord.as_dict() """


@cache
def record_as_dict(self: ArcWarcRecord):
"""Method to easily represent a record as a dict, to be fed into db_utils.Database.insert()"""
return {k.lower().replace('-', '_'): v for k, v in self.rec_headers.headers}
return {k.lower().replace("-", "_"): v for k, v in self.rec_headers.headers}


setattr(ArcWarcRecord, 'as_dict', record_as_dict)
setattr(ArcWarcRecord, "as_dict", record_as_dict)

""" Monkeypatch warcio.ArcWarcRecord.to_json() """

Expand Down Expand Up @@ -78,8 +78,8 @@ class WarcDB(MutableMapping):

def __init__(self, *args, **kwargs):
# First pop warcdb - specific params
self._batch_size = kwargs.pop('batch_size', 1000)
self._records_table = kwargs.get('records_table', 'records')
self._batch_size = kwargs.pop("batch_size", 1000)
self._records_table = kwargs.get("records_table", "records")

# Pass the rest to sqlite_utils
self._db = sqlite_utils.Database(*args, **kwargs)
Expand All @@ -99,16 +99,16 @@ def records(self):

@property
def http_headers(self):
return self.table('http_headers')
return self.table("http_headers")

@property
def payloads(self):
return self.table('payloads')
return self.table("payloads")

"""MutableMapping abstract methods"""

def __setitem__(self, key, value: ArcWarcRecord):
""" This is the only client-facing way to mutate the file.
"""This is the only client-facing way to mutate the file.
Any normalization should happen here.
"""
# Any normalizations happens here
Expand Down Expand Up @@ -140,103 +140,112 @@ def __iadd__(self, r: ArcWarcRecord):
All 'warcinfo' and 'metadata' records shall not have a payload.
"""
col_type_conversions = {
'content_length': int,
'payload': str,
'warc_date': datetime.datetime,

"content_length": int,
"payload": str,
"warc_date": datetime.datetime,
}
record_dict = r.as_dict()

# Certain rec_types have payload
has_payload = r.rec_type in ['warcinfo', 'request', 'response', 'metadata', 'resource']
has_payload = r.rec_type in [
"warcinfo",
"request",
"response",
"metadata",
"resource",
]
if has_payload:
record_dict['payload'] = r.payload()
record_dict["payload"] = r.payload()

# Certain rec_types have http_headers
has_http_headers = r.http_headers is not None
if has_http_headers:
record_dict['http_headers'] = r.http_headers.to_json()
record_dict["http_headers"] = r.http_headers.to_json()

"""Depending on the record type we insert to appropriate record"""
if r.rec_type == 'warcinfo':

self.db.table('warcinfo').insert(record_dict,
pk='warc_record_id',
alter=True,
ignore=True,
columns=col_type_conversions)
elif r.rec_type == 'request':
self.db.table('request').insert(record_dict,
pk='warc_record_id',
foreign_keys=[
("warc_warcinfo_id", "warcinfo", "warc-record-id")
],
alter=True,
ignore=True,
columns=col_type_conversions
)

elif r.rec_type == 'response':
self.db.table('response').insert(record_dict,
pk='warc_record_id',
foreign_keys=[
("warc_warcinfo_id", "warcinfo", "warc_record_id"),
("warc_concurrent_to", "request", "warc_record_id")
],
alter=True,
ignore=True,
columns=col_type_conversions
)

elif r.rec_type == 'metadata':
self.db.table('metadata').insert(record_dict,
pk='warc_record_id',
foreign_keys=[
("warc-warcinfo-id", "warcinfo", "warc_record_id"),
("warc_concurrent_to", "response", "warc_record_id")
],
alter=True,
ignore=True,
columns=col_type_conversions
)

elif r.rec_type == 'resource':
self.db.table('resource').insert(record_dict,
pk='warc_record_id',
foreign_keys=[
("warc-warcinfo-id", "warcinfo", "warc_record_id"),
("warc_concurrent_to", "metadata", "warc_record_id")
],
alter=True,
ignore=True,
columns=col_type_conversions
)
if r.rec_type == "warcinfo":
self.db.table("warcinfo").insert(
record_dict,
pk="warc_record_id",
alter=True,
ignore=True,
columns=col_type_conversions,
)
elif r.rec_type == "request":
self.db.table("request").insert(
record_dict,
pk="warc_record_id",
foreign_keys=[("warc_warcinfo_id", "warcinfo", "warc-record-id")],
alter=True,
ignore=True,
columns=col_type_conversions,
)

elif r.rec_type == "response":
self.db.table("response").insert(
record_dict,
pk="warc_record_id",
foreign_keys=[
("warc_warcinfo_id", "warcinfo", "warc_record_id"),
("warc_concurrent_to", "request", "warc_record_id"),
],
alter=True,
ignore=True,
columns=col_type_conversions,
)

elif r.rec_type == "metadata":
self.db.table("metadata").insert(
record_dict,
pk="warc_record_id",
foreign_keys=[
("warc-warcinfo-id", "warcinfo", "warc_record_id"),
("warc_concurrent_to", "response", "warc_record_id"),
],
alter=True,
ignore=True,
columns=col_type_conversions,
)

elif r.rec_type == "resource":
self.db.table("resource").insert(
record_dict,
pk="warc_record_id",
foreign_keys=[
("warc-warcinfo-id", "warcinfo", "warc_record_id"),
("warc_concurrent_to", "metadata", "warc_record_id"),
],
alter=True,
ignore=True,
columns=col_type_conversions,
)

else:
raise ValueError(f"Record type <{r.rec_type}> is not supported"
f"Only [warcinfo, request, response, metadata, resource] are.")
raise ValueError(
f"Record type <{r.rec_type}> is not supported"
f"Only [warcinfo, request, response, metadata, resource] are."
)
return self


from sqlite_utils import cli as sqlite_utils_cli

warcdb_cli = sqlite_utils_cli.cli
warcdb_cli.help = \
"Commands for interacting with .warcdb files\n\nBased on SQLite-Utils"
warcdb_cli.help = "Commands for interacting with .warcdb files\n\nBased on SQLite-Utils"


@warcdb_cli.command('import')
@warcdb_cli.command("import")
@click.argument(
"db_path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
)
@click.argument('warc_path',
type=click.STRING,
nargs=-1
)
@click.option('--batch-size',
type=click.INT, default=1000,
help="Batch size for chunked INSERTs [Note: ignored for now]", )
@click.argument("warc_path", type=click.STRING, nargs=-1)
@click.option(
"--batch-size",
type=click.INT,
default=1000,
help="Batch size for chunked INSERTs [Note: ignored for now]",
)
def import_(db_path, warc_path, batch_size):
"""
Import a WARC file into the database
Expand All @@ -251,16 +260,23 @@ def import_(db_path, warc_path, batch_size):

def to_import():
for f in always_iterable(warc_path):
if f.startswith('http'):
yield from tqdm(ArchiveIterator(req.get(f, stream=True).raw, arc2warc=True), desc=f)
elif f.endswith('.wacz'):
if f.startswith("http"):
yield from tqdm(
ArchiveIterator(req.get(f, stream=True).raw, arc2warc=True), desc=f
)
elif f.endswith(".wacz"):
# TODO: can we support loading WACZ files by URL?
wacz = zipfile.ZipFile(f)
warcs = filter(lambda f: f.filename.endswith('warc.gz'), wacz.infolist())
warcs = filter(
lambda f: f.filename.endswith("warc.gz"), wacz.infolist()
)
for warc in warcs:
yield from tqdm(ArchiveIterator(wacz.open(warc.filename, 'r'), arc2warc=True), desc=warc.filename)
yield from tqdm(
ArchiveIterator(wacz.open(warc.filename, "r"), arc2warc=True),
desc=warc.filename,
)
else:
yield from tqdm(ArchiveIterator(open(f, 'rb'), arc2warc=True), desc=f)
yield from tqdm(ArchiveIterator(open(f, "rb"), arc2warc=True), desc=f)

for r in to_import():
db += r

0 comments on commit 626f443

Please sign in to comment.