Skip to content

Commit

Permalink
[Bridge] Enhance DDS dataset (#1187)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Jun 5, 2024
1 parent be09383 commit 51e9055
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions pgx/bridge_bidding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,21 @@
REDOUBLE_ACTION_NUM = 2
BID_OFFSET_NUM = 3

DDS_RESULTS_TRAIN_URL = "https://drive.google.com/uc?id=1qINu6uIVLJj95oEK3QodsI3aqvOpEozp"
DDS_RESULTS_TEST_URL = "https://drive.google.com/uc?id=1fNPdJTPw03QrxyOgo-7PvVi5kRI_IZST"
DDS_RESULTS_TRAIN_SMALL_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_2.5M.npy"
DDS_RESULTS_TRAIN_LARGE_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_10M.npy"
DDS_RESULTS_TEST_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_500K.npy"


def download_dds_results(download_dir="dds_results"):
"""Download and split the results into 100K chunks."""
os.makedirs(download_dir, exist_ok=True)
train_fname = os.path.join(download_dir, "dds_results_2.5M.npy")
if not os.path.exists(train_fname):
_download(DDS_RESULTS_TRAIN_URL, train_fname)
with open(train_fname, "rb") as f:

def split_data(data, prefix, base_i=0):
with open(data, "rb") as f:
keys, values = jnp.load(f)
n = 100_000
m = keys.shape[0] // n
for i in range(m):
fname = os.path.join(download_dir, f"train_{i:03d}.npy")
fname = os.path.join(download_dir, f"{prefix}_{base_i + i:03d}.npy")
with open(fname, "wb") as f:
print(
f"saving {fname} ... [{i * n}, {(i + 1) * n})",
Expand All @@ -66,27 +65,22 @@ def download_dds_results(download_dir="dds_results"):
),
)

os.makedirs(download_dir, exist_ok=True)

train_small_fname = os.path.join(download_dir, "dds_results_2.5M.npy")
if not os.path.exists(train_small_fname):
_download(DDS_RESULTS_TRAIN_SMALL_URL, train_small_fname)
split_data(train_small_fname, "train")

train_large_fname = os.path.join(download_dir, "dds_results_10M.npy")
if not os.path.exists(train_large_fname):
_download(DDS_RESULTS_TRAIN_LARGE_URL, train_large_fname)
split_data(train_large_fname, "train", base_i=25)

test_fname = os.path.join(download_dir, "dds_results_500K.npy")
if not os.path.exists(test_fname):
_download(DDS_RESULTS_TEST_URL, test_fname)
with open(test_fname, "rb") as f:
keys, values = jnp.load(f)
n = 100_000
m = keys.shape[0] // n
for i in range(m):
fname = os.path.join(download_dir, f"test_{i:03d}.npy")
with open(fname, "wb") as f:
print(
f"saving {fname} ... [{i * n}, {(i + 1) * n})",
file=sys.stderr,
)
jnp.save(
f,
(
keys[i * n : (i + 1) * n],
values[i * n : (i + 1) * n],
),
)
split_data(test_fname, "test")


@dataclass
Expand Down

0 comments on commit 51e9055

Please sign in to comment.