Skip to content

Commit

Permalink
added csv_to_parquet function
Browse files Browse the repository at this point in the history
  • Loading branch information
edsu committed Jul 25, 2024
1 parent fe2b6e7 commit 080026d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ dependencies = [
"dimcli",
"polars>=1.2",
"pyalex",
"more-itertools"
"more-itertools",
"pyarrow"
]

[tool.pytest.ini_options]
Expand Down
22 changes: 21 additions & 1 deletion rialto_airflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import csv
import datetime
from pathlib import Path
import re
import sys
from itertools import batched
from pathlib import Path

import pyarrow
from pyarrow.parquet import ParquetWriter


def create_snapshot_dir(data_dir):
Expand Down Expand Up @@ -63,3 +68,18 @@ def normalize_doi(doi):
doi = re.sub("^doi: ", "", doi)

return doi


def csv_to_parquet(csv_file, parquet_file, batch_size=10_000):
csv.field_size_limit(sys.maxsize)

csv_input = open(csv_file)
reader = csv.DictReader(csv_input)

# naively assume all columns are strings
schema = pyarrow.schema([(name, pyarrow.string()) for name in reader.fieldnames])

with ParquetWriter(open(parquet_file, "wb"), schema, compression="zstd") as writer:
for rows in batched(reader, batch_size):
table = pyarrow.Table.from_pylist(rows, schema)
writer.write_table(table)
12 changes: 12 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import pytest
import polars

from rialto_airflow import utils

Expand Down Expand Up @@ -67,3 +68,14 @@ def test_normalize_doi():
== "10.1103/physrevlett.96.07390"
)
assert utils.normalize_doi(" doi: 10.1234/5678 ") == "10.1234/5678"


def test_csv_to_parquet(tmp_path):
csv_file = Path("test/data/authors.csv")
parquet_file = tmp_path / "authors.parquet"
utils.csv_to_parquet(csv_file, parquet_file)

assert parquet_file.is_file()
df = polars.read_parquet(parquet_file)

assert df.shape == (10, 2)

0 comments on commit 080026d

Please sign in to comment.