diff --git a/src/Utility/Pre-Processing/STOFS-3D-Atl-shadow-VIMS/Pre_processing/Source_sink/Replace_with_USGS/replace_with_obs.py b/src/Utility/Pre-Processing/STOFS-3D-Atl-shadow-VIMS/Pre_processing/Source_sink/Replace_with_USGS/replace_with_obs.py index bcf44127c..18649c870 100644 --- a/src/Utility/Pre-Processing/STOFS-3D-Atl-shadow-VIMS/Pre_processing/Source_sink/Replace_with_USGS/replace_with_obs.py +++ b/src/Utility/Pre-Processing/STOFS-3D-Atl-shadow-VIMS/Pre_processing/Source_sink/Replace_with_USGS/replace_with_obs.py @@ -10,6 +10,7 @@ from datetime import datetime #, timedelta import matplotlib.pyplot as plt import xarray as xr +import geopandas as gpd import json from schism_py_pre_post.Grid.Prop import Prop @@ -17,6 +18,7 @@ from schism_py_pre_post.Download.Data import ObsData from schism_py_pre_post.Download.download_usgs_with_api import download_stations, usgs_var_dict, convert_to_ObsData from pylib import schism_grid +import pickle def BinA(A=None, B=None): @@ -91,14 +93,11 @@ def __init__(self, st_id=None, nearby_nwm_fid=None): class vsource: - def __init__(self, xyz=[0, 0, 0], hgrid_ie=0, nwm_fid=0, df=pd.DataFrame(), nwm_idx_g2l=None): + def __init__(self, xyz=[0, 0, 0], hgrid_ie=0, nwm_fid=0, df=pd.DataFrame()): self.xyz = xyz self.usgs_st = [] self.nwm_fid = nwm_fid - if nwm_idx_g2l is not None: - self.seg_id = nwm_idx_g2l[self.nwm_fid] - self.upstream_path = search_path() self.downstream_path = search_path() self.hgrid_ie = hgrid_ie @@ -119,71 +118,84 @@ def df(self, dataframe): class search_path: + ''' + This class is to search along the upstream or downstream path of a vsource. + ''' def __init__(self): self._segID = [] - self._segfID = [] self._order = [] def forward_path(self, seg_id, seg_order): self._segID.append(seg_id) - self._segfID.append(index_l2g[seg_id]) self._order.append(seg_order) def reset_path(self): self._segID = [] - self._segfID = [] self._order = [] - def writer(self, fname): + def writer(self, fname, shp_gpd=None): with open(fname, 'w') as fout: fout.write(f'lon lat \n') - for i in self._segID: - fout.write(str(nwm_shp_records[i].lon) + " " + - str(nwm_shp_records[i].lat) + "\n") + for fid in self._segID: + fout.write(str(shp_gpd.loc[fid].lon) + " " + str(shp_gpd.loc[fid].lat) + "\n") -def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, total_seg_length=0.0, order=0): - """This is a recursive function to find the up/down-stream USGS stations""" - if starting_seg_id is None: # boundary reached - return +def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, total_seg_length=0.0, order=0, nwm_shp=None): + """This is a recursive function to find the up/down-stream USGS stations, + The search stops when the total length of the path exceeds 30 km. + The order of the stream is also recorded, but not used as a stopping criterion at this moment.""" + + if starting_seg_id is None: + return # some starting segs may not be in nwm_shp, e.g., near the boundary of the subset domain + + try: + starting_seg_rec = nwm_shp.loc[starting_seg_id] + except KeyError: + return # some starting segs may not be in nwm_shp, e.g., near the boundary of the subset domain if iupstream: vsource.upstream_path.forward_path(seg_id=starting_seg_id, seg_order=order) else: - vsource.downstream_path.forward_path(seg_id=starting_seg_id, seg_order=order) + vsource.upstream_path.forward_path(seg_id=starting_seg_id, seg_order=order) - total_seg_length += nwm_shp_records[starting_seg_id].Length - if (total_seg_length > 30e3): + total_seg_length += nwm_shp.loc[starting_seg_id].Length + if (total_seg_length > 30e3): # 30 km return - if nwm_shp_records[starting_seg_id].gages != '': # gage found - st = usgs_st(st_id=nwm_shp_records[starting_seg_id].gages, nearby_nwm_fid=index_l2g[starting_seg_id]) + if nwm_shp.loc[starting_seg_id].gages is not None and nwm_shp.loc[starting_seg_id].gages != '': # gage found + st = usgs_st(st_id=nwm_shp.loc[starting_seg_id].gages, nearby_nwm_fid=starting_seg_id) vsource.usgs_st.append(st) return else: # search upstream or downstream if iupstream: - # uids = [None] if no upstream seg, thus "for uid in uids" is skipped - uids = [index_g2l.get(key) for key in from_seg[starting_seg_id]] + uids = nwm_shp.loc[starting_seg_id]['from'] else: - # uids = [None] if no downstream seg - uids = [index_g2l[nwm_shp_records[starting_seg_id].to]] + uids = [nwm_shp.loc[starting_seg_id]['to']] # 'to' is a single value, make it a list for the for loop for uid in uids: find_usgs_along_nwm(iupstream=iupstream, starting_seg_id=uid, - vsource=vsource, total_seg_length=total_seg_length, order=order+1) + vsource=vsource, total_seg_length=total_seg_length, order=order+1, nwm_shp=nwm_shp) -if __name__ == "__main__": - # ----------------------------------------------------------------------------------- - # ------inputs------ - # ----------------------------------------------------------------------------------- - start_time_str = "2017-12-01 00:00:00" - f_shapefile = "/sciclone/schism10/feiye/schism20/REPO/NWM/Shapefiles/ecgc/ecgc.shp" - run_dir = '/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/' - nwm_data_dir = '/sciclone/schism10/whuang07/schism20/NWM_v2.1/' - output_dir = '/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/USGS_adjusted_sources/' - - # -------end inputs ------ +def source_nwm2usgs( + start_time_str="2017-12-01 00:00:00", + f_shapefile="/sciclone/schism10/feiye/schism20/REPO/NWM/Shapefiles/ecgc/ecgc.shp", + run_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/', + nwm_data_dir='/sciclone/schism10/whuang07/schism20/NWM_v2.1/', + output_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/USGS_adjusted_sources/', +): + ''' + This script is to adjust the vsource.th file based on USGS data. + The adjustment is done by comparing the NWM data near the vsource and the USGS data near the vsource. + The USGS data is downloaded from USGS website using the API. + The NWM data is pre-downloaded during the original NWM to source/sink conversion by gen_source_sink.py. + See the Atlas paper draft for details of how USGS stations are selected for each vsource. + In short, the USGS stations are selected along the upstream and downstream branches of the vsource. + The NWM shapefile provides the information of the upstream and downstream branches, + as well as the location of the USGS stations. + However, only NWM v1's shapefile has the USGS station information. + This is acceptable because the segment IDs are the same between NWM v1 and later versions. + ''' usgs_data_fname = f'{output_dir}/usgs_data.pkl' @@ -205,41 +217,40 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota # ----------------------------------------------------------------------------------- - # ------Read NWM shapefile------ + # ------Read and process NWM shapefile------ # ----------------------------------------------------------------------------------- t1 = time.time() - sf = shapefile.Reader(f_shapefile) - nwm_shp_shapes = sf.shapes() - nwm_shp_records = sf.records() + if os.path.exists(f'{f_shapefile}.pkl'): # cached + with open(f'{f_shapefile}.pkl', 'rb') as file: + nwm_shp = pickle.load(file) + else: + nwm_shp = gpd.read_file(f_shapefile) - sf.shapeType - len(sf) + # append "from" property to each seg + nwm_shp['from'] = None - # Another way to get individual records # idx = 200143 # s = sf.shape(idx) # r = sf.record(idx) + index_g2l = {} # global to local index + for i, rec in nwm_shp.iterrows(): + index_g2l[rec.featureID] = i - # Setup mapping between local index and featureID - index_l2g = np.empty(len(nwm_shp_records), dtype=int) - from_seg = [] # np.empty(len(nwm_shp_records), dtype=int) - n_from_seg = np.zeros(len(nwm_shp_records), dtype=int) + from_list = [[] for _ in range(len(nwm_shp))] # directly use nwm_shp.loc is slow + for i, rec in nwm_shp.iterrows(): + try: + from_list[index_g2l[rec.to]].append(rec.featureID) # index is featureID + except KeyError: + pass # index_g2l[rec.to] can be None, then 0 is pre-assigned + nwm_shp['from'] = from_list - # Add "from" (upstream seg id) info to each seg - index_g2l = dict() - for i, [rec, shp] in enumerate(zip(nwm_shp_records, nwm_shp_shapes)): - index_l2g[i] = rec.featureID - index_g2l[rec.featureID] = i - index_g2l[0] = None # handle no downstream case (to = 0) + # set featureID as index + nwm_shp = nwm_shp.set_index('featureID') + # duplicate featureID as a column + nwm_shp['featureID'] = nwm_shp.index - for i, [rec, shp] in enumerate(zip(nwm_shp_records, nwm_shp_shapes)): - from_seg.append([]) - for i, [rec, shp] in enumerate(zip(nwm_shp_records, nwm_shp_shapes)): - if rec.to in index_g2l.keys(): # rec.to can be outside if the shp is cut from the original one - id = index_g2l[rec.to] - if id is not None: - from_seg[id].append(rec.featureID) - n_from_seg[id] += 1 + with open(f'{f_shapefile}.pkl', 'wb') as file: + pickle.dump(nwm_shp, file) - print(f'---------------loading shapefile took: {time.time()-t1} s ---------------\n') + print(f'---------------loading and processing shapefile took: {time.time()-t1} s ---------------\n') # ----------------------------------------------------------------------------------- # ------load vsource------ @@ -250,7 +261,7 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota hg = schism_grid(f'{run_dir}/hgrid.gr3') # hg.save() hg.compute_ctr() - # group 0 is source; 1 is sink + # read featureID at vsource elements if os.path.exists(f'{run_dir}/ele_fid_dict.json'): with open(f'{run_dir}/ele_fid_dict.json') as json_file: mysrc_nwm_fid = json.load(json_file) @@ -273,7 +284,6 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota hgrid_ie=source_ele, nwm_fid=mysrc_nwm_fid[str(source_ele)], df=pd.DataFrame({'datetime': my_th.datetime, 'Data': my_th.data[:, i]}), - nwm_idx_g2l=index_g2l ) pass print(f'---------------loading vsource took: {time.time()-t1} s ---------------\n') @@ -283,21 +293,12 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota # ----------------------------------------------------------------------------------- t1 = time.time() for i, my_source in enumerate(my_sources): - source_lon = my_source.xyz[0] - source_lat = my_source.xyz[1] - - # if my_source.hgrid_ie not in ele_debug: - # continue + # recursively go through upstream and downstream branches to find available USGS stations + find_usgs_along_nwm(iupstream=True, starting_seg_id=my_source.nwm_fid, vsource=my_source, total_seg_length=0.0, order=0, nwm_shp=nwm_shp) + my_source.upstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_upstream', shp_gpd=nwm_shp) - # recursively go through upstream and downstream branches - # to find available USGS stations - find_usgs_along_nwm( - iupstream=True, starting_seg_id=my_source.seg_id, vsource=my_source, total_seg_length=0.0) - my_source.upstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_upstream') - - find_usgs_along_nwm( - iupstream=False, starting_seg_id=my_source.seg_id, vsource=my_source, total_seg_length=0.0) - my_source.downstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_downstream') + find_usgs_along_nwm(iupstream=False, starting_seg_id=my_source.nwm_fid, vsource=my_source, total_seg_length=0.0, order=0, nwm_shp=nwm_shp) + my_source.downstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_downstream', shp_gpd=nwm_shp) usgs_stations = [st.st_id for my_source in my_sources for st in my_source.usgs_st] usgs_stations = list(set(usgs_stations)) # remove duplicates @@ -425,4 +426,13 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota my_th.df_propagate() my_th.writer(f'{output_dir}/adjusted_vsource.th') - pass + +if __name__ == "__main__": + # sample usage (the final product is "adjusted_vsource.th" in the output_dir): + source_nwm2usgs( + start_time_str="2017-12-01 00:00:00", + f_shapefile="/sciclone/schism10/feiye/schism20/REPO/NWM/Shapefiles/ecgc/ecgc.shp", + run_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/', + nwm_data_dir='/sciclone/schism10/whuang07/schism20/NWM_v2.1/', + output_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/USGS_adjusted_sources/', + )