From 04ac79bed183f70b74918e9d5386424d4937132f Mon Sep 17 00:00:00 2001 From: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:51:12 -0400 Subject: [PATCH] Stamp size again (#505) * change it so stamp size is used as a mandatory argument of sampling functions * stamp size is automatically propagated from sampling functions * propagate changes to stamp size --- README.md | 2 +- btk/draw_blends.py | 7 +-- btk/sampling_functions.py | 68 +++++++++++++++----------- notebooks/00-quickstart.ipynb | 2 - notebooks/01-advanced-deblending.ipynb | 1 - notebooks/01-advanced-generation.ipynb | 9 ++-- notebooks/01-advanced-metrics.ipynb | 1 - notebooks/02-advanced-plots.ipynb | 1 - tests/test_cosmos.py | 1 - tests/test_deblenders.py | 2 - tests/test_measure.py | 1 - tests/test_pipeline.py | 1 - 12 files changed, 46 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 3f9ec9db..90011394 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ sampling_function = btk.sampling_functions.DefaultSampling( # setup generator to create batches of blends batch_size = 100 draw_generator = btk.draw_blends.CatsimGenerator( - catalog, sampling_function, survey, batch_size, stamp_size + catalog, sampling_function, survey, batch_size ) # get batch of blends diff --git a/btk/draw_blends.py b/btk/draw_blends.py index 1eb6a483..cbe7ddae 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -148,7 +148,6 @@ def __init__( sampling_function: SamplingFunction, surveys: Union[List[Survey], Survey], batch_size: int = 8, - stamp_size: float = 24.0, njobs: int = 1, verbose: bool = False, use_bar: bool = False, @@ -165,7 +164,6 @@ def __init__( surveys: List of BTK Survey objects or single BTK Survey object. batch_size: Number of blends generated per batch - stamp_size: Size of the stamps, in arcseconds njobs: Number of njobs to use; defines the number of minibatches verbose: Indicates whether additionnal information should be printed use_bar: Whether to use progress bar (default: False) @@ -187,7 +185,7 @@ def __init__( self.max_number = self.blend_generator.max_number self.apply_shear = apply_shear self.augment_data = augment_data - self.stamp_size = stamp_size + self.stamp_size = sampling_function.stamp_size self.use_bar = use_bar self._set_surveys(surveys) @@ -523,7 +521,6 @@ def __init__( sampling_function: SamplingFunction, surveys: List[Survey], batch_size: int = 8, - stamp_size: float = 24.0, njobs: int = 1, verbose: bool = False, add_noise: str = "all", @@ -541,7 +538,6 @@ def __init__( sampling_function: See parent class. surveys: See parent class. batch_size: See parent class. - stamp_size: See parent class. njobs: See parent class. verbose: See parent class. add_noise: See parent class. @@ -563,7 +559,6 @@ def __init__( sampling_function, surveys, batch_size, - stamp_size, njobs, verbose, use_bar, diff --git a/btk/sampling_functions.py b/btk/sampling_functions.py index 2fba0500..3448cf4b 100644 --- a/btk/sampling_functions.py +++ b/btk/sampling_functions.py @@ -64,14 +64,16 @@ class SamplingFunction(ABC): galaxies chosen for the blend. """ - def __init__(self, max_number: int, min_number: int = 1, seed=DEFAULT_SEED): + def __init__(self, stamp_size: int, max_number: int, min_number: int = 1, seed=DEFAULT_SEED): """Initializes the SamplingFunction. Args: + stamp_size: The size of the stamp in arcseconds. max_number: maximum number of catalog entries returned from sample. min_number: minimum number of catalog entries returned from sample. (Default: 1) seed: Seed to initialize randomness for reproducibility. (Default: btk.DEFAULT_SEED) """ + self.stamp_size = stamp_size self.min_number = min_number self.max_number = max_number @@ -93,9 +95,9 @@ class DefaultSampling(SamplingFunction): def __init__( self, + stamp_size: float = 24.0, max_number: int = 2, min_number: int = 1, - stamp_size: float = 24.0, max_shift: Optional[float] = None, seed: int = DEFAULT_SEED, max_mag: float = 25.3, @@ -105,9 +107,9 @@ def __init__( """Initializes default sampling function. Args: + stamp_size: Size of the desired stamp. max_number: Defined in parent class min_number: Defined in parent class - stamp_size: Size of the desired stamp. max_shift: Magnitude of maximum value of shift. If None then it is set as one-tenth the stamp size. (in arcseconds) seed: Seed to initialize randomness for reproducibility. @@ -115,8 +117,9 @@ def __init__( max_mag: Maximum magnitude allowed in samples. mag_name: Name of the magnitude column in the catalog. """ - super().__init__(max_number=max_number, min_number=min_number, seed=seed) - self.stamp_size = stamp_size + super().__init__( + stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed + ) self.max_shift = max_shift if max_shift is not None else self.stamp_size / 10.0 self.min_mag, self.max_mag = min_mag, max_mag self.mag_name = mag_name @@ -165,10 +168,10 @@ class DensitySampling(SamplingFunction): def __init__( self, + stamp_size: float = 24.0, max_number: int = 40, min_number: int = 0, density: float = 185, - stamp_size: float = 24.0, max_shift: Optional[float] = None, seed: int = DEFAULT_SEED, max_mag: float = 27.3, @@ -178,11 +181,11 @@ def __init__( """Initializes default sampling function. Args: + stamp_size: Size of the desired stamp (in arcseconds) max_number: Defined in parent class min_number: Defined in parent class density: Density of galaxies, default corresponds to 27.3 i-band magnitude cut in CATSIM catalog. (in counts / sq. arcmin) - stamp_size: Size of the desired stamp (in arcseconds) max_shift: Magnitude of maximum value of shift. If None, then centroids can fall anywhere within the image. (in arcseconds) seed: Seed to initialize randomness for reproducibility. @@ -190,8 +193,9 @@ def __init__( max_mag: Maximum magnitude allowed in samples. mag_name: Name of the magnitude column in the catalog. """ - super().__init__(max_number=max_number, min_number=min_number, seed=seed) - self.stamp_size = stamp_size + super().__init__( + stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed + ) self.min_mag, self.max_mag = min_mag, max_mag self.mag_name = mag_name self.max_shift = max_shift if max_shift else self.stamp_size / 2 @@ -249,23 +253,24 @@ class BasicSampling(SamplingFunction): def __init__( self, + stamp_size: float = 24.0, max_number: int = 4, min_number: int = 1, - stamp_size: float = 24.0, mag_name: str = "i_ab", seed: int = DEFAULT_SEED, ): """Initializes the basic sampling function. Args: + stamp_size: Size of the desired stamp. max_number: Defined in parent class. min_number: Defined in parent class. - stamp_size: Size of the desired stamp. seed: Seed to initialize randomness for reproducibility. mag_name: Name of the magnitude column in the catalog for cuts. """ - super().__init__(max_number=max_number, min_number=min_number, seed=seed) - self.stamp_size = stamp_size + super().__init__( + stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed + ) self.mag_name = mag_name if min_number < 1: @@ -328,24 +333,33 @@ class DefaultSamplingShear(DefaultSampling): def __init__( self, + stamp_size: float = 24.0, max_number: int = 2, min_number: int = 1, - stamp_size: float = 24.0, max_shift: Optional[float] = None, seed=DEFAULT_SEED, + max_mag: float = 25.3, + min_mag: float = -np.inf, + mag_name: str = "i_ab", shear: Tuple[float, float] = (0.0, 0.0), ): """Initializes default sampling function with shear. Args: + stamp_size: Defined in parent class. max_number: Defined in parent class. min_number: Defined in parent class. stamp_size: Defined in parent class. max_shift: Defined in parent class. seed: Defined in parent class. + max_mag: Defined in parent class. + min_mag: Defined in parent class. + mag_name: Defined in parent class. shear: Constant (g1,g2) shear to apply to every galaxy. """ - super().__init__(max_number, min_number, stamp_size, max_shift, seed) + super().__init__( + stamp_size, max_number, min_number, max_shift, seed, max_mag, min_mag, mag_name + ) self.shear = shear def __call__(self, table: Table, **kwargs) -> Table: @@ -386,7 +400,7 @@ def __init__( bright_cut: Magnitude cut for bright galaxy. (Default: 25.3) dim_cut: Magnitude cut for dim galaxy. (Default: 28.0) """ - super().__init__(2, 1, seed) + super().__init__(stamp_size, 2, 1, seed) self.stamp_size = stamp_size self.max_shift = max_shift if max_shift is not None else self.stamp_size / 10.0 self.mag_name = mag_name @@ -427,8 +441,8 @@ class RandomSquareSampling(SamplingFunction): def __init__( self, - max_number: int = 2, stamp_size: float = 24.0, + max_number: int = 2, seed: int = DEFAULT_SEED, max_mag: float = 25.3, min_mag: float = -np.inf, @@ -437,16 +451,14 @@ def __init__( """Initializes the RandomSquareSampling sampling function. Args: - max_number: Defined in parent class stamp_size: Size of the desired stamp (arcsec). + max_number: Defined in parent class seed: Seed to initialize randomness for reproducibility. min_mag: Minimum magnitude allowed in samples max_mag: Maximum magnitude allowed in samples. mag_name: Name of the magnitude column in the catalog. """ - super().__init__(max_number=max_number, min_number=0, seed=seed) - self.stamp_size = stamp_size - self.max_number = max_number + super().__init__(stamp_size=stamp_size, max_number=max_number, min_number=0, seed=seed) self.max_mag = max_mag self.min_mag = min_mag self.mag_name = mag_name @@ -527,10 +539,10 @@ class FriendsOfFriendsSampling(SamplingFunction): def __init__( self, + stamp_size: float = 24.0, max_number: int = 10, min_number: int = 2, link_distance: int = 2.5, - stamp_size: float = 24.0, seed: int = DEFAULT_SEED, min_mag: float = -np.inf, max_mag: float = 25.3, @@ -539,8 +551,9 @@ def __init__( """Initializes the FriendsOfFriendsSampling sampling function. Args: - max_number: Defined in parent class - min_number: Defined in parent class + stamp_size: Defined in parent class. + max_number: Defined in parent class. + min_number: Defined in parent class. link_distance: Minimum linkage distance to form a group (arcsec). stamp_size: Size of the desired stamp (arcsec). seed: Seed to initialize randomness for reproducibility. @@ -548,11 +561,10 @@ def __init__( max_mag: Maximum magnitude allowed in samples. mag_name: Name of the magnitude column in the catalog. """ - super().__init__(max_number=max_number, min_number=min_number, seed=seed) - self.stamp_size = stamp_size + super().__init__( + stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed + ) self.link_distance = link_distance - self.max_number = max_number - self.min_number = min_number self.max_mag = max_mag self.min_mag = min_mag self.mag_name = mag_name diff --git a/notebooks/00-quickstart.ipynb b/notebooks/00-quickstart.ipynb index 1f325cbd..4f0ce188 100644 --- a/notebooks/00-quickstart.ipynb +++ b/notebooks/00-quickstart.ipynb @@ -277,7 +277,6 @@ " sampling_function,\n", " LSST,\n", " batch_size=batch_size,\n", - " stamp_size=stamp_size,\n", " njobs=1,\n", " add_noise=\"all\",\n", " seed=seed, # use same seed here\n", @@ -1066,7 +1065,6 @@ " sampling_function,\n", " LSST,\n", " batch_size=batch_size,\n", - " stamp_size=stamp_size,\n", " njobs=1,\n", " add_noise=\"all\",\n", " seed=seed, # use same seed here\n", diff --git a/notebooks/01-advanced-deblending.ipynb b/notebooks/01-advanced-deblending.ipynb index 2e752b94..bec2abba 100644 --- a/notebooks/01-advanced-deblending.ipynb +++ b/notebooks/01-advanced-deblending.ipynb @@ -88,7 +88,6 @@ " sampling_function,\n", " LSST,\n", " batch_size=batch_size,\n", - " stamp_size=24.0,\n", " njobs=1,\n", " add_noise=\"all\",\n", " seed=SEED,\n", diff --git a/notebooks/01-advanced-generation.ipynb b/notebooks/01-advanced-generation.ipynb index 407390b1..dd7f4f14 100644 --- a/notebooks/01-advanced-generation.ipynb +++ b/notebooks/01-advanced-generation.ipynb @@ -219,16 +219,16 @@ "\n", " def __init__(\n", " self,\n", + " stamp_size: float = 24.0,\n", " max_number: int = 2,\n", " min_number: int = 1,\n", - " stamp_size: float = 24.0,\n", " max_shift: Optional[float] = None,\n", " seed: int = DEFAULT_SEED,\n", " max_mag: float = 25.3,\n", " min_mag: float = -np.inf,\n", " mag_name: str = \"i_ab\",\n", " ):\n", - " super().__init__(max_number=max_number, min_number=min_number, seed=seed)\n", + " super().__init__(stamp_size=stamp_size, max_number=max_number, min_number=min_number, seed=seed)\n", " self.stamp_size = stamp_size\n", " self.max_shift = max_shift if max_shift is not None else self.stamp_size / 10.0\n", " self.min_mag, self.max_mag = min_mag, max_mag\n", @@ -553,12 +553,11 @@ "from btk.sampling_functions import DefaultSampling\n", "from btk.draw_blends import CosmosGenerator\n", "\n", - "sampling_func = DefaultSampling(3, 1,\n", - " stamp_size=16,\n", + "sampling_func = DefaultSampling(16, 3, 1,\n", " mag_name=\"MAG\" # use the name available in the COSMOS catalog.\n", " )\n", "\n", - "generator = CosmosGenerator(cosmos_cat, sampling_func, lsst, batch_size=9, stamp_size=16)" + "generator = CosmosGenerator(cosmos_cat, sampling_func, lsst, batch_size=9)" ] }, { diff --git a/notebooks/01-advanced-metrics.ipynb b/notebooks/01-advanced-metrics.ipynb index a3e7e94d..66375c51 100644 --- a/notebooks/01-advanced-metrics.ipynb +++ b/notebooks/01-advanced-metrics.ipynb @@ -100,7 +100,6 @@ " sampling_function,\n", " survey,\n", " batch_size=batch_size,\n", - " stamp_size=stamp_size,\n", " njobs=1,\n", " add_noise=\"all\",\n", " seed=0,\n", diff --git a/notebooks/02-advanced-plots.ipynb b/notebooks/02-advanced-plots.ipynb index 4dc1d122..38d379a9 100644 --- a/notebooks/02-advanced-plots.ipynb +++ b/notebooks/02-advanced-plots.ipynb @@ -141,7 +141,6 @@ " sampling_function,\n", " survey,\n", " batch_size=batch_size,\n", - " stamp_size=stamp_size,\n", " njobs=1,\n", " add_noise=\"background\",\n", " seed=seed, # use same seed here\n", diff --git a/tests/test_cosmos.py b/tests/test_cosmos.py index 4ca9428c..bf6d1e9a 100644 --- a/tests/test_cosmos.py +++ b/tests/test_cosmos.py @@ -40,7 +40,6 @@ def test_cosmos_generator(data_dir): sampling_function, survey, batch_size=batch_size, - stamp_size=stamp_size, njobs=1, add_noise="all", seed=SEED, diff --git a/tests/test_deblenders.py b/tests/test_deblenders.py index d80c7dcd..9f958651 100644 --- a/tests/test_deblenders.py +++ b/tests/test_deblenders.py @@ -33,7 +33,6 @@ def test_sep(data_dir): sampling_function, survey, batch_size=batch_size, - stamp_size=24.0, njobs=1, add_noise="all", seed=SEED, @@ -84,7 +83,6 @@ def test_scarlet(data_dir): sampling_function, LSST, batch_size=batch_size, - stamp_size=stamp_size, njobs=1, add_noise="all", seed=seed, # use same seed here diff --git a/tests/test_measure.py b/tests/test_measure.py index 0b68a85c..030a64fd 100644 --- a/tests/test_measure.py +++ b/tests/test_measure.py @@ -40,7 +40,6 @@ def test_measure(data_dir): sampling_function, survey, batch_size=batch_size, - stamp_size=stamp_size, njobs=1, add_noise="all", seed=SEED, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b9f99fee..a770368e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -39,7 +39,6 @@ def test_pipeline(data_dir): sampling_function, survey, batch_size=batch_size, - stamp_size=stamp_size, njobs=1, add_noise="all", seed=SEED,