-
Notifications
You must be signed in to change notification settings - Fork 0
/
Transient_Tool_RASA36.py
2248 lines (1860 loc) · 90.9 KB
/
Transient_Tool_RASA36.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Transient Detection Tool
Version: 1.0.0
Author: YoungPyo Hong
Date: 2024-11-06
A GUI application for viewing and classifying astronomical transient candidates.
This tool provides:
- FITS and PNG image support
- Configurable display settings
- Classification categories
- Image caching and preloading
- Progress tracking
- Keyboard shortcuts
"""
# Standard library imports
import glob
import logging
import os
import re
import threading
import configparser
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
# Third-party imports
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.visualization import ZScaleInterval
from matplotlib import colors
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from matplotlib.widgets import Slider
import matplotlib.pyplot as plt
from tkinter import (Button, Checkbutton, Entry, Frame, IntVar, Label,
Tk, ttk, Text, messagebox)
from tkinter.ttk import Progressbar, Style
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
def handle_exceptions(func):
"""
Decorator to handle exceptions in user-facing methods.
Catches any exceptions, logs them, and displays error messages to the user.
Args:
func: The function to wrap
Returns:
Wrapped function that handles exceptions gracefully
"""
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
logging.exception(f"Error in {func.__name__}: {e}")
messagebox.showerror("Error", f"An error occurred: {e}")
return wrapper
@dataclass
class Config:
"""
Configuration settings for the TransientTool application.
This class manages all configuration parameters loaded from config.ini,
providing type hints and validation for each setting.
Key Features:
- Type validation for all settings
- Default values for optional parameters
- Configuration file loading and parsing
- Logging setup
"""
data_directory: str
file_pattern: str
output_csv_file: str
zoom_min: float
zoom_max: float
zoom_step: float
initial_zoom: float
default_sci_ref_visible: bool
scale: str
vmin_subtracted: str
vmax_subtracted: str
vmin_science: str
vmax_science: str
vmin_reference: str
vmax_reference: str
log_file: str
log_level: str
shortcuts: Dict[str, str]
file_type: str
tile_ids: List[str]
cache_size: int
classification_labels: List[str]
cache_window: int
preload_batch_size: int
columns_order: List[str] = field(default_factory=lambda: [
'file_index', 'tile_id', 'unique_number', 'Memo', 'Scale'
])
# Optional parameters (with defaults)
view_mode: bool = False
specific_view_mode: Optional[str] = None
quick_start: bool = False
@staticmethod
def load_config(config_path: str = 'config.ini') -> 'Config':
"""
Load and validate configuration settings from INI file.
Performs the following:
1. Reads the INI file
2. Validates required settings
3. Sets default values for optional settings
4. Configures logging
5. Validates data types and value ranges
Args:
config_path: Path to configuration file
Returns:
Config object with validated settings
Raises:
ValueError: If required settings are missing or invalid
"""
try:
config = configparser.ConfigParser()
config.read(config_path)
def get_config_option(section: str, option: str, type_func: Any, default: Any) -> Any:
"""
Get config option value, ignoring comments after #.
Args:
section: Config section name
option: Option name within section
type_func: Type conversion function
default: Default value if option not found
Returns:
Converted option value or default
"""
try:
# Get raw value and strip comments after #
value = config.get(section, option)
if '#' in value:
value = value.split('#')[0].strip()
# Handle special case for None values
if value.lower() == 'none':
return None
# Handle boolean values specially
if type_func == bool:
return value.lower() in ['true', '1', 'yes']
# Convert value to specified type
return type_func(value)
except (configparser.NoSectionError, configparser.NoOptionError):
return default
except ValueError as e:
logging.warning(f"Error parsing {option} from config: {e}")
return default
# Load shortcuts
shortcuts = {}
if config.has_section('Shortcuts'):
for key in config.options('Shortcuts'):
shortcuts[key] = config.get('Shortcuts', key).strip()
# Load mode settings
view_mode = get_config_option('Mode', 'view_mode', bool, False)
specific_view_mode = get_config_option('Mode', 'specific_view_mode', str, None)
# Load tile IDs
raw_tile_ids = config.get('TileSettings', 'tile_ids', fallback='').split(',')
tile_ids = []
if any(tid.strip() for tid in raw_tile_ids): # If tile_ids is not empty
for tid in raw_tile_ids:
if tid.strip():
tile_id = DataManager.get_tile_id(tid.strip())
if tile_id:
tile_ids.append(tile_id)
if not tile_ids:
logging.warning("No valid tile IDs found in config, will auto-detect")
else:
logging.info("No tile IDs specified in config, will auto-detect")
# Load classification labels
classification_labels = [label.strip() for label in
config.get('Settings', 'classification_labels', fallback='').split(',')]
# Set up logging configuration
log_file = get_config_option('Logging', 'log_file', str, 'transient_tool.log')
log_level = get_config_option('Logging', 'log_level', str, 'INFO').upper()
# Configure logging
logging.basicConfig(
filename=log_file,
level=getattr(logging, log_level),
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logging.info("Configuration loaded successfully")
# Create config instance with all parameters
config_obj = Config(
data_directory=get_config_option('Paths', 'data_directory', str, ''),
file_pattern=get_config_option('Paths', 'file_pattern', str, ''),
output_csv_file=get_config_option('Paths', 'output_csv_file', str, ''),
zoom_min=get_config_option('Settings', 'zoom_min', float, 1.0),
zoom_max=get_config_option('Settings', 'zoom_max', float, 10.0),
zoom_step=get_config_option('Settings', 'zoom_step', float, 0.1),
initial_zoom=get_config_option('Settings', 'initial_zoom', float, 1.0),
default_sci_ref_visible=get_config_option('Settings', 'default_sci_ref_visible',bool, True),
scale=get_config_option('Settings', 'scale', str, 'zscale').lower(),
vmin_subtracted=get_config_option('Settings', 'vmin_subtracted', str, 'median').lower(),
vmax_subtracted=get_config_option('Settings', 'vmax_subtracted', str, 'max').lower(),
vmin_science=get_config_option('Settings', 'vmin_science', str, 'median').lower(),
vmax_science=get_config_option('Settings', 'vmax_science', str, 'max').lower(),
vmin_reference=get_config_option('Settings', 'vmin_reference', str, 'median').lower(),
vmax_reference=get_config_option('Settings', 'vmax_reference', str, 'max').lower(),
log_file=log_file,
log_level=log_level,
shortcuts=shortcuts,
file_type=get_config_option('Settings', 'file_type', str, 'fits').lower(),
tile_ids=tile_ids, # 빈 리스트여도 괜찮음
cache_size=get_config_option('TileSettings', 'cache_size', int, 100),
classification_labels=classification_labels,
cache_window=get_config_option('TileSettings', 'cache_window', int, 10),
preload_batch_size=get_config_option('TileSettings', 'preload_batch_size', int, 5),
view_mode=view_mode,
specific_view_mode=specific_view_mode,
quick_start=get_config_option('Mode', 'quick_start', bool, False)
)
# Validation of required fields
required_options = ['data_directory', 'file_pattern', 'output_csv_file']
for option in required_options:
if not getattr(config_obj, option):
raise ValueError(f"Missing required configuration option: {option} in section 'Paths'.")
# Validate scale option
if config_obj.scale not in ['zscale', 'linear', 'log']:
logging.warning(f"Invalid scale '{config_obj.scale}' in configuration. Using 'linear' as default.")
config_obj.scale = 'linear'
# Validate file_type option
if config_obj.file_type not in ['fits', 'png']:
raise ValueError("Invalid file_type option in configuration. Choose 'fits' or 'png'.")
return config_obj
except Exception as e:
raise ValueError(f"Failed to load configuration: {e}")
class DataManager:
"""
Handles data loading, processing and persistence for astronomical images.
Key responsibilities:
1. Image data loading and caching
2. Classification data management
3. Progress tracking
4. File operations (CSV read/write)
5. Cache management
The DataManager maintains a thread-safe environment for concurrent operations
and implements efficient caching strategies for optimal performance.
"""
def __init__(self, config: Config):
"""
Initialize DataManager with configuration.
Sets up:
- Data structures for image and metadata management
- Thread locks for concurrent operations
- Cache initialization
- Helper components (ImageProcessor, DataValidator)
Args:
config: Configuration settings
"""
# Configuration and basic attributes
self.config = config
self.region_df = None
self.index = 0
self.total_images = 0
# Cache related attributes
self.image_cache = {}
self.cache_size = config.cache_size
self.cache_window = config.cache_window
self.preload_batch_size = config.preload_batch_size
# Thread locks
self.cache_lock = threading.Lock()
self.preload_lock = threading.Lock()
self.preload_thread = threading.Thread()
self.file_lock = threading.Lock()
# Helper components
self.image_processor = ImageProcessor(config)
self.data_validator = DataValidator(config)
# Load data based on quick_start mode
if hasattr(self.config, 'quick_start') and self.config.quick_start:
self._quick_start_load()
else:
self._full_load()
if hasattr(self.config, 'specific_view_mode') and self.config.specific_view_mode:
self.valid_indices = self.region_df[
self.region_df[self.config.specific_view_mode] == 1
].index.tolist()
logging.info("DataManager initialized.")
def _quick_start_load(self):
"""Load data directly from CSV without scanning directories."""
try:
if os.path.exists(self.config.output_csv_file):
self.region_df = pd.read_csv(self.config.output_csv_file)
logging.info(f"Found existing CSV with {len(self.region_df)} entries")
# Initialize DataFrame structure
self.init_dataframe()
else:
raise FileNotFoundError("Required CSV file not found for quick start mode")
except Exception as e:
logging.error(f"Error in quick start load: {e}")
raise
def _full_load(self):
"""Perform full load with directory scanning."""
try:
self.load_files() # This includes directory scanning
except Exception as e:
logging.error(f"Error in full load: {e}")
raise
def load_files(self):
"""Load files and initialize DataFrame."""
try:
# Load existing data if available
existing_data = None
if os.path.exists(self.config.output_csv_file):
existing_data = pd.read_csv(self.config.output_csv_file)
logging.info(f"Found existing CSV with {len(existing_data)} entries")
# Get current files, using existing data to avoid reprocessing
if existing_data is not None:
self.region_df = existing_data.copy()
logging.info("Using existing data without rescanning")
else:
# No existing data, scan all files
self.region_df = self.scan_directory_for_files().copy()
logging.info(f"Created new index with {len(self.region_df)} entries")
# Initialize DataFrame with proper structure
self.init_dataframe()
except Exception as e:
logging.exception(f"Error loading files: {e}")
raise
def scan_directory_for_files(self, existing_keys=None):
"""Scan directory and create new index for files."""
try:
file_data = []
base_dir = self.config.data_directory
tile_ids = self.config.tile_ids if self.config.tile_ids else self.get_all_tile_ids()
for tile_id in tile_ids:
pattern = f"**/*{tile_id}*.com.*.sub.{self.config.file_type}"
full_pattern = os.path.join(base_dir, pattern)
files = glob.glob(full_pattern, recursive=True)
if files:
logging.info(f"Found {len(files)} files for tile {tile_id}")
# Create temporary list for this tile's files
tile_data = []
for filename in files:
unique_number = self.get_unique_number(filename)
if unique_number is not None:
# Skip if file already exists in CSV
if existing_keys and (tile_id, unique_number) in existing_keys:
continue
file_data_dict = {
'tile_id': tile_id,
'unique_number': unique_number,
'Memo': '',
'Scale': ''
}
for label in self.config.classification_labels:
file_data_dict[label] = 0
tile_data.append(file_data_dict)
# Sort tile_data by unique_number before adding to main list
tile_data.sort(key=lambda x: x['unique_number'])
file_data.extend(tile_data)
else:
logging.info(f"No files found for tile {tile_id}")
# Add file_index after sorting
for i, data in enumerate(file_data):
data['file_index'] = i
df = pd.DataFrame(file_data)
if len(file_data) > 0:
logging.info(f"Found {len(df)} new files across all tiles")
return df
except Exception as e:
logging.exception(f"Error scanning directory: {e}")
raise
@staticmethod
def get_unique_number(filename: str) -> Optional[int]:
"""
Extract unique identifier from filename.
Args:
filename: Full path to image file
Returns:
Unique number from filename or None if not found
Example:
>>> get_unique_number("path/to/com.123.sub.fits")
123
"""
basename = os.path.basename(filename)
match = re.search(r'com\.(\d+)\.', basename)
if match:
return int(match.group(1))
return None
def get_tile_id(filename: str) -> Optional[str]:
"""
Extract tile ID from filename.
Expected format: 'T<number>' in the filename
Returns the full ID including 'T' prefix
"""
match = re.search(r'(T\d+)', filename)
if match:
return match.group(1)
return None
def init_dataframe(self):
"""Initialize DataFrame with proper structure and defaults."""
try:
# Create a copy to avoid chained assignment
df = self.region_df.copy()
# Remove total row (file_index == -1) if exists
df = df[df['file_index'] != -1]
# Validate DataFrame structure
is_valid, errors = self.data_validator.validate_dataframe(df)
if not is_valid:
logging.warning(f"DataFrame validation failed: {errors}")
# Add missing columns if needed
for col in self.config.classification_labels:
if col not in df.columns:
df.loc[:, col] = 0
df.loc[:, col] = df[col].fillna(0).astype(int)
if 'Memo' not in df.columns:
df.loc[:, 'Memo'] = ''
df['Memo'] = df['Memo'].fillna('').astype(str)
if 'Scale' not in df.columns:
df.loc[:, 'Scale'] = self.config.scale
df['Scale'] = df['Scale'].fillna(self.config.scale).astype(str)
# Ensure proper column order including classification labels
all_columns = [*self.config.columns_order, *self.config.classification_labels]
existing_columns = [col for col in all_columns if col in df.columns]
# Verify all required columns exist
missing_columns = set(self.config.columns_order) - set(df.columns)
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
df = df[existing_columns]
# Sort by file_index and reset index to ensure alignment
df = df.sort_values('file_index').reset_index(drop=True)
# Fix file_index alignment
if not df['file_index'].equals(pd.Series(range(len(df)))):
logging.warning("Fixing non-sequential file_index values")
df['file_index'] = range(len(df))
# Assign back to self.region_df
self.region_df = df
# Save DataFrame to CSV
self.save_dataframe()
logging.info(f"DataFrame initialized with {len(self.region_df)} rows")
except Exception as e:
logging.error(f"Error initializing DataFrame: {e}")
raise
def get_starting_index(self) -> int:
"""
Determine starting index based on first unprocessed image.
Returns index of first image where all classification values are 0.
"""
try:
if self.region_df.empty:
return 0
# Check if any classification label is 1 for each row
classified = self.region_df[self.config.classification_labels].any(axis=1)
# Find first unclassified image (where all labels are 0)
unclassified = self.region_df[~classified]
if not unclassified.empty:
# Get the file_index of first unclassified image
first_unclassified_index = unclassified['file_index'].iloc[0]
logging.info(f"Starting from first unclassified image at index {first_unclassified_index}")
return first_unclassified_index
# If all images are classified, start from beginning
logging.info("All images are classified, starting from index 0")
return 0
except Exception as e:
logging.error(f"Critical error in get_starting_index: {e}")
raise
def save_dataframe(self, mode='w', callback=None):
"""Save the DataFrame to CSV file."""
try:
# Validate DataFrame before saving
is_valid, errors = self.data_validator.validate_dataframe(self.region_df)
if not is_valid:
error_msg = "\n".join(errors)
logging.error(f"DataFrame validation failed:\n{error_msg}")
raise ValueError(f"DataFrame validation failed:\n{error_msg}")
# Create a copy to avoid modifying the original
df_to_save = self.region_df.copy()
# Ensure Memo column is preserved as string
df_to_save['Memo'] = df_to_save['Memo'].astype(str)
# Calculate totals for classifications
totals = {}
for col in self.config.classification_labels:
totals[col] = int(df_to_save[col].sum())
# Count total images and processed images
total_images = len(df_to_save)
total_processed = len(df_to_save[df_to_save[self.config.classification_labels].any(axis=1)])
percent_processed = (total_processed/total_images*100) if total_images > 0 else 0
# Create total row
total_dict = {
'file_index': -1,
'tile_id': 'Total',
'unique_number': len(df_to_save['tile_id'].unique()),
'Memo': f"{total_processed}/{total_images}",
'Scale': f"{percent_processed:.2f}%"
}
total_dict.update(totals)
# Remove any existing total rows and duplicates
df_to_save = df_to_save[df_to_save['file_index'] != -1].drop_duplicates(subset=['tile_id', 'unique_number'])
# Append total row
df_to_save = pd.concat([df_to_save, pd.DataFrame([total_dict])], ignore_index=True)
# Save with file lock
with self.file_lock:
df_to_save.to_csv(self.config.output_csv_file, index=False, mode='w', na_rep='')
logging.info(f"DataFrame saved successfully to {self.config.output_csv_file}")
if callback:
callback()
except Exception as e:
logging.error(f"Error saving DataFrame: {e}")
raise
def load_image_data(self, index: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Load image data with improved caching.
"""
try:
index = int(index)
if index < 0 or index >= len(self.region_df):
raise ValueError(f"Index {index} out of bounds")
current_row = self.region_df.iloc[index]
tile_id = current_row['tile_id']
unique_number = current_row['unique_number']
cache_key = f"{tile_id}_{unique_number}"
# Check cache with minimal locking
with self.cache_lock:
if cache_key in self.image_cache:
return self.image_cache[cache_key]
# Load images using ImageProcessor
images = self.image_processor.load_and_process_images(tile_id, unique_number)
# Update cache
with self.cache_lock:
self.image_cache[cache_key] = images
# Cleanup cache if needed
if len(self.image_cache) > self.cache_size:
self.cleanup_cache(index)
return images
except Exception as e:
logging.error(f"Error in load_image_data: {e}")
raise
def start_preloading(self, current_index: int):
"""
Start preloading images in a background thread, prioritizing current tile ID.
"""
try:
if self.preload_thread.is_alive():
return # Skip if previous preload is still running
current_row = self.region_df.iloc[current_index]
current_tile = current_row['tile_id']
# Get indices to preload (prioritize same tile, then sequential)
next_indices = self._get_preload_indices(current_index, current_tile)
if next_indices:
self.preload_thread = threading.Thread(
target=self.preload_images,
args=(next_indices,),
daemon=True
)
self.preload_thread.start()
logging.debug(f"Started preloading {len(next_indices)} images")
except Exception as e:
logging.error(f"Error starting preload thread: {e}")
def _get_preload_indices(self, current_index: int, current_tile: str) -> List[int]:
"""
Get optimized list of indices to preload, prioritizing same tile ID.
"""
try:
indices_to_preload = []
# First, get next few images from same tile
same_tile_indices = self.region_df[
(self.region_df['tile_id'] == current_tile) &
(self.region_df.index > current_index)
].index[:self.preload_batch_size].tolist()
# Then, get next sequential images
sequential_indices = range(
current_index + 1,
min(current_index + self.preload_batch_size, len(self.region_df))
)
# Combine and remove duplicates while maintaining order
all_indices = []
for idx in same_tile_indices + list(sequential_indices):
if idx not in all_indices:
all_indices.append(idx)
# Filter out already cached images
for idx in all_indices:
row = self.region_df.iloc[idx]
cache_key = f"{row['tile_id']}_{row['unique_number']}"
if cache_key not in self.image_processor.image_cache:
indices_to_preload.append(idx)
if len(indices_to_preload) >= self.preload_batch_size:
break
return indices_to_preload
except Exception as e:
logging.error(f"Error getting preload indices: {e}")
return []
def preload_images(self, indices: Union[int, List[int]]):
"""
Preload images with improved performance.
"""
try:
if isinstance(indices, int):
indices = [indices]
# Group indices by tile_id for efficient loading
tile_groups = {}
for idx in indices:
row = self.region_df.iloc[idx]
tile_id = row['tile_id']
if tile_id not in tile_groups:
tile_groups[tile_id] = []
tile_groups[tile_id].append((idx, row['unique_number']))
# Process each tile group
for tile_id, index_pairs in tile_groups.items():
try:
# Sort by unique_number for sequential file access
index_pairs.sort(key=lambda x: x[1])
for idx, unique_number in index_pairs:
cache_key = f"{tile_id}_{unique_number}"
if cache_key not in self.image_processor.image_cache:
self.load_image_data(idx)
except Exception as e:
logging.error(f"Error preloading tile {tile_id}: {e}")
continue
except Exception as e:
logging.error(f"Error in preload_images: {e}")
def calculate_progress(self) -> dict:
"""Calculate progress statistics for total and per-tile."""
try:
progress_stats = {}
# Calculate total progress
total_images = len(self.region_df)
total_classified = len(self.region_df[
self.region_df[self.config.classification_labels].any(axis=1)
])
total_percent = (total_classified / total_images * 100) if total_images > 0 else 0
progress_stats['total'] = {
'classified': total_classified,
'total': total_images,
'percent': total_percent
}
# Calculate progress by tile
progress_stats['tiles'] = {}
for tile_id in sorted(self.region_df['tile_id'].unique()):
tile_df = self.region_df[self.region_df['tile_id'] == tile_id]
tile_total = len(tile_df)
tile_classified = len(tile_df[tile_df[self.config.classification_labels].any(axis=1)])
tile_percent = (tile_classified / tile_total * 100) if tile_total > 0 else 0
progress_stats['tiles'][tile_id] = {
'classified': tile_classified,
'total': tile_total,
'percent': tile_percent
}
return progress_stats
except Exception as e:
logging.error(f"Error calculating progress: {e}")
def cleanup_cache(self, current_index: int):
"""Remove images outside cache window with tile ID priority."""
try:
with self.cache_lock:
current_tile = self.region_df.iloc[current_index]['tile_id']
# Keep images from current tile and within window
window_start = max(0, current_index - self.cache_window)
window_end = min(len(self.region_df), current_index + self.cache_window)
valid_indices = set(range(window_start, window_end + 1))
valid_keys = set()
# Add keys for current tile
for idx in self.region_df[self.region_df['tile_id'] == current_tile].index:
row = self.region_df.iloc[idx]
valid_keys.add(f"{row['tile_id']}_{row['unique_number']}")
# Add keys for window
for idx in valid_indices:
row = self.region_df.iloc[idx]
valid_keys.add(f"{row['tile_id']}_{row['unique_number']}")
# Remove invalid keys
for key in list(self.image_cache.keys()):
if key not in valid_keys:
del self.image_cache[key]
except Exception as e:
logging.error(f"Error cleaning cache: {e}")
def get_all_tile_ids(self):
"""Scan directory to find all available tile IDs."""
try:
base_dir = self.config.data_directory
pattern = "**/*RASA36-T*-*.com.*.sub." + self.config.file_type
all_files = glob.glob(os.path.join(base_dir, pattern), recursive=True)
# Extract tile IDs using regex
tile_ids = set()
for filepath in all_files:
match = re.search(r'RASA36-(T\d{5})-', filepath)
if match:
tile_ids.add(match.group(1))
if not tile_ids:
logging.warning("No tile IDs found in directory")
return []
sorted_tile_ids = sorted(list(tile_ids))
logging.info(f"Found {len(sorted_tile_ids)} tile IDs: {', '.join(sorted_tile_ids)}")
return sorted_tile_ids
except Exception as e:
logging.error(f"Error finding tile IDs: {e}")
return []
class DataValidator:
"""Class to handle data validation."""
def __init__(self, config: Config):
self.config = config
def validate_dataframe(self, df: pd.DataFrame) -> Tuple[bool, List[str]]:
"""
Validate DataFrame structure and content.
Single source of truth for DataFrame validation.
"""
errors = []
try:
# Validate required columns
errors.extend(self._validate_columns(df))
# Validate data types
errors.extend(self._validate_data_types(df))
# Validate classification values
errors.extend(self._validate_classifications(df))
return len(errors) == 0, errors
except Exception as e:
errors.append(f"Validation error: {str(e)}")
return False, errors
def _validate_columns(self, df: pd.DataFrame) -> List[str]:
"""Validate required columns exist."""
errors = []
required_cols = self.config.columns_order + self.config.classification_labels
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
errors.append(f"Missing required columns: {missing_cols}")
return errors
def _validate_data_types(self, df: pd.DataFrame) -> List[str]:
"""Validate data types of key columns."""
errors = []
if 'unique_number' in df.columns:
if not pd.to_numeric(df['unique_number'], errors='coerce').notna().all():
errors.append("Invalid unique_number values found")
return errors
def _validate_classifications(self, df: pd.DataFrame) -> List[str]:
"""Validate classification column values."""
errors = []
for col in self.config.classification_labels:
if col in df.columns:
invalid = ~df[col].isin([0, 1, np.nan])
if invalid.any():
invalid_rows = df.loc[invalid, 'unique_number'].tolist()
errors.append(f"Invalid {col} values in rows: {invalid_rows}")
return errors
class ImageProcessor:
"""
Handles image processing operations and caching.
Key features:
1. FITS and PNG image support
2. Image scaling (zscale, linear, log)
3. Normalization and value range management
4. Thread-safe caching
5. Memory optimization
The ImageProcessor ensures efficient image loading and processing
while maintaining memory usage within configured limits.
"""
def __init__(self, config: Config):
self.config = config
self.image_cache = {}
self.cache_size = config.cache_size
self.cache_lock = threading.Lock()
self.zscale = ZScaleInterval()
# Add image type flag
self.is_fits = config.file_type.lower() == 'fits'
def _update_cache(self, key: str, value: Tuple[np.ndarray, np.ndarray, np.ndarray]):
"""Thread-safe cache update."""
try:
with self.cache_lock:
if len(self.image_cache) >= self.cache_size:
oldest_key = next(iter(self.image_cache))
del self.image_cache[oldest_key]
self.image_cache[key] = value
except Exception as e:
logging.error(f"Error updating cache: {e}")
def load_and_process_images(self, tile_id: str, unique_number: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Load and process images with parallel loading for better performance.
"""
try:
# Get paths for all image types
paths = self._get_image_paths(tile_id, unique_number)
# Parallel loading using ThreadPoolExecutor
images = {'sub': None, 'new': None, 'ref': None}
# Create a thread pool with 3 workers (one for each image type)
with ThreadPoolExecutor(max_workers=3) as executor:
futures = {
'sub': executor.submit(self._load_single_image, paths['sub']),
'new': executor.submit(self._load_single_image, paths['new']),
'ref': executor.submit(self._load_single_image, paths['ref'])
}
# Wait for all tasks to complete and get results
for img_type, future in futures.items():
try:
images[img_type] = future.result(timeout=30)
except Exception as e:
logging.error(f"Error loading {img_type} image: {e}")
# Validate required images
if images['sub'] is None:
raise FileNotFoundError(f"Failed to load subtracted image for {tile_id}-{unique_number}")
# Process images based on configuration - now using is_fits flag
processed_images = {
img_type: img_data
for img_type, img_data in images.items()
if img_data is not None
}
return (processed_images.get('sub'),
processed_images.get('new'),
processed_images.get('ref'))
except Exception as e:
logging.error(f"Error in load_and_process_images for {tile_id}-{unique_number}: {e}")
raise
def _load_single_image(self, filepath: str) -> Optional[np.ndarray]:
"""Load a single image file with proper error handling."""
try:
if not os.path.exists(filepath):
logging.warning(f"File not found: {filepath}")
return None
if self.is_fits:
with fits.open(filepath, memmap=True) as hdul:
data = hdul[0].data
if data is None:
logging.error(f"No data in FITS file: {filepath}")
return None
return data.astype(np.float32)
else: # PNG case
return plt.imread(filepath)
except Exception as e:
logging.error(f"Error loading image {filepath}: {e}")
return None
def _get_image_paths(self, tile_id: str, unique_number: int) -> dict:
"""Get paths for all image types."""
try:
# First find the .sub file
base_pattern = f"**/*{tile_id}*.com.{unique_number}.sub.{self.config.file_type}"
sub_files = glob.glob(os.path.join(self.config.data_directory, base_pattern), recursive=True)