Skip to content

Commit

Permalink
Update data importers (#1118)
Browse files Browse the repository at this point in the history
* start updating argo

* fix argo imports

* update glider script

* updated seal and nafc_ctd scripts

* updated the argo and glider scripts

* fix indent errors and formatting

---------

Co-authored-by: JustinElms <justinwelms@gmail.com>
Co-authored-by: JustinElms <justin.elms@dfo-mpo.gc.ca>
  • Loading branch information
3 people authored Apr 2, 2024
1 parent ffd4dba commit 76bf1d6
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 233 deletions.
208 changes: 114 additions & 94 deletions scripts/data_importers/argo.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@

import glob
import os
import sys

import defopt
import pandas as pd
import gsw
import xarray as xr
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(os.path.dirname(current))
sys.path.append(parent)

import data.observational
from data.observational import DataType, Platform, Sample, Station

VARIABLES = ["TEMP", "PSAL"]
Expand All @@ -33,103 +39,117 @@ def main(uri: str, filename: str):
:param str uri: Database URI
:param str filename: Argo NetCDF Filename, or directory of files
"""
data.observational.init_db(uri, echo=False)
data.observational.create_tables()

session = data.observational.db.session

if os.path.isdir(filename):
filenames = sorted(glob.glob(os.path.join(filename, "*.nc")))
else:
filenames = [filename]

for fname in filenames:
print(fname)
with xr.open_dataset(fname) as ds:
times = pd.to_datetime(ds.JULD.values)

for f in META_FIELDS:
META_FIELDS[f] = ds[f].values.astype(str)

for prof in ds.N_PROF.values:
plat_number = ds.PLATFORM_NUMBER.values.astype(str)[prof]
unique_id = f"argo_{plat_number}"

# Grab the platform from the db base on the unique id
platform = (
session.query(Platform)
.filter(
Platform.unique_id == unique_id,
Platform.type == Platform.Type.argo,
)
.first()
)
if platform is None:
# ... or make a new platform
platform = Platform(type=Platform.Type.argo, unique_id=unique_id)
attrs = {}
for f in META_FIELDS:
attrs[ds[f].long_name] = META_FIELDS[f][prof].strip()

platform.attrs = attrs
session.add(platform)

# Make a new Station
station = Station(
time=times[prof],
latitude=ds.LATITUDE.values[prof],
longitude=ds.LONGITUDE.values[prof],
)
platform.stations.append(station)
# We need to commit the station here so that it'll have an id
session.commit()

depth = abs(gsw.conversions.z_from_p(
ds.PRES[prof].dropna("N_LEVELS").values, ds.LATITUDE.values[prof]
))

samples = []
for variable in VARIABLES:
# First check our local cache for the DataType object, if
# that comes up empty, check the db, and failing that,
# create a new one from the variable's attributes
if variable not in datatype_map:
dt = DataType.query.get(ds[variable].standard_name)
if dt is None:
dt = DataType(
key=ds[variable].standard_name,
name=ds[variable].long_name,
unit=ds[variable].units,
)

data.observational.db.session.add(dt)
# Commit the DataType right away. This might lead
# to a few extra commits on the first import, but
# reduces overall complexity in having to
# 'remember' if we added a new one later.
data.observational.db.session.commit()
datatype_map[variable] = dt
else:
dt = datatype_map[variable]

values = ds[variable][prof].dropna("N_LEVELS").values

# Using station_id and datatype_key here instead of the
# actual objects so that we can use bulk_save_objects--this
# is much faster, but it doesn't follow any relationships.
samples = [
Sample(
depth=pair[0],
datatype_key=dt.key,
value=pair[1],
station_id=station.id,
engine = create_engine(
uri,
connect_args={"connect_timeout": 10},
pool_recycle=3600,
)

with Session(engine) as session:

if os.path.isdir(filename):
filenames = sorted(glob.glob(os.path.join(filename, "*.nc")))
else:
filenames = [filename]

for fname in filenames:
print(fname)
with xr.open_dataset(fname) as ds:
times = pd.to_datetime(ds.JULD.values)

for f in META_FIELDS:
META_FIELDS[f] = ds[f].values.astype(str)

for prof in ds.N_PROF.values:
plat_number = ds.PLATFORM_NUMBER.values.astype(str)[prof]
unique_id = f"argo_{plat_number}"

# Grab the platform from the db base on the unique id
platform = (
session.query(Platform)
.filter(
Platform.unique_id == unique_id,
Platform.type == Platform.Type.argo,
)
for pair in zip(depth, values)
]
.first()
)
if platform is None:
# ... or make a new platform
platform = Platform(
type=Platform.Type.argo, unique_id=unique_id
)
attrs = {}
for f in META_FIELDS:
attrs[ds[f].long_name] = META_FIELDS[f][prof].strip()

platform.attrs = attrs
session.add(platform)

# Make a new Station
station = Station(
time=times[prof],
latitude=ds.LATITUDE.values[prof],
longitude=ds.LONGITUDE.values[prof],
)
platform.stations.append(station)
# We need to commit the station here so that it'll have an id
session.commit()

depth = abs(
gsw.conversions.z_from_p(
ds.PRES[prof].dropna("N_LEVELS").values,
ds.LATITUDE.values[prof],
)
)

samples = []
for variable in VARIABLES:
# First check our local cache for the DataType object, if
# that comes up empty, check the db, and failing that,
# create a new one from the variable's attributes
if variable not in datatype_map:
statement = select(DataType).where(
DataType.key == ds[variable].standard_name
)
dt = session.execute(statement).all()
if not dt:
dt = DataType(
key=ds[variable].standard_name,
name=ds[variable].long_name,
unit=ds[variable].units,
)

session.add(dt)
# Commit the DataType right away. This might lead
# to a few extra commits on the first import, but
# reduces overall complexity in having to
# 'remember' if we added a new one later.
session.commit()
datatype_map[variable] = dt
else:
dt = dt[0][0] # fix this
else:
dt = datatype_map[variable]

values = ds[variable][prof].dropna("N_LEVELS").values

# Using station_id and datatype_key here instead of the
# actual objects so that we can use bulk_save_objects--this
# is much faster, but it doesn't follow any relationships.
samples = [
Sample(
depth=pair[0],
datatype_key=dt.key,
value=pair[1],
station_id=station.id,
)
for pair in zip(depth, values)
]

data.observational.db.session.bulk_save_objects(samples)
session.bulk_save_objects(samples)

session.commit()
session.commit()


if __name__ == "__main__":
Expand Down
77 changes: 43 additions & 34 deletions scripts/data_importers/glider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import glob
import os
import sys

import defopt
import gsw
import xarray as xr
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(os.path.dirname(current))
sys.path.append(parent)

import data.observational
from data.observational import DataType, Platform, Sample, Station

VARIABLES = ["CHLA", "PSAL", "TEMP", "CNDC"]
Expand All @@ -19,42 +25,49 @@ def main(uri: str, filename: str):
:param str uri: Database URI
:param str filename: Glider Filename, or directory of NetCDF files
"""
data.observational.init_db(uri, echo=False)
data.observational.create_tables()

if os.path.isdir(filename):
filenames = sorted(glob.glob(os.path.join(filename, "*.nc")))
else:
filenames = [filename]

datatype_map = {}
for fname in filenames:
print(fname)
with xr.open_dataset(fname) as ds:
variables = [v for v in VARIABLES if v in ds.variables]
df = (
ds[["TIME", "LATITUDE", "LONGITUDE", "PRES", *variables]]
.to_dataframe()
.reset_index()
.dropna()
)
engine = create_engine(
uri,
connect_args={"connect_timeout": 10},
pool_recycle=3600,
)

with Session(engine) as session:
if os.path.isdir(filename):
filenames = sorted(glob.glob(os.path.join(filename, "*.nc")))
else:
filenames = [filename]

datatype_map = {}
for fname in filenames:
print(fname)
with xr.open_dataset(fname) as ds:
variables = [v for v in VARIABLES if v in ds.variables]
df = (
ds[["TIME", "LATITUDE", "LONGITUDE", "PRES", *variables]]
.to_dataframe()
.reset_index()
.dropna()
)

df["DEPTH"] = abs(gsw.conversions.z_from_p(df.PRES, df.LATITUDE))

for variable in variables:
if variable not in datatype_map:
dt = DataType.query.get(ds[variable].standard_name)
if dt is None:
statement = select(DataType).where(
DataType.key == ds[variable].standard_name
)
dt = session.execute(statement).all()
if not dt:
dt = DataType(
key=ds[variable].standard_name,
name=ds[variable].long_name,
unit=ds[variable].units,
)
data.observational.db.session.add(dt)
session.add(dt)

datatype_map[variable] = dt

data.observational.db.session.commit()
session.commit()

p = Platform(
type=Platform.Type.glider, unique_id=f"glider_{ds.deployment_label}"
Expand All @@ -67,8 +80,8 @@ def main(uri: str, filename: str):
"Contact": ds.contact,
}
p.attrs = attrs
data.observational.db.session.add(p)
data.observational.db.session.commit()
session.add(p)
session.commit()

stations = [
Station(
Expand All @@ -84,9 +97,7 @@ def main(uri: str, filename: str):
# updated with id's. It's slower, but it means that we can just
# put all the station ids into a pandas series to use when
# constructing the samples.
data.observational.db.session.bulk_save_objects(
stations, return_defaults=True
)
session.bulk_save_objects(stations, return_defaults=True)
df["STATION_ID"] = [s.id for s in stations]

samples = [
Expand All @@ -101,12 +112,10 @@ def main(uri: str, filename: str):
]
for idx, row in df.iterrows()
]
data.observational.db.session.bulk_save_objects(
[item for sublist in samples for item in sublist]
)
data.observational.db.session.commit()
session.bulk_save_objects([item for sublist in samples for item in sublist])
session.commit()

data.observational.db.session.commit()
session.commit()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 76bf1d6

Please sign in to comment.