Skip to content

Commit

Permalink
fix trackplot db
Browse files Browse the repository at this point in the history
  • Loading branch information
JustinElms committed Jul 25, 2024
1 parent 8ad39c4 commit be63945
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 25 deletions.
45 changes: 21 additions & 24 deletions plotting/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from geopy.distance import distance
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.interpolate import interp1d
from sqlalchemy import func
from sqlalchemy.orm import Session

import plotting.colormap as colormap
import plotting.utils as utils
Expand All @@ -18,26 +20,18 @@
DataType,
Platform,
Sample,
SessionLocal,
Station,
)
from data.observational.queries import get_platform_variable_track
from data.utils import datetime_to_timestamp
from plotting.plotter import Plotter


def get_db():
try:
db = SessionLocal()
yield db
finally:
db.close()


class TrackPlotter(Plotter):
def __init__(self, dataset_name: str, query: str, **kwargs):
def __init__(self, dataset_name: str, query: str, db: Session, **kwargs):
self.plottype: str = "track"
super().__init__(dataset_name, query, **kwargs)
self.db = db
self.size: str = "11x5"
self.model_depths = None

Expand Down Expand Up @@ -66,21 +60,21 @@ def parse_query(self, query):
self.track_quantum = query.get("track_quantum", "hour")

def load_data(self):
db = get_db
platform = db.session.query(Platform).get(self.platform)
platform = self.db.query(Platform).get(self.platform)
self.name = platform.unique_id

# First get the variable
st0 = db.session.query(Station).filter(Station.platform == platform).first()
datatype_keys = (
db.session.query(db.func.distinct(Sample.datatype_key))
st0 = self.db.query(Station).filter(Station.platform == platform).first()
datatype_keys = [
k[0]
for k in self.db.query(func.distinct(Sample.datatype_key))
.filter(Sample.station == st0)
.all()
)
]

datatypes = (
db.session.query(DataType)
.filter(DataType.key.in_([d for d, in datatype_keys]))
self.db.query(DataType)
.filter(DataType.key.in_([d for d in datatype_keys]))
.order_by(DataType.key)
.all()
)
Expand All @@ -94,7 +88,7 @@ def load_data(self):
for v in variables:
d.append(
get_platform_variable_track(
db.session,
self.db,
platform,
v.key,
self.track_quantum,
Expand Down Expand Up @@ -188,9 +182,9 @@ def load_data(self):
od = md

# Clear model data beneath observed data
od[
np.where(self.model_depths > max(self.depth))[0][1:], :
] = np.nan
od[np.where(self.model_depths > max(self.depth))[0][1:], :] = (
np.nan
)

d.append(od)

Expand Down Expand Up @@ -414,7 +408,10 @@ def plot(self):
# latlon
if self.latlon:
for j, label in enumerate(
["Latitude (degrees)", "Longitude (degrees)"] #[gettext("Latitude (degrees)"), gettext("Longitude (degrees)")]
[
"Latitude (degrees)",
"Longitude (degrees)",
] # [gettext("Latitude (degrees)"), gettext("Longitude (degrees)")]
):
plt.subplot(gs[subplot])
subplot += subplot_inc
Expand All @@ -425,7 +422,7 @@ def plot(self):
plt.setp(plt.gca().get_xticklabels(), rotation=30)

fig.suptitle(
"Track Plot (Observed %s - %s, Modelled %s - %s)" # gettext("Track Plot (Observed %s - %s, Modelled %s - %s)")
"Track Plot (Observed %s - %s, Modelled %s - %s)" # gettext("Track Plot (Observed %s - %s, Modelled %s - %s)")
% (
self.times[0].strftime("%Y-%m-%d"),
self.times[-1].strftime("%Y-%m-%d"),
Expand Down
2 changes: 1 addition & 1 deletion routes/api_v2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def make_response(data, mime):
elif plot_type == "observation":
plotter = ObservationPlotter(dataset, query, db, **options)
elif plot_type == "track":
plotter = TrackPlotter(dataset, query, **options)
plotter = TrackPlotter(dataset, query, db, **options)
elif plot_type == "class4":
plotter = Class4Plotter(dataset, query, **options)
elif plot_type == "stick":
Expand Down

0 comments on commit be63945

Please sign in to comment.