Skip to content

Commit

Permalink
#100 Merge pull request from deshima-dev/astropenguin/issue94
Browse files Browse the repository at this point in the history
Add fit module
  • Loading branch information
astropenguin authored Oct 25, 2023
2 parents beec656 + cbefef5 commit 12b4ccb
Show file tree
Hide file tree
Showing 8 changed files with 10,253 additions and 20 deletions.
2 changes: 2 additions & 0 deletions decode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"assign",
"correct",
"fit",
"load",
"plot",
"select",
Expand All @@ -11,6 +12,7 @@
# submodules
from . import assign
from . import correct
from . import fit
from . import load
from . import plot
from . import select
10,008 changes: 10,008 additions & 0 deletions decode/data/alma_atm.txt

Large diffs are not rendered by default.

95 changes: 95 additions & 0 deletions decode/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
__all__ = ["baseline"]


# standard library
from typing import Any, Optional, Union


# dependencies
import numpy as np
import xarray as xr
from numpy.typing import NDArray
from sklearn import linear_model
from . import load


def baseline(
dems: xr.DataArray,
/,
*,
order: int = 0,
model: str = "LinearRegression",
weight: Optional[Union[NDArray[np.float_], float]] = None,
**options: Any,
) -> xr.DataArray:
"""Fit baseline by polynomial and atmospheric models.
Args:
dems: DEMS DataArray to be fit.
order: Maximum order of the polynomial model.
weight: One-dimensional weight along channel axis.
If it is a scalar, then ``(dtau/dpwv)^weight`` will be used.
It is only for ``'LinearRegression'`` or ``'Ridge'`` models.
model: Name of the model class in ``sklearn.linear_model``.
options: Optional arguments used for the model initialization.
Returns:
baseline: DataArray of the fit baseline.
"""
freq = dems.d2_mkid_frequency.values
slope = dtau_dpwv(freq).values
n_freq, n_poly = len(freq), order + 1

# create data to be fit
X = np.zeros([n_freq, n_poly + 1])
X[:, 0] = slope

for exp in range(n_poly):
X[:, exp + 1] = (freq - freq.mean()) ** exp

X /= np.linalg.norm(X, axis=0)
y = dems.values.T

if weight is None:
weight = np.ones_like(freq)
elif isinstance(weight, float):
weight = slope**weight
else:
weight = np.array(weight)

# fit model to data
options = {"fit_intercept": False, **options}
model = getattr(linear_model, model)(**options)

if model in ("LinearRegression", "Ridge"):
model.fit(X, y, sample_weight=weight) # type: ignore
else:
model.fit(X, y) # type: ignore

coeff: NDArray[np.float_] = model.coef_ # type: ignore

# create baseline
baseline = xr.zeros_like(dems)
baseline += np.outer(coeff[:, 0], X[:, 0])

for exp in range(n_poly + 1):
baseline.coords[f"basis_{exp}"] = "chan", X[:, exp]
baseline.coords[f"coeff_{exp}"] = "time", coeff[:, exp]

return baseline


def dtau_dpwv(freq: NDArray[np.float_]) -> xr.DataArray:
"""Calculate dtau/dpwv as a function of frequency.
Args:
freq: Frequency in units of Hz.
Returns:
DataArray that stores dtau/dpwv.
"""
tau = load.atm(type="tau").interp(freq=freq, method="linear")
fit = tau.curvefit("pwv", lambda x, a, b: a * x + b)
return fit["curvefit_coefficients"].sel(param="a", drop=True)
53 changes: 51 additions & 2 deletions decode/load.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,71 @@
__all__ = ["dems"]
__all__ = ["atm", "dems"]


# standard library
from pathlib import Path
from typing import Any, Union
from typing import Any, Literal, Union
from warnings import catch_warnings, simplefilter


# dependencies
import numpy as np
import pandas as pd
import xarray as xr


# constants
ALMA_ATM = "alma_atm.txt"
DATA_DIR = Path(__file__).parent / "data"
NETCDF_ENGINE = "scipy"
NETCDF_SUFFIX = ".nc"
ZARR_ENGINE = "zarr"
ZARR_SUFFIX = ".zarr"


def atm(*, type: Literal["eta", "tau"] = "tau") -> xr.DataArray:
"""Load an ALMA ATM model as a DataArray.
Args:
type: Type of model to be stored in the DataArray.
Either ``'eta'`` (transmission) or ``'tau'`` (opacity).
Returns:
DataArray that stores the ALMA ATM model.
"""
atm = pd.read_csv(
DATA_DIR / ALMA_ATM,
comment="#",
index_col=0,
sep=r"\s+",
)
freq = xr.DataArray(
atm.index * 1e9,
dims="freq",
attrs={
"long_name": "Frequency",
"units": "Hz",
},
)
pwv = xr.DataArray(
atm.columns.astype(float),
dims="pwv",
attrs={
"long_name": "Precipitable water vapor",
"units": "mm",
},
)

if type == "eta":
return xr.DataArray(atm, coords=(freq, pwv))
elif type == "tau":
with catch_warnings():
simplefilter("ignore")
return xr.DataArray(-np.log(atm), coords=(freq, pwv))
else:
raise ValueError("Type must be either eta or tau.")


def dems(dems: Union[Path, str], /, **options: Any) -> xr.DataArray:
"""Load a DEMS file as a DataArray.
Expand Down
108 changes: 90 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ documentation = "https://deshima-dev.github.io/decode/"
python = ">=3.9, <3.13"
matplotlib = "^3.7"
numpy = "^1.23"
pandas = ">=1.5, <3.0"
scikit-learn = "^1.2"
scipy = "^1.10"
xarray = "^2023.1"
zarr = "^2.14"
Expand Down
Empty file added tests/test_fit.py
Empty file.
Loading

0 comments on commit 12b4ccb

Please sign in to comment.