Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Jun 3, 2024
1 parent be09383 commit 57ccb5d
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions pgx/bridge_bidding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,14 @@

def download_dds_results(download_dir="dds_results"):
"""Download and split the results into 100K chunks."""
os.makedirs(download_dir, exist_ok=True)
train_fname = os.path.join(download_dir, "dds_results_2.5M.npy")
if not os.path.exists(train_fname):
_download(DDS_RESULTS_TRAIN_URL, train_fname)
with open(train_fname, "rb") as f:

def split_data(data, prefix, base_i=0):
with open(data, "rb") as f:

Check warning on line 48 in pgx/bridge_bidding.py

View check run for this annotation

Codecov / codecov/patch

pgx/bridge_bidding.py#L47-L48

Added lines #L47 - L48 were not covered by tests
keys, values = jnp.load(f)
n = 100_000
m = keys.shape[0] // n
for i in range(m):
fname = os.path.join(download_dir, f"train_{i:03d}.npy")
fname = os.path.join(download_dir, f"{prefix}_{i:03d}.npy")

Check warning on line 53 in pgx/bridge_bidding.py

View check run for this annotation

Codecov / codecov/patch

pgx/bridge_bidding.py#L53

Added line #L53 was not covered by tests
with open(fname, "wb") as f:
print(
f"saving {fname} ... [{i * n}, {(i + 1) * n})",
Expand All @@ -66,27 +64,17 @@ def download_dds_results(download_dir="dds_results"):
),
)

os.makedirs(download_dir, exist_ok=True)

Check warning on line 67 in pgx/bridge_bidding.py

View check run for this annotation

Codecov / codecov/patch

pgx/bridge_bidding.py#L67

Added line #L67 was not covered by tests

train_fname = os.path.join(download_dir, "dds_results_2.5M.npy")
if not os.path.exists(train_fname):
_download(DDS_RESULTS_TRAIN_URL, train_fname)
split_data(train_fname, "train")

Check warning on line 72 in pgx/bridge_bidding.py

View check run for this annotation

Codecov / codecov/patch

pgx/bridge_bidding.py#L69-L72

Added lines #L69 - L72 were not covered by tests

test_fname = os.path.join(download_dir, "dds_results_500K.npy")
if not os.path.exists(test_fname):
_download(DDS_RESULTS_TEST_URL, test_fname)
with open(test_fname, "rb") as f:
keys, values = jnp.load(f)
n = 100_000
m = keys.shape[0] // n
for i in range(m):
fname = os.path.join(download_dir, f"test_{i:03d}.npy")
with open(fname, "wb") as f:
print(
f"saving {fname} ... [{i * n}, {(i + 1) * n})",
file=sys.stderr,
)
jnp.save(
f,
(
keys[i * n : (i + 1) * n],
values[i * n : (i + 1) * n],
),
)
split_data(test_fname, "test")

Check warning on line 77 in pgx/bridge_bidding.py

View check run for this annotation

Codecov / codecov/patch

pgx/bridge_bidding.py#L77

Added line #L77 was not covered by tests


@dataclass
Expand Down

0 comments on commit 57ccb5d

Please sign in to comment.