Skip to content

Commit

Permalink
Merge pull request #81 from JPMastrogiacomo/fix_dotprod_extradims
Browse files Browse the repository at this point in the history
Fix dotprod extradims
  • Loading branch information
ks905383 authored Sep 26, 2024
2 parents cc5fd6c + c037845 commit dbbc8b3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
10 changes: 3 additions & 7 deletions xagg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def aggregate(ds,wm,impl=None,silent=None):
normed_weights = normed_weights.fillna(0)

# finally we do the dot product to get the weighted averages
aggregated_array = normed_weights.dot(var_array_filled)
aggregated_array = normed_weights.dot(var_array_filled, dim='loc')

# if the original gridded values were all nan, make the final
# aggregation nan
Expand All @@ -679,17 +679,13 @@ def aggregate(ds,wm,impl=None,silent=None):

data_dict[var] = aggregated_array

ds_combined = xr.Dataset(data_dict)
df_combined = ds_combined.to_dataframe().reset_index()
df_combined = df_combined.groupby('poly_idx').agg(list_or_first)
ds_combined = xr.Dataset(data_dict)

wm.agg = pd.merge(wm.agg, df_combined, on='poly_idx')
for var in ds:
if ('bnds' not in ds[var].sizes) & ('loc' in ds[var].sizes):
# convert to list of arrays - NOT SURE THIS IS THE RIGHT THING TO
# DO, JUST TRYING TO MATCH ORIGINAL FORMAT
wm.agg[var] = wm.agg[var].apply(np.array).apply(lambda x: [x])

wm.agg[var]=pd.Series([[[ds_combined[var].isel(poly_idx=i).values]] for i in range(len(ds_combined.poly_idx))])
# Put in class format
agg_out = aggregated(agg=wm.agg,source_grid=wm.source_grid,
geometry=wm.geometry,ds_in=ds_combined,weights=wm.weights)
Expand Down
7 changes: 3 additions & 4 deletions xagg/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ def prep_for_nc(agg_obj,loc_dim='poly_idx'):
else:
# For data variables (from the input grid), create empty array
ds_out[var] = xr.DataArray(data=np.zeros((len(agg_obj.agg),
*[agg_obj.ds_in[var].sizes[k] for k in agg_obj.ds_in[var].sizes.keys() if k not in ['lat','lon','loc']]))*np.nan,
dims=['poly_idx',*[k for k in agg_obj.ds_in[var].sizes.keys() if k not in ['lat','lon','loc']]],
coords=[[k for k in agg_obj.agg.poly_idx],*[agg_obj.ds_in[var][k].values for k in agg_obj.ds_in[var].sizes.keys() if k not in ['lat','lon','loc']]])

*[agg_obj.ds_in[var].sizes[k] for k in agg_obj.ds_in[var].sizes.keys() if k not in ['lat','lon','loc','poly_idx']]))*np.nan,
dims=['poly_idx',*[k for k in agg_obj.ds_in[var].sizes.keys() if k not in ['lat','lon','loc','poly_idx']]],
coords=[[k for k in agg_obj.agg.poly_idx],*[agg_obj.ds_in[var][k].values for k in agg_obj.ds_in[var].sizes.keys() if k not in ['lat','lon','loc','poly_idx']]])
# Now insert aggregated values
for poly_idx in agg_obj.agg.poly_idx:
ds_out[var].loc[{'poly_idx':poly_idx}] = np.squeeze(agg_obj.agg.loc[poly_idx,var])
Expand Down

0 comments on commit dbbc8b3

Please sign in to comment.