Skip to content

Commit

Permalink
STOFS scripts: minor edits to make the script clearer.
Browse files Browse the repository at this point in the history
  • Loading branch information
feiye-vims committed Dec 4, 2023
1 parent d557074 commit abd6559
Showing 1 changed file with 89 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from datetime import datetime #, timedelta
import matplotlib.pyplot as plt
import xarray as xr
import geopandas as gpd

import json
from schism_py_pre_post.Grid.Prop import Prop
from schism_py_pre_post.Timeseries.TimeHistory import TimeHistory
from schism_py_pre_post.Download.Data import ObsData
from schism_py_pre_post.Download.download_usgs_with_api import download_stations, usgs_var_dict, convert_to_ObsData
from pylib import schism_grid
import pickle


def BinA(A=None, B=None):
Expand Down Expand Up @@ -91,14 +93,11 @@ def __init__(self, st_id=None, nearby_nwm_fid=None):


class vsource:
def __init__(self, xyz=[0, 0, 0], hgrid_ie=0, nwm_fid=0, df=pd.DataFrame(), nwm_idx_g2l=None):
def __init__(self, xyz=[0, 0, 0], hgrid_ie=0, nwm_fid=0, df=pd.DataFrame()):
self.xyz = xyz
self.usgs_st = []
self.nwm_fid = nwm_fid

if nwm_idx_g2l is not None:
self.seg_id = nwm_idx_g2l[self.nwm_fid]

self.upstream_path = search_path()
self.downstream_path = search_path()
self.hgrid_ie = hgrid_ie
Expand All @@ -119,71 +118,84 @@ def df(self, dataframe):


class search_path:
'''
This class is to search along the upstream or downstream path of a vsource.
'''
def __init__(self):
self._segID = []
self._segfID = []
self._order = []

def forward_path(self, seg_id, seg_order):
self._segID.append(seg_id)
self._segfID.append(index_l2g[seg_id])
self._order.append(seg_order)

def reset_path(self):
self._segID = []
self._segfID = []
self._order = []

def writer(self, fname):
def writer(self, fname, shp_gpd=None):
with open(fname, 'w') as fout:
fout.write(f'lon lat \n')
for i in self._segID:
fout.write(str(nwm_shp_records[i].lon) + " " +
str(nwm_shp_records[i].lat) + "\n")
for fid in self._segID:
fout.write(str(shp_gpd.loc[fid].lon) + " " + str(shp_gpd.loc[fid].lat) + "\n")


def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, total_seg_length=0.0, order=0):
"""This is a recursive function to find the up/down-stream USGS stations"""
if starting_seg_id is None: # boundary reached
return
def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, total_seg_length=0.0, order=0, nwm_shp=None):
"""This is a recursive function to find the up/down-stream USGS stations,
The search stops when the total length of the path exceeds 30 km.
The order of the stream is also recorded, but not used as a stopping criterion at this moment."""

if starting_seg_id is None:
return # some starting segs may not be in nwm_shp, e.g., near the boundary of the subset domain

try:
starting_seg_rec = nwm_shp.loc[starting_seg_id]
except KeyError:
return # some starting segs may not be in nwm_shp, e.g., near the boundary of the subset domain

if iupstream:
vsource.upstream_path.forward_path(seg_id=starting_seg_id, seg_order=order)
else:
vsource.downstream_path.forward_path(seg_id=starting_seg_id, seg_order=order)
vsource.upstream_path.forward_path(seg_id=starting_seg_id, seg_order=order)

total_seg_length += nwm_shp_records[starting_seg_id].Length
if (total_seg_length > 30e3):
total_seg_length += nwm_shp.loc[starting_seg_id].Length
if (total_seg_length > 30e3): # 30 km
return

if nwm_shp_records[starting_seg_id].gages != '': # gage found
st = usgs_st(st_id=nwm_shp_records[starting_seg_id].gages, nearby_nwm_fid=index_l2g[starting_seg_id])
if nwm_shp.loc[starting_seg_id].gages is not None and nwm_shp.loc[starting_seg_id].gages != '': # gage found
st = usgs_st(st_id=nwm_shp.loc[starting_seg_id].gages, nearby_nwm_fid=starting_seg_id)
vsource.usgs_st.append(st)
return
else: # search upstream or downstream
if iupstream:
# uids = [None] if no upstream seg, thus "for uid in uids" is skipped
uids = [index_g2l.get(key) for key in from_seg[starting_seg_id]]
uids = nwm_shp.loc[starting_seg_id]['from']
else:
# uids = [None] if no downstream seg
uids = [index_g2l[nwm_shp_records[starting_seg_id].to]]
uids = [nwm_shp.loc[starting_seg_id]['to']] # 'to' is a single value, make it a list for the for loop

for uid in uids:
find_usgs_along_nwm(iupstream=iupstream, starting_seg_id=uid,
vsource=vsource, total_seg_length=total_seg_length, order=order+1)
vsource=vsource, total_seg_length=total_seg_length, order=order+1, nwm_shp=nwm_shp)


if __name__ == "__main__":
# -----------------------------------------------------------------------------------
# ------inputs------
# -----------------------------------------------------------------------------------
start_time_str = "2017-12-01 00:00:00"
f_shapefile = "/sciclone/schism10/feiye/schism20/REPO/NWM/Shapefiles/ecgc/ecgc.shp"
run_dir = '/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/'
nwm_data_dir = '/sciclone/schism10/whuang07/schism20/NWM_v2.1/'
output_dir = '/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/USGS_adjusted_sources/'

# -------end inputs ------
def source_nwm2usgs(
start_time_str="2017-12-01 00:00:00",
f_shapefile="/sciclone/schism10/feiye/schism20/REPO/NWM/Shapefiles/ecgc/ecgc.shp",
run_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/',
nwm_data_dir='/sciclone/schism10/whuang07/schism20/NWM_v2.1/',
output_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/USGS_adjusted_sources/',
):
'''
This script is to adjust the vsource.th file based on USGS data.
The adjustment is done by comparing the NWM data near the vsource and the USGS data near the vsource.
The USGS data is downloaded from USGS website using the API.
The NWM data is pre-downloaded during the original NWM to source/sink conversion by gen_source_sink.py.
See the Atlas paper draft for details of how USGS stations are selected for each vsource.
In short, the USGS stations are selected along the upstream and downstream branches of the vsource.
The NWM shapefile provides the information of the upstream and downstream branches,
as well as the location of the USGS stations.
However, only NWM v1's shapefile has the USGS station information.
This is acceptable because the segment IDs are the same between NWM v1 and later versions.
'''

usgs_data_fname = f'{output_dir}/usgs_data.pkl'

Expand All @@ -205,41 +217,40 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota


# -----------------------------------------------------------------------------------
# ------Read NWM shapefile------
# ------Read and process NWM shapefile------
# -----------------------------------------------------------------------------------
t1 = time.time()

sf = shapefile.Reader(f_shapefile)
nwm_shp_shapes = sf.shapes()
nwm_shp_records = sf.records()
if os.path.exists(f'{f_shapefile}.pkl'): # cached
with open(f'{f_shapefile}.pkl', 'rb') as file:
nwm_shp = pickle.load(file)
else:
nwm_shp = gpd.read_file(f_shapefile)

sf.shapeType
len(sf)
# append "from" property to each seg
nwm_shp['from'] = None

# Another way to get individual records # idx = 200143 # s = sf.shape(idx) # r = sf.record(idx)
index_g2l = {} # global to local index
for i, rec in nwm_shp.iterrows():
index_g2l[rec.featureID] = i

# Setup mapping between local index and featureID
index_l2g = np.empty(len(nwm_shp_records), dtype=int)
from_seg = [] # np.empty(len(nwm_shp_records), dtype=int)
n_from_seg = np.zeros(len(nwm_shp_records), dtype=int)
from_list = [[] for _ in range(len(nwm_shp))] # directly use nwm_shp.loc is slow
for i, rec in nwm_shp.iterrows():
try:
from_list[index_g2l[rec.to]].append(rec.featureID) # index is featureID
except KeyError:
pass # index_g2l[rec.to] can be None, then 0 is pre-assigned
nwm_shp['from'] = from_list

# Add "from" (upstream seg id) info to each seg
index_g2l = dict()
for i, [rec, shp] in enumerate(zip(nwm_shp_records, nwm_shp_shapes)):
index_l2g[i] = rec.featureID
index_g2l[rec.featureID] = i
index_g2l[0] = None # handle no downstream case (to = 0)
# set featureID as index
nwm_shp = nwm_shp.set_index('featureID')
# duplicate featureID as a column
nwm_shp['featureID'] = nwm_shp.index

for i, [rec, shp] in enumerate(zip(nwm_shp_records, nwm_shp_shapes)):
from_seg.append([])
for i, [rec, shp] in enumerate(zip(nwm_shp_records, nwm_shp_shapes)):
if rec.to in index_g2l.keys(): # rec.to can be outside if the shp is cut from the original one
id = index_g2l[rec.to]
if id is not None:
from_seg[id].append(rec.featureID)
n_from_seg[id] += 1
with open(f'{f_shapefile}.pkl', 'wb') as file:
pickle.dump(nwm_shp, file)

print(f'---------------loading shapefile took: {time.time()-t1} s ---------------\n')
print(f'---------------loading and processing shapefile took: {time.time()-t1} s ---------------\n')

# -----------------------------------------------------------------------------------
# ------load vsource------
Expand All @@ -250,7 +261,7 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota
hg = schism_grid(f'{run_dir}/hgrid.gr3') # hg.save()
hg.compute_ctr()

# group 0 is source; 1 is sink
# read featureID at vsource elements
if os.path.exists(f'{run_dir}/ele_fid_dict.json'):
with open(f'{run_dir}/ele_fid_dict.json') as json_file:
mysrc_nwm_fid = json.load(json_file)
Expand All @@ -273,7 +284,6 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota
hgrid_ie=source_ele,
nwm_fid=mysrc_nwm_fid[str(source_ele)],
df=pd.DataFrame({'datetime': my_th.datetime, 'Data': my_th.data[:, i]}),
nwm_idx_g2l=index_g2l
)
pass
print(f'---------------loading vsource took: {time.time()-t1} s ---------------\n')
Expand All @@ -283,21 +293,12 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota
# -----------------------------------------------------------------------------------
t1 = time.time()
for i, my_source in enumerate(my_sources):
source_lon = my_source.xyz[0]
source_lat = my_source.xyz[1]

# if my_source.hgrid_ie not in ele_debug:
# continue
# recursively go through upstream and downstream branches to find available USGS stations
find_usgs_along_nwm(iupstream=True, starting_seg_id=my_source.nwm_fid, vsource=my_source, total_seg_length=0.0, order=0, nwm_shp=nwm_shp)
my_source.upstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_upstream', shp_gpd=nwm_shp)

# recursively go through upstream and downstream branches
# to find available USGS stations
find_usgs_along_nwm(
iupstream=True, starting_seg_id=my_source.seg_id, vsource=my_source, total_seg_length=0.0)
my_source.upstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_upstream')

find_usgs_along_nwm(
iupstream=False, starting_seg_id=my_source.seg_id, vsource=my_source, total_seg_length=0.0)
my_source.downstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_downstream')
find_usgs_along_nwm(iupstream=False, starting_seg_id=my_source.nwm_fid, vsource=my_source, total_seg_length=0.0, order=0, nwm_shp=nwm_shp)
my_source.downstream_path.writer(f'{output_dir}/Ele{str(my_source.hgrid_ie)}_downstream', shp_gpd=nwm_shp)

usgs_stations = [st.st_id for my_source in my_sources for st in my_source.usgs_st]
usgs_stations = list(set(usgs_stations)) # remove duplicates
Expand Down Expand Up @@ -425,4 +426,13 @@ def find_usgs_along_nwm(iupstream=True, starting_seg_id=None, vsource=None, tota
my_th.df_propagate()
my_th.writer(f'{output_dir}/adjusted_vsource.th')

pass

if __name__ == "__main__":
# sample usage (the final product is "adjusted_vsource.th" in the output_dir):
source_nwm2usgs(
start_time_str="2017-12-01 00:00:00",
f_shapefile="/sciclone/schism10/feiye/schism20/REPO/NWM/Shapefiles/ecgc/ecgc.shp",
run_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/',
nwm_data_dir='/sciclone/schism10/whuang07/schism20/NWM_v2.1/',
output_dir='/sciclone/schism10/feiye/Requests/RUN02a_JZ/src/NWM/USGS_adjusted_sources/',
)

0 comments on commit abd6559

Please sign in to comment.