From 57ccb5dd468c4327046d804ef007a20869fc4b1f Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Mon, 3 Jun 2024 15:45:27 +0900 Subject: [PATCH 1/3] refactor --- pgx/bridge_bidding.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/pgx/bridge_bidding.py b/pgx/bridge_bidding.py index c4c622799..268f62989 100644 --- a/pgx/bridge_bidding.py +++ b/pgx/bridge_bidding.py @@ -43,16 +43,14 @@ def download_dds_results(download_dir="dds_results"): """Download and split the results into 100K chunks.""" - os.makedirs(download_dir, exist_ok=True) - train_fname = os.path.join(download_dir, "dds_results_2.5M.npy") - if not os.path.exists(train_fname): - _download(DDS_RESULTS_TRAIN_URL, train_fname) - with open(train_fname, "rb") as f: + + def split_data(data, prefix, base_i=0): + with open(data, "rb") as f: keys, values = jnp.load(f) n = 100_000 m = keys.shape[0] // n for i in range(m): - fname = os.path.join(download_dir, f"train_{i:03d}.npy") + fname = os.path.join(download_dir, f"{prefix}_{i:03d}.npy") with open(fname, "wb") as f: print( f"saving {fname} ... [{i * n}, {(i + 1) * n})", @@ -66,27 +64,17 @@ def download_dds_results(download_dir="dds_results"): ), ) + os.makedirs(download_dir, exist_ok=True) + + train_fname = os.path.join(download_dir, "dds_results_2.5M.npy") + if not os.path.exists(train_fname): + _download(DDS_RESULTS_TRAIN_URL, train_fname) + split_data(train_fname, "train") + test_fname = os.path.join(download_dir, "dds_results_500K.npy") if not os.path.exists(test_fname): _download(DDS_RESULTS_TEST_URL, test_fname) - with open(test_fname, "rb") as f: - keys, values = jnp.load(f) - n = 100_000 - m = keys.shape[0] // n - for i in range(m): - fname = os.path.join(download_dir, f"test_{i:03d}.npy") - with open(fname, "wb") as f: - print( - f"saving {fname} ... [{i * n}, {(i + 1) * n})", - file=sys.stderr, - ) - jnp.save( - f, - ( - keys[i * n : (i + 1) * n], - values[i * n : (i + 1) * n], - ), - ) + split_data(test_fname, "test") @dataclass From 7b9daf2fc267e8156934792d769910e2c7c273e0 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Wed, 5 Jun 2024 14:44:23 +0900 Subject: [PATCH 2/3] fix --- pgx/bridge_bidding.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pgx/bridge_bidding.py b/pgx/bridge_bidding.py index 268f62989..678db2a13 100644 --- a/pgx/bridge_bidding.py +++ b/pgx/bridge_bidding.py @@ -37,8 +37,9 @@ REDOUBLE_ACTION_NUM = 2 BID_OFFSET_NUM = 3 -DDS_RESULTS_TRAIN_URL = "https://drive.google.com/uc?id=1qINu6uIVLJj95oEK3QodsI3aqvOpEozp" -DDS_RESULTS_TEST_URL = "https://drive.google.com/uc?id=1fNPdJTPw03QrxyOgo-7PvVi5kRI_IZST" +DDS_RESULTS_TRAIN_SMALL_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_2.5M.npy" +DDS_RESULTS_TRAIN_LARGE_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_10M.npy" +DDS_RESULTS_TEST_URL = "https://huggingface.co/datasets/sotetsuk/dds_dataset/resolve/main/dds_results_500K.npy" def download_dds_results(download_dir="dds_results"): @@ -50,7 +51,7 @@ def split_data(data, prefix, base_i=0): n = 100_000 m = keys.shape[0] // n for i in range(m): - fname = os.path.join(download_dir, f"{prefix}_{i:03d}.npy") + fname = os.path.join(download_dir, f"{prefix}_{base_i + i:03d}.npy") with open(fname, "wb") as f: print( f"saving {fname} ... [{i * n}, {(i + 1) * n})", @@ -66,10 +67,15 @@ def split_data(data, prefix, base_i=0): os.makedirs(download_dir, exist_ok=True) - train_fname = os.path.join(download_dir, "dds_results_2.5M.npy") - if not os.path.exists(train_fname): - _download(DDS_RESULTS_TRAIN_URL, train_fname) - split_data(train_fname, "train") + train_small_fname = os.path.join(download_dir, "dds_results_2.5M.npy") + if not os.path.exists(train_small_fname): + _download(DDS_RESULTS_TRAIN_SMALL_URL, train_small_fname) + split_data(train_small_fname, "train") + + train_large_fname = os.path.join(download_dir, "dds_results_10M.npy") + if not os.path.exists(train_large_fname): + _download(DDS_RESULTS_TRAIN_LARGE_URL, train_large_fname) + split_data(train_large_fname, "train", base_i=25) test_fname = os.path.join(download_dir, "dds_results_500K.npy") if not os.path.exists(test_fname): From 3db1e1bc70422c1bbd8dd88a837d6c5a5ab71e8e Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Wed, 5 Jun 2024 14:47:08 +0900 Subject: [PATCH 3/3] fmt --- pgx/bridge_bidding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pgx/bridge_bidding.py b/pgx/bridge_bidding.py index 678db2a13..c9639671b 100644 --- a/pgx/bridge_bidding.py +++ b/pgx/bridge_bidding.py @@ -66,17 +66,17 @@ def split_data(data, prefix, base_i=0): ) os.makedirs(download_dir, exist_ok=True) - + train_small_fname = os.path.join(download_dir, "dds_results_2.5M.npy") if not os.path.exists(train_small_fname): _download(DDS_RESULTS_TRAIN_SMALL_URL, train_small_fname) - split_data(train_small_fname, "train") - + split_data(train_small_fname, "train") + train_large_fname = os.path.join(download_dir, "dds_results_10M.npy") if not os.path.exists(train_large_fname): _download(DDS_RESULTS_TRAIN_LARGE_URL, train_large_fname) - split_data(train_large_fname, "train", base_i=25) - + split_data(train_large_fname, "train", base_i=25) + test_fname = os.path.join(download_dir, "dds_results_500K.npy") if not os.path.exists(test_fname): _download(DDS_RESULTS_TEST_URL, test_fname)