diff --git a/decode/qlook.py b/decode/qlook.py index 53649f2..5406660 100644 --- a/decode/qlook.py +++ b/decode/qlook.py @@ -29,6 +29,8 @@ DEFAULT_FORMAT = "png" DEFAULT_FREQUENCY_UNITS = "GHz" DEFAULT_INCL_MKID_IDS = None +DEFAULT_OUTDIR = Path() +DEFAULT_OVERWRITE = False DEFAULT_SKYCOORD_GRID = "6 arcsec" DEFAULT_SKYCOORD_UNITS = "arcsec" SIGMA_OVER_MAD = 1.4826 @@ -43,7 +45,8 @@ def pswsc( data_type: Literal["df/f", "brightness", None] = DEFAULT_DATA_TYPE, frequency_units: str = DEFAULT_FREQUENCY_UNITS, format: str = DEFAULT_FORMAT, - outdir: Path = Path(), + outdir: Path = DEFAULT_OUTDIR, + overwrite: bool = DEFAULT_OVERWRITE, ) -> Path: """Quick-look at a PSW observation with sky chopper. @@ -58,6 +61,7 @@ def pswsc( frequency_units: Units of the frequency axis. format: Output data format of the quick-look result. outdir: Output directory for the quick-look result. + overwrite: Whether to overwrite the output if it exists. Returns: Absolute path of the saved file. @@ -77,10 +81,11 @@ def pswsc( spec = da_sub.mean("scan") # save result - filename = Path(dems).with_suffix(f".pswsc.{format}").name + file_name = Path(dems).with_suffix(f".pswsc.{format}").name + file_path = Path(outdir) / file_name if format in DATA_FORMATS: - return save_qlook(spec, Path(outdir) / filename) + return save_qlook(spec, file_path, overwrite=overwrite) fig, axes = plt.subplots(1, 2, figsize=DEFAULT_FIGSIZE) @@ -96,7 +101,7 @@ def pswsc( ax.grid(True) fig.tight_layout() - return save_qlook(fig, Path(outdir) / filename) + return save_qlook(fig, file_path, overwrite=overwrite) def raster( @@ -111,7 +116,8 @@ def raster( skycoord_grid: str = DEFAULT_SKYCOORD_GRID, skycoord_units: str = DEFAULT_SKYCOORD_UNITS, format: str = DEFAULT_FORMAT, - outdir: Path = Path(), + outdir: Path = DEFAULT_OUTDIR, + overwrite: bool = DEFAULT_OVERWRITE, ) -> Path: """Quick-look at a raster scan observation. @@ -134,6 +140,7 @@ def raster( skycoord_units: Units of the sky coordinate axes. format: Output image format of quick-look result. outdir: Output directory for the quick-look result. + overwrite: Whether to overwrite the output if it exists. Returns: Absolute path of the saved file. @@ -174,10 +181,11 @@ def raster( cont = cube.weighted(weight.fillna(0)).mean("chan") # save result - filename = Path(dems).with_suffix(f".raster.{format}").name + file_name = Path(dems).with_suffix(f".raster.{format}").name + file_path = Path(outdir) / file_name if format in DATA_FORMATS: - return save_qlook(cont, Path(outdir) / filename) + return save_qlook(cont, file_path, overwrite=overwrite) fig, axes = plt.subplots(1, 2, figsize=(12, 5.5)) @@ -202,7 +210,7 @@ def raster( ax.grid(True) fig.tight_layout() - return save_qlook(fig, Path(outdir) / filename) + return save_qlook(fig, file_path, overwrite=overwrite) def skydip( @@ -215,7 +223,8 @@ def skydip( chan_weight: Literal["uniform", "std", "std/tx"] = "std/tx", pwv: Literal["0.5", "1.0", "2.0", "3.0", "4.0", "5.0"] = "5.0", format: str = DEFAULT_FORMAT, - outdir: Path = Path(), + outdir: Path = DEFAULT_OUTDIR, + overwrite: bool = DEFAULT_OVERWRITE, ) -> Path: """Quick-look at a skydip observation. @@ -236,6 +245,7 @@ def skydip( the atmospheric transmission when chan_weight is std/tx. format: Output image format of quick-look result. outdir: Output directory for the quick-look result. + overwrite: Whether to overwrite the output if it exists. Returns: Absolute path of the saved file. @@ -255,10 +265,11 @@ def skydip( series = da_on.weighted(weight.fillna(0)).mean("chan") # save result - filename = Path(dems).with_suffix(f".skydip.{format}").name + file_name = Path(dems).with_suffix(f".raster.{format}").name + file_path = Path(outdir) / file_name if format in DATA_FORMATS: - return save_qlook(series, Path(outdir) / filename) + return save_qlook(series, file_path, overwrite=overwrite) fig, axes = plt.subplots(1, 2, figsize=DEFAULT_FIGSIZE) @@ -273,7 +284,7 @@ def skydip( ax.grid(True) fig.tight_layout() - return save_qlook(fig, Path(outdir) / filename) + return save_qlook(series, file_path, overwrite=overwrite) def still( @@ -286,7 +297,8 @@ def still( chan_weight: Literal["uniform", "std", "std/tx"] = "std/tx", pwv: Literal["0.5", "1.0", "2.0", "3.0", "4.0", "5.0"] = "5.0", format: str = DEFAULT_FORMAT, - outdir: Path = Path(), + outdir: Path = DEFAULT_OUTDIR, + overwrite: bool = DEFAULT_OVERWRITE, ) -> Path: """Quick-look at a still observation. @@ -307,6 +319,7 @@ def still( the atmospheric transmission when chan_weight is std/tx. format: Output data format of the quick-look result. outdir: Output directory for the quick-look result. + overwrite: Whether to overwrite the output if it exists. Returns: Absolute path of the saved file. @@ -325,10 +338,11 @@ def still( series = da.weighted(weight.fillna(0)).mean("chan") # save result - filename = Path(dems).with_suffix(f".still.{format}").name + file_name = Path(dems).with_suffix(f".still.{format}").name + file_path = Path(outdir) / file_name if format in DATA_FORMATS: - return save_qlook(series, Path(outdir) / filename) + return save_qlook(series, file_path, overwrite=overwrite) fig, axes = plt.subplots(1, 2, figsize=DEFAULT_FIGSIZE) @@ -343,7 +357,7 @@ def still( ax.grid(True) fig.tight_layout() - return save_qlook(fig, Path(outdir) / filename) + return save_qlook(series, file_path, overwrite=overwrite) def zscan( @@ -356,7 +370,8 @@ def zscan( chan_weight: Literal["uniform", "std", "std/tx"] = "std/tx", pwv: Literal["0.5", "1.0", "2.0", "3.0", "4.0", "5.0"] = "5.0", format: str = DEFAULT_FORMAT, - outdir: Path = Path(), + outdir: Path = DEFAULT_OUTDIR, + overwrite: bool = DEFAULT_OVERWRITE, ) -> Path: """Quick-look at an observation of subref axial focus scan. @@ -377,6 +392,7 @@ def zscan( the atmospheric transmission when chan_weight is std/tx. format: Output image format of quick-look result. outdir: Output directory for the quick-look result. + overwrite: Whether to overwrite the output if it exists. Returns: Absolute path of the saved file. @@ -396,10 +412,11 @@ def zscan( series = da_on.weighted(weight.fillna(0)).mean("chan") # save result - filename = Path(dems).with_suffix(f".zscan.{format}").name + file_name = Path(dems).with_suffix(f".zscan.{format}").name + file_path = Path(outdir) / file_name if format in DATA_FORMATS: - return save_qlook(series, Path(outdir) / filename) + return save_qlook(series, file_path, overwrite=overwrite) fig, axes = plt.subplots(1, 2, figsize=DEFAULT_FIGSIZE) @@ -414,7 +431,7 @@ def zscan( ax.grid(True) fig.tight_layout() - return save_qlook(fig, Path(outdir) / filename) + return save_qlook(series, file_path, overwrite=overwrite) def mean_in_time(dems: xr.DataArray) -> xr.DataArray: @@ -565,28 +582,52 @@ def load_dems( raise ValueError("Data type could not be inferred.") -def save_qlook(qlook: Union[Figure, xr.DataArray], filename: Path) -> Path: +def save_qlook( + qlook: Union[Figure, xr.DataArray], + file: Path, + /, + *, + overwrite: bool = False, +) -> Path: """Save a quick look result to a file with given format. Args: qlook: Matplotlib figure or DataArray to be saved. - filename: Path of the saved file. + file: Path of the saved file. + overwrite: Whether to overwrite the file if it exists. Returns: Absolute path of the saved file. """ + path = Path(file).expanduser().resolve() + + if path.exists() and not overwrite: + raise FileExistsError(f"{path} already exists.") + if isinstance(qlook, Figure): - qlook.savefig(filename) - elif (ext := "".join(filename.suffixes)) == ".csv": + qlook.savefig(path) + return path + + if path.name.endswith(".csv"): name = qlook.attrs["data_type"] - qlook.to_dataset(name=name).to_pandas().to_csv(filename) - elif ext == ".nc": - qlook.to_netcdf(filename) - elif ext == ".zarr" or format == ".zarr.zip": - qlook.to_zarr(filename, mode="w") + ds = qlook.to_dataset(name=name) + ds.to_pandas().to_csv(path) + return path + + if path.name.endswith(".nc"): + qlook.to_netcdf(path) + return path + + if path.name.endswith(".zarr"): + qlook.to_zarr(path, mode="w") + return path + + if path.name.endswith(".zarr.zip"): + qlook.to_zarr(path, mode="w") + return path - return Path(filename).expanduser().resolve() + raise ValueError("Extension of filename is not valid.") def main() -> None: