Skip to content

Commit

Permalink
default
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Nov 15, 2024
1 parent c2d7c2e commit 8cbb1b2
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 205 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,4 @@ debug*
# Corkit
data
corkit/data
debug.ipynb
4 changes: 2 additions & 2 deletions corkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def wrapper(*args, **kwargs):
except (asyncio.TimeoutError, ClientPayloadError) as e:
print(f"Error: {e}. Retrying {attempts + 1}/{retries}...")
attempts += 1
await asyncio.sleep(delay)
await asyncio.sleep(delay*attempts)
raise TimeoutError(f"Failed after {retries} retries.")
return wrapper
return decorator
Expand Down Expand Up @@ -161,7 +161,7 @@ async def update() -> None:

download_filenames = [*chain.from_iterable(download_filenames)]

batch_size = 500
batch_size = 50
for i in range(0, len(download_filenames), batch_size):
await asyncio.gather(
*[
Expand Down
14 changes: 4 additions & 10 deletions corkit/lasco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""Utils dependencies"""
import matplotlib.pyplot as plt
from numpy.typing import NDArray
from .utils import (
Expand All @@ -9,7 +8,6 @@
reduce_std_size,
save,
adjust_hdr,
get_roll_or_xy,
get_sun_center,
rotate,
rot,
Expand All @@ -31,8 +29,6 @@
)

from .reconstruction import dl_image, normal_model_reconstruction, transforms, cross_model_reconstruction, fourier_model_reconstruction

"""Pypi dependencies"""
from astropy.visualization import HistEqStretch, ImageNormalize
from typing import Union, Dict, Tuple, List, Optional
from datetime import datetime, timedelta
Expand All @@ -52,7 +48,6 @@

__all__ = ["level_1", "CME", "LASCOplot", "downloader", "c3_calibrate", "c2_calibrate"]


##############
# LASCO #
##############
Expand Down Expand Up @@ -134,16 +129,15 @@ def level_1(
case "C2":
b, header = c2_calibrate(img0, header, **kwargs)
b, header = c2_warp(b, header)
zz: NDArray = np.where(img0 <= 0)
maskall: NDArray = np.ones((header["naxis1"], header["naxis2"]))
maskall[zz] = 0
maskall: NDArray = np.ones_like(img0)
maskall[img0 <= 0] = 0
maskall, _ = c2_warp(maskall, header)
b *= maskall
case "C3":
b, header = c3_calibrate(img0, header, *args, **kwargs)
bn, header = c3_warp(b, header)
zz: NDArray = np.where(img0 <= 0)
maskall: NDArray = np.ones((header["naxis1"], header["naxis2"]))
maskall: NDArray = np.ones((header["naxis2"], header["naxis1"]))
maskall[zz] = 0
maskallw, _ = c3_warp(maskall, header)
b = bn * maskallw * correct_var(header, args[1])[0]
Expand Down Expand Up @@ -389,7 +383,7 @@ def _read_bkg_full():
with fits.open(bkg_path) as hdul:
bkg = hdul[0].data.astype(float)
bkg *= 0.8 / hdul[0].header["exptime"]
return bkg
return bkg

def _read_ramp_full() -> NDArray:
ramp_path = os.path.join(DEFAULT_SAVE_DIR, "C3ramp.fts")
Expand Down
31 changes: 3 additions & 28 deletions corkit/reconstruction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from skimage.restoration import inpaint
from scipy.fft import fft, ifft
import numpy as np
from numpy.typing import NDArray
from .utils import deprecation, DEFAULT_SAVE_DIR
import torch
import torchvision.transforms as tt
Expand All @@ -13,13 +14,10 @@

def transforms():
def forward(x, bkg):
#Equalizing features
forward_eq = lambda x: ImageNormalize(stretch=HistEqStretch(x[np.isfinite(x)]))
forward_eq = forward_eq(x)
x = forward_eq(x)

bkg = ImageNormalize(stretch=HistEqStretch(bkg[np.isfinite(bkg)]))(bkg) < 0.2
# Defining the mask
msk = torch.from_numpy((((x <0.168) + bkg) < 0.99).astype(np.float32)).unsqueeze(0).unsqueeze(0)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
x = interpolate(x, (1024,1024), mode = 'bilinear', align_corners=False)
Expand Down Expand Up @@ -47,29 +45,16 @@ def normal_model_reconstruction():

def dl_image(model, img, bkg, forward_transform, inverse_transform):
init_shape = img.shape

x, forward_eq, mask = forward_transform(img.astype(np.float32), bkg)

plt.imshow(mask.detach().cpu().view(1024,1024).numpy())
plt.show()

if len(np.where(mask == 0.)[0]) > 32*32:
print('Reconstructing...')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

x, _ = model(x.view(1, 1, 1024,1024).to(device).float(), mask.view(1, 1, 1024,1024).to(device).float())

x = interpolate(x, size = init_shape)

x = inverse_transform(x, forward_eq, init_shape)

return x

else:

return img


Expand All @@ -79,31 +64,21 @@ def dl_image_cross(
img_1 += bkg
mask = img_1 < 0
img_1 -= bkg

img_1 = forward_transform(img_1)
img_2 = forward_transform(img_2)

img_2, _ = model(
img_1.unsqueeze(0), img_2.unsqueeze(0), time.unsqueeze(0), mask.unsqueeze(0)
)

img_2 = inverse_transform(img_2)

return img_2


def image_reconstruction(img: np.array):
def image_reconstruction(img: NDArray):
deprecation("1.1.0")
map_miss_blocks = -np.fix(img > 0.1)

# Find locs
mask = np.zeros(img.shape, dtype=bool)

mask[map_miss_blocks == 0] = 1

# Perform restoration
img_restored = inpaint.inpaint_biharmonic(img, mask)

return img_restored


Expand Down Expand Up @@ -137,7 +112,7 @@ def read_zone(img, list_miss_blocks, rebindex):
return zone_width, zone_height, zone


def dct(array: np.array, inverse: bool = False):
def dct(array: NDArray, inverse: bool = False):
deprecation("1.1.0")
shape = array.shape
dim = len(shape)
Expand Down
4 changes: 2 additions & 2 deletions corkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def reduce_statistics2(img, header, **kwargs):
mn = np.min(img[wmn])
mx = np.max(img[wmn])
medyen = np.median(img[wmn])
bscale: float = 1.0
bscale: float = 1.
if medyen > 0:
while medyen * bscale < 1000:
bscale *= 10
Expand Down Expand Up @@ -374,7 +374,7 @@ def reduce_statistics2(img, header, **kwargs):
header["DATASIG"] = sig

temparr = img[w] * bscale
h, _ = np.histogram(temparr, bins=np.arange(temparr.min(), temparr.max() + 1))
h, _ = np.histogram(temparr)
nh = len(h)
tot = 0
limits = np.array([0.01, 0.10, 0.25, 0.5, 0.75, 0.90, 0.95, 0.98, 0.99])
Expand Down
195 changes: 33 additions & 162 deletions sample.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/test_corkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from datetime import datetime
import asyncio
from corkit.lasco import CME, level_1, downloader
from corkit.dataset import update
from typing import List

async def tool_downloader(tool: str, scrap_date_list: List[datetime]) -> None:
Expand Down

0 comments on commit 8cbb1b2

Please sign in to comment.