Skip to content

Commit

Permalink
Automatically formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Nov 16, 2024
1 parent 6e078a9 commit 2295c58
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 49 deletions.
2 changes: 1 addition & 1 deletion corkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '{{VERSION_PLACEHOLDER}}'
__version__ = "0.0.1731770678"
__author__ = "Jorgedavyd"
__email__ = "jorged.encyso@gmail.com"
7 changes: 3 additions & 4 deletions corkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import asyncio
from .dataset import update


def main():
parser = argparse.ArgumentParser(
description="Corkit CLI dataset update manager."
)
parser = argparse.ArgumentParser(description="Corkit CLI dataset update manager.")
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Batch size for dataset updates (default: 10)"
help="Batch size for dataset updates (default: 10)",
)
args = parser.parse_args()

Expand Down
15 changes: 11 additions & 4 deletions corkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

__all__ = ["update", "CorKitDatasets"]


def timeout(retries=5, delay=10):
def decorator(func):
@wraps(func)
Expand All @@ -38,11 +39,14 @@ async def wrapper(*args, **kwargs):
except (asyncio.TimeoutError, ClientPayloadError) as e:
print(f"Error: {e}. Retrying {attempts + 1}/{retries}...")
attempts += 1
await asyncio.sleep(delay*attempts)
await asyncio.sleep(delay * attempts)
raise TimeoutError(f"Failed after {retries} retries.")

return wrapper

return decorator


def clean_old(file_list):
dat_files = filter(lambda filename: filename.lower().endswith(".dat"), file_list)
fts_files = filter(
Expand All @@ -62,6 +66,7 @@ async def get_names(url: str, href):
names = [name["href"] for name in names]
return names


@timeout()
async def download_single(url: str, filepath: str):
if not os.path.exists(filepath):
Expand All @@ -71,6 +76,7 @@ async def download_single(url: str, filepath: str):
) as f:
await f.write(await response.read())


async def update(batch_size: int = 500) -> None:
"""
# update
Expand Down Expand Up @@ -223,11 +229,12 @@ async def update(batch_size: int = 500) -> None:
print("Downloading reconstructors and their utilities...")
download_recons()


def download_recons():
root = os.path.join(DEFAULT_SAVE_DIR, 'models')
root = os.path.join(DEFAULT_SAVE_DIR, "models")
os.makedirs(root, exist_ok=True)
url = 'https://drive.google.com/uc?id=102orHwKGr9BL6s-M4bc1a3HVgHO_JmDd'
output = os.path.join(root, 'partial_conv.pt')
url = "https://drive.google.com/uc?id=102orHwKGr9BL6s-M4bc1a3HVgHO_JmDd"
output = os.path.join(root, "partial_conv.pt")
gdown.download(url, output, quiet=False)

# await download_single("", os.path.join(DEFAULT_SAVE_DIR, "models/fourier.pt"))
Expand Down
43 changes: 32 additions & 11 deletions corkit/lasco.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
check_05,
)

from .reconstruction import dl_image, normal_model_reconstruction, transforms, cross_model_reconstruction, fourier_model_reconstruction
from .reconstruction import (
dl_image,
normal_model_reconstruction,
transforms,
cross_model_reconstruction,
fourier_model_reconstruction,
)
from astropy.visualization import HistEqStretch, ImageNormalize
from typing import Union, Dict, Tuple, List, Optional
from datetime import datetime, timedelta
Expand All @@ -48,6 +54,7 @@

__all__ = ["level_1", "CME", "LASCOplot", "downloader", "c3_calibrate", "c2_calibrate"]


##############
# LASCO #
##############
Expand Down Expand Up @@ -120,7 +127,7 @@ def level_1(
- 1023
!= 0
):
img0: NDArray= reduce_std_size(img0, header, FULL=dofull)
img0: NDArray = reduce_std_size(img0, header, FULL=dofull)

print(
f'LASCO-{header["detector"]}:{header["filename"]}:{header["date-obs"]}T{header["time-obs"]}...'
Expand Down Expand Up @@ -177,15 +184,15 @@ def level_1(
ramp = _read_ramp_full()
bkg = _read_bkg_full()
forward, inverse = transforms()
if 'model' not in kwargs:
if "model" not in kwargs:
model = normal_model_reconstruction()
else:
match kwargs['model']:
case 'normal':
match kwargs["model"]:
case "normal":
model = normal_model_reconstruction()
case 'fourier':
case "fourier":
model = fourier_model_reconstruction()
case 'cross':
case "cross":
model = cross_model_reconstruction()
args = (vig_pre, vig_post, mask, ramp, bkg, model, forward, inverse)
for filepath in fits_files:
Expand All @@ -202,6 +209,7 @@ def level_1(

return out


def final_step(
target_path: Optional[str],
filetype: str,
Expand Down Expand Up @@ -259,7 +267,9 @@ def final_step(
f"CorKit Level 1 calibration with python modules: level_1.py, open source level 1 implementation."
)
header["date"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S.%f")
header["filename"] = os.path.basename(target_path) if target_path is not None else 'null'
header["filename"] = (
os.path.basename(target_path) if target_path is not None else "null"
)
header["CRPIX1"] = crpix_x
header["CRPIX2"] = crpix_y
header["CROTA"] = r_hdr
Expand Down Expand Up @@ -310,6 +320,7 @@ def final_step(

return bout, header


class downloader:
tools = ["c2", "c3"]
batch_size = 2
Expand Down Expand Up @@ -379,30 +390,35 @@ async def __call__(self, scrap_date_list):
for scrap_date in scrap_date_list:
await self.downloader_pipeline(scrap_date)


def _read_bkg_full():
bkg_path = os.path.join(DEFAULT_SAVE_DIR, "3m_clcl_all.fts")
with fits.open(bkg_path) as hdul:
bkg = hdul[0].data.astype(float)
bkg *= 0.8 / hdul[0].header["exptime"]
return bkg


def _read_ramp_full() -> NDArray:
ramp_path = os.path.join(DEFAULT_SAVE_DIR, "C3ramp.fts")
ramp = fits.getdata(ramp_path)
return ramp


def _read_mask_full() -> NDArray:
msk_fn = os.path.join(DEFAULT_SAVE_DIR, "c3_cl_mask_lvl1.fts")
mask = fits.getdata(msk_fn)
return mask


def _read_vig_full() -> Tuple[NDArray, NDArray]:
vig_pre = os.path.join(DEFAULT_SAVE_DIR, "c3vig_preint_final.fts")
vig_post = os.path.join(DEFAULT_SAVE_DIR, "c3vig_postint_final.fts")
vig_pre = fits.getdata(vig_pre)
vig_post = fits.getdata(vig_post)
return vig_pre, vig_post


def c3_calibrate(img0: NDArray, header: fits.Header, *args, **kwargs):
assert header["detector"] == "C3", "Not valid C3 fits file"
if check_05(header):
Expand Down Expand Up @@ -437,7 +453,7 @@ def c3_calibrate(img0: NDArray, header: fits.Header, *args, **kwargs):
ramp = _read_ramp_full()
bkg = _read_bkg_full()
forward, inverse = transforms()
if 'model' not in kwargs:
if "model" not in kwargs:
model = normal_model_reconstruction()

header.add_history("C3ramp.fts 1999/03/18")
Expand All @@ -457,7 +473,9 @@ def c3_calibrate(img0: NDArray, header: fits.Header, *args, **kwargs):

vig, ramp, bkg, mask = correct_var(header, vig, ramp, bkg, mask)

img = c3_calibration_forward(img0, header, calfac, vig, mask, bkg, ramp, model, forward, inverse, **kwargs)
img = c3_calibration_forward(
img0, header, calfac, vig, mask, bkg, ramp, model, forward, inverse, **kwargs
)

header.add_history(
f"corkit/lasco.py c3_calibrate: (function) {__version__}, 12/04/24"
Expand Down Expand Up @@ -522,6 +540,7 @@ def c3_calibration_forward(
img = dl_image(model, img.T, bkg, forward, inverse)
return img


def c3_calfactor(header: fits.Header, **kwargs) -> Tuple[fits.Header, float]:
# Set calibration factor for the various filters
filter_ = header["filter"].upper().strip()
Expand Down Expand Up @@ -680,6 +699,7 @@ def c2_calfactor(header: fits.Header, **kwargs) -> Tuple[fits.Header, float]:

return header, cal_factor


def c2_calibrate(
img0: NDArray, header: fits.Header, **kwargs
) -> Tuple[NDArray, fits.Header]:
Expand All @@ -691,7 +711,7 @@ def c2_calibrate(
print("This file is already a Level 1 product.")
return img0, header

vig_full = kwargs.get('vig_full', None)
vig_full = kwargs.get("vig_full", None)
header, expfac, bias = get_exp_factor(header)
header["exptime"] *= expfac
header["offset"] = bias
Expand Down Expand Up @@ -723,6 +743,7 @@ def c2_calibrate(

return img, header


def c2_calibration_forward(img, header, calfac, vig):
if header["polar"] in [
"PB",
Expand Down
33 changes: 22 additions & 11 deletions corkit/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ def forward(x, bkg):
forward_eq = forward_eq(x)
x = forward_eq(x)
bkg = ImageNormalize(stretch=HistEqStretch(bkg[np.isfinite(bkg)]))(bkg) < 0.2
msk = torch.from_numpy((((x <0.168) + bkg) < 0.99).astype(np.float32)).unsqueeze(0).unsqueeze(0)
msk = (
torch.from_numpy((((x < 0.168) + bkg) < 0.99).astype(np.float32))
.unsqueeze(0)
.unsqueeze(0)
)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
x = interpolate(x, (1024,1024), mode = 'bilinear', align_corners=False)
msk = interpolate(msk, size = (1024,1024), mode = 'nearest')
x = interpolate(x, (1024, 1024), mode="bilinear", align_corners=False)
msk = interpolate(msk, size=(1024, 1024), mode="nearest")
return x, forward_eq, msk

def inverse(x, forward_eq, init_shape):
Expand All @@ -32,33 +36,40 @@ def inverse(x, forward_eq, init_shape):


def load_model(path: str):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cuda" if torch.cuda.is_available() else "cpu"
match device:
case 'cuda':
case "cuda":
return torch.jit.load(path)
case 'cpu':
return torch.jit.load(path, map_location = torch.device('cpu'))
case "cpu":
return torch.jit.load(path, map_location=torch.device("cpu"))


def cross_model_reconstruction():
path: str = os.path.join(DEFAULT_SAVE_DIR, "models/cross.pt")
return load_model(path)


def fourier_model_reconstruction():
path: str = os.path.join(DEFAULT_SAVE_DIR, "models/fourier.pt")
return load_model(path)


def normal_model_reconstruction():
path: str = os.path.join(DEFAULT_SAVE_DIR, "models/partial_conv.pt")
return load_model(path)


def dl_image(model, img, bkg, forward_transform, inverse_transform):
init_shape = img.shape
x, forward_eq, mask = forward_transform(img.astype(np.float32), bkg)

if len(np.where(mask == 0.)[0]) > 32*32:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x, _ = model(x.view(1, 1, 1024,1024).to(device).float(), mask.view(1, 1, 1024,1024).to(device).float())
x = interpolate(x, size = init_shape)
if len(np.where(mask == 0.0)[0]) > 32 * 32:
device = "cuda" if torch.cuda.is_available() else "cpu"
x, _ = model(
x.view(1, 1, 1024, 1024).to(device).float(),
mask.view(1, 1, 1024, 1024).to(device).float(),
)
x = interpolate(x, size=init_shape)
x = inverse_transform(x, forward_eq, init_shape)
return x
else:
Expand Down
5 changes: 3 additions & 2 deletions corkit/secchi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(self) -> None:
self.url = (
lambda date, name: f"https://secchi.nrl.navy.mil/postflight/cor2/L0/a/img/{date}/{name}"
)
self.png_path = lambda date, hour: f"./data/SECCHI/COR2/{date}_{hour}.png"
self.png_path = (
lambda date, hour: f"./data/SECCHI/COR2/{date}_{hour}.png"
)
self.fits_path = (
lambda date, hour: f"./data/SECCHI/COR2/{date}_{hour}.fits"
)
Expand Down Expand Up @@ -103,4 +105,3 @@ def data_prep(self, scrap_date):
scrap_date[0], scrap_date[-1], timedelta(days=1)
)
return self.get_days(scrap_date)

Loading

0 comments on commit 2295c58

Please sign in to comment.