diff --git a/pgx/bridge_bidding.py b/pgx/bridge_bidding.py index c4c622799..c9639671b 100644 --- a/pgx/bridge_bidding.py +++ b/pgx/bridge_bidding.py @@ -37,22 +37,21 @@ REDOUBLE_ACTION_NUM = 2 BID_OFFSET_NUM = 3 -DDS_RESULTS_TRAIN_URL = "https://drive.google.com/uc?id=1qINu6uIVLJj95oEK3QodsI3aqvOpEozp" -DDS_RESULTS_TEST_URL = "https://drive.google.com/uc?id=1fNPdJTPw03QrxyOgo-7PvVi5kRI_IZST" +DDS_RESULTS_TRAIN_SMALL_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_2.5M.npy" +DDS_RESULTS_TRAIN_LARGE_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_10M.npy" +DDS_RESULTS_TEST_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_500K.npy" def download_dds_results(download_dir="dds_results"): """Download and split the results into 100K chunks.""" - os.makedirs(download_dir, exist_ok=True) - train_fname = os.path.join(download_dir, "dds_results_2.5M.npy") - if not os.path.exists(train_fname): - _download(DDS_RESULTS_TRAIN_URL, train_fname) - with open(train_fname, "rb") as f: + + def split_data(data, prefix, base_i=0): + with open(data, "rb") as f: keys, values = jnp.load(f) n = 100_000 m = keys.shape[0] // n for i in range(m): - fname = os.path.join(download_dir, f"train_{i:03d}.npy") + fname = os.path.join(download_dir, f"{prefix}_{base_i + i:03d}.npy") with open(fname, "wb") as f: print( f"saving {fname} ... [{i * n}, {(i + 1) * n})", @@ -66,27 +65,22 @@ def download_dds_results(download_dir="dds_results"): ), ) + os.makedirs(download_dir, exist_ok=True) + + train_small_fname = os.path.join(download_dir, "dds_results_2.5M.npy") + if not os.path.exists(train_small_fname): + _download(DDS_RESULTS_TRAIN_SMALL_URL, train_small_fname) + split_data(train_small_fname, "train") + + train_large_fname = os.path.join(download_dir, "dds_results_10M.npy") + if not os.path.exists(train_large_fname): + _download(DDS_RESULTS_TRAIN_LARGE_URL, train_large_fname) + split_data(train_large_fname, "train", base_i=25) + test_fname = os.path.join(download_dir, "dds_results_500K.npy") if not os.path.exists(test_fname): _download(DDS_RESULTS_TEST_URL, test_fname) - with open(test_fname, "rb") as f: - keys, values = jnp.load(f) - n = 100_000 - m = keys.shape[0] // n - for i in range(m): - fname = os.path.join(download_dir, f"test_{i:03d}.npy") - with open(fname, "wb") as f: - print( - f"saving {fname} ... [{i * n}, {(i + 1) * n})", - file=sys.stderr, - ) - jnp.save( - f, - ( - keys[i * n : (i + 1) * n], - values[i * n : (i + 1) * n], - ), - ) + split_data(test_fname, "test") @dataclass