Skip to content

Commit

Permalink
reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Nov 14, 2024
1 parent ad0b7e1 commit c2d7c2e
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 116 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,9 @@ dmypy.json
# Cython debug symbols
cython_debug/

# Debugging scripts and files
# Debugging scripts and files
debug*

# Corkit
data
corkit/data
57 changes: 33 additions & 24 deletions corkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,36 @@
import os
from dataclasses import dataclass
import gdown

from torch.utils.data import Dataset
from typing import List, Optional
from torchvision.transforms import Compose, Resize, ToTensor, Lambda
from astropy.visualization import ImageNormalize, HistEqStretch
from torch import Tensor
from .lasco import CME
from astropy.io import fits
import asyncio
from aiohttp import ClientPayloadError
from functools import wraps
from . import __version__

__all__ = ["update", "CorKitDatasets"]

def timeout(retries=5, delay=10):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
attempts = 0
while attempts < retries:
try:
return await func(*args, **kwargs)
except (asyncio.TimeoutError, ClientPayloadError) as e:
print(f"Error: {e}. Retrying {attempts + 1}/{retries}...")
attempts += 1
await asyncio.sleep(delay)
raise TimeoutError(f"Failed after {retries} retries.")
return wrapper
return decorator

# done
def clean_old(file_list):
dat_files = filter(lambda filename: filename.lower().endswith(".dat"), file_list)
fts_files = filter(
Expand All @@ -30,7 +53,6 @@ def clean_old(file_list):
return [dat_file, fts_file]


# done
async def get_names(url: str, href):
async with aiohttp.ClientSession() as client:
async with client.get(url, ssl=False) as response:
Expand All @@ -40,8 +62,7 @@ async def get_names(url: str, href):
names = [name["href"] for name in names]
return names


# done
@timeout()
async def download_single(url: str, filepath: str):
if not os.path.exists(filepath):
async with aiohttp.ClientSession() as client:
Expand All @@ -50,8 +71,6 @@ async def download_single(url: str, filepath: str):
) as f:
await f.write(await response.read())


# done
async def update() -> None:
"""
# update
Expand Down Expand Up @@ -80,9 +99,9 @@ async def update() -> None:
links = list(map(base_link, names))
roots = list(map(base_root, names))

await asyncio.gather(
*[download_single(url, filepath) for url, filepath in zip(links, roots)]
)
for url, filepath in zip(links, roots):
await download_single(url, filepath)

print("Cheking data_anal directory...")

os.makedirs(os.path.join(DEFAULT_SAVE_DIR, "data/data_anal/"), exist_ok=True)
Expand Down Expand Up @@ -207,9 +226,7 @@ async def update() -> None:

def download_recons():
root = os.path.join(DEFAULT_SAVE_DIR, 'models')

os.makedirs(root, exist_ok=True)

url = 'https://drive.google.com/uc?id=102orHwKGr9BL6s-M4bc1a3HVgHO_JmDd'
output = os.path.join(root, 'partial_conv.pt')
gdown.download(url, output, quiet=False)
Expand All @@ -219,8 +236,7 @@ def download_recons():
# await download_single("", os.path.join(DEFAULT_SAVE_DIR, "models/cross.pt"))



# done
@timeout()
async def update_single(url: str, filepath: str):
if os.path.exists(filepath):
os.remove(filepath)
Expand All @@ -245,7 +261,9 @@ async def medv_update():
"c3_post_recovery_adj_xyr_medv2.sav",
]

await asyncio.gather(*[update_single(link(file), pathfile(file)) for file in files])
for file in files:
await update_single(link(file), pathfile(file))

filepaths = [pathfile(file) for file in files]
dfs = []

Expand All @@ -264,15 +282,6 @@ async def medv_update():
df.to_csv(filepath)


from torch.utils.data import Dataset
from typing import List, Optional
from torchvision.transforms import Compose, Resize, ToTensor, Lambda
from astropy.visualization import ImageNormalize, HistEqStretch
from torch import Tensor
from .lasco import CME
from astropy.io import fits


@dataclass
class CorKitDatasets(Dataset):
"""
Expand Down
Loading

0 comments on commit c2d7c2e

Please sign in to comment.