From 6fc9d695b0f839c3aa5e8ead2b8bd9ef11656fa0 Mon Sep 17 00:00:00 2001 From: Kevin Schwarzwald Date: Thu, 6 Jun 2024 13:30:57 -0400 Subject: [PATCH 1/3] Add new tests for exports, fix extra 'name' variable in exports --- tests/test_export.py | 139 +++++++++++++++++++++++++++++++++++++------ xagg/classes.py | 19 ++++-- xagg/export.py | 38 +++++++++--- 3 files changed, 166 insertions(+), 30 deletions(-) diff --git a/tests/test_export.py b/tests/test_export.py index f9de8dc..a9d4ec2 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -3,6 +3,7 @@ import numpy as np import xarray as xr import shutil +import os import geopandas as gpd from geopandas import testing as gpdt from unittest import TestCase @@ -10,6 +11,7 @@ from xagg.core import (process_weights,create_raster_polygons,get_pixel_overlaps,aggregate,read_wm) from xagg.wrappers import (pixel_overlaps) +from xagg.export import (prep_for_csv) ##### to_dataset() tests ##### @@ -38,8 +40,7 @@ def test_to_dataset(agg=agg): ds_out = agg.to_dataset() # Build reference output dataset - ds_ref = xr.Dataset({'name':(['poly_idx'],np.array(['test']).astype(object)), - 'test':(['poly_idx','run'],np.array([[1.0,2.0]]))}, + ds_ref = xr.Dataset({'test':(['poly_idx','run'],np.array([[1.0,2.0]]))}, coords={'poly_idx':(['poly_idx'],np.array([0])), 'run':(['run'],np.array([0,1]))}) @@ -47,22 +48,139 @@ def test_to_dataset(agg=agg): # variation off from actual 1.0, 2.0 due to crs xr.testing.assert_allclose(ds_out,ds_ref,atol=0.0001) +def test_to_dataset_renamelocdim(agg=agg): + # Change to dataset, renaming `poly_idx` + ds_out = agg.to_dataset(loc_dim = 'county_id') + + # Build reference output dataset + ds_ref = xr.Dataset({'test':(['county_id','run'],np.array([[1.0,2.0]]))}, + coords={'county_id':(['county_id'],np.array([0])), + 'run':(['run'],np.array([0,1]))}) + + # Assert equal within tolerance, again likely due to very slight + # variation off from actual 1.0, 2.0 due to crs + xr.testing.assert_allclose(ds_out,ds_ref,atol=0.0001) + +##### to_dataframe() tests ##### def test_to_dataframe(agg=agg): # Change to dataframe df_out = agg.to_dataframe() # Build reference output dataframe - df_ref = pd.DataFrame({'poly_idx':[0,0],'run':[0,1],'name':['test','test'],'test':[0.9999,1.9999]}) + df_ref = pd.DataFrame({'poly_idx':[0,0],'run':[0,1],'test':[0.9999,1.9999]}) df_ref = df_ref.set_index(['poly_idx','run']) # Assert equal within tolerance, again likely due to very slight # variation off from actual 1.0, 2.0 due to crs pd.testing.assert_frame_equal(df_out,df_ref,atol=0.0001) +def test_to_dataframe_renamelocdim(agg=agg): + # Change to dataframe, with a new name for the location dimension + df_out = agg.to_dataframe(loc_dim='county_id') + + # Build reference output dataframe + df_ref = pd.DataFrame({'county_id':[0,0],'run':[0,1],'test':[0.9999,1.9999]}) + df_ref = df_ref.set_index(['county_id','run']) + + # Assert equal within tolerance + pd.testing.assert_frame_equal(df_out,df_ref,atol=0.0001) + +##### to_geodataframe() tests ##### +def test_to_geodataframe(agg=agg): + # Change to geodatafarme + df_out = agg.to_geodataframe() + + # Build reference output geodataframe + gdf_ref = gpd.GeoDataFrame({'name':['test'], + 'test0':[0.9999629411369734], + 'test1':[1.9999629411369735]}, + geometry = [Polygon([(0,0),(0,1),(1,1),(1,0),(0,0)])], + crs = 'EPSG:4326') + + # Assert equal within tolerance, again likely due to very slight + # variation off from actual 1.0, 2.0 due to crs + gpdt.assert_geodataframe_equal(df_out,gdf_ref,check_less_precise=True) + +def test_prep_for_csv_multd(): + # Test to make sure .prep_for_csv() (for .to_geodataframe() + # and .to_csv()) fails if you have a variable with more than one + # non-location dimension + + # Have a 4-D variable (run and time in addition to geographic data) + ds_extrad = xr.Dataset({'test':(['lon','lat','run','time'],np.random.rand(2,2,2,5)), + 'lat_bnds':(['lat','bnds'],np.array([[-0.5,0.5],[0.5,1.5]])), + 'lon_bnds':(['lon','bnds'],np.array([[-0.5,0.5],[0.5,1.5]]))}, + coords={'lat':(['lat'],np.array([0,1])), + 'lon':(['lon'],np.array([0,1])), + 'run':(['run'],np.array([0,1])), + 'time':(['time'],pd.date_range('2001-01-01','2001-01-05')), + 'bnds':(['bnds'],np.array([0,1]))}) + + # Create polygon covering multiple pixels + gdf = {'name':['test'], + 'geometry':[Polygon([(0,0),(0,1),(1,1),(1,0),(0,0)])]} + gdf = gpd.GeoDataFrame(gdf,crs="EPSG:4326") + + # Get pixel overlaps + wm_extrad = pixel_overlaps(ds_extrad,gdf) + + # Get aggregate + agg_extrad = aggregate(ds_extrad,wm_extrad) + + with pytest.raises(NotImplementedError): + # Make sure + prep_for_csv(agg_extrad) + +##### to_netcdf() tests ##### +def test_to_netcdf(agg=agg): + # Export to netcdf + agg.to_netcdf('test.nc') + + # Make reference dataset + ds_ref = xr.Dataset({'test':(('poly_idx','run'),[[0.9999629411369734,1.9999629411369735]])}, + coords = {'poly_idx':(('poly_idx'),[0]), + 'run':(('run'),[0,1])}) + + # Load + ds_out = xr.open_dataset('test.nc') + + # Test + xr.testing.assert_allclose(ds_ref,ds_out) + + # Remove test export file + os.remove('test.nc') ##### pixel_overlaps() export tests ##### +## Create weightmap to export +# Add a simple weights grid +weights = xr.DataArray(data=np.array([[0.,1.],[2.,3.]]), + dims=['lat','lon'], + coords=[ds.lat,ds.lon]) -def test_pixel_overlaps_export_and_import(ds=ds): +# Create polygon covering one pixel +gdf = {'name':['test'], + 'geometry':[Polygon([(-0.5,-0.5),(-0.5,0.5),(0.5,0.5),(0.5,-0.5),(-0.5,-0.5)])]} +gdf = gpd.GeoDataFrame(gdf,crs="EPSG:4326") + +# Calculate weightmap +wm_out = pixel_overlaps(ds,gdf,weights=weights) + +def test_export_wm_nooverwrite(wm_out=wm_out): + # Test to make sure FileExistsError is thrown if the target + # directory already exists and overwrite=False + fn = 'wm_export_test' + + # Create temporary directory + os.mkdir(fn) + + # Try to save with overwrite=False + with pytest.raises(FileExistsError): + wm_out.to_file(fn,overwrite=False) + + # Clean + shutil.rmtree(fn) + +def test_export_wm_standard(wm_out=wm_out): # Testing the .to_file() --> read_wm() workflow. # Rather complex because of the many different components of wm. # wm.agg in particular is a dataframe with lists in it (which is @@ -73,19 +191,6 @@ def test_pixel_overlaps_export_and_import(ds=ds): fn = 'wm_export_test' - # Add a simple weights grid - weights = xr.DataArray(data=np.array([[0.,1.],[2.,3.]]), - dims=['lat','lon'], - coords=[ds.lat,ds.lon]) - - # Create polygon covering one pixel - gdf = {'name':['test'], - 'geometry':[Polygon([(-0.5,-0.5),(-0.5,0.5),(0.5,0.5),(0.5,-0.5),(-0.5,-0.5)])]} - gdf = gpd.GeoDataFrame(gdf,crs="EPSG:4326") - - # Calculate weightmap - wm_out = pixel_overlaps(ds,gdf,weights=weights) - # Export weightmap wm_out.to_file(fn,overwrite=True) diff --git a/xagg/classes.py b/xagg/classes.py index 920a377..dc892dd 100644 --- a/xagg/classes.py +++ b/xagg/classes.py @@ -74,18 +74,28 @@ def __init__(self,agg,source_grid,geometry,ds_in,weights='nowghts'): # Conversion functions def to_dataset(self,loc_dim='poly_idx'): """ Convert to xarray dataset. + + Parameters + ----------------- + loc_dim : :py:class:`str`, by default `'poly_idx'` + What to name the polygon dimension (e.g., 'county') """ ds_out = prep_for_nc(self,loc_dim=loc_dim) return ds_out def to_geodataframe(self): - """ Convert to geopandas geodataframe. + """ Convert to wide geopandas geodataframe. """ df_out = prep_for_csv(self,add_geom=True) return df_out def to_dataframe(self,loc_dim='poly_idx'): """ Convert to pandas dataframe. + + Parameters + ----------------- + loc_dim : :py:class:`str`, by default `'poly_idx'` + What to name the polygon dimension (e.g., 'county') """ df_out = prep_for_nc(self,loc_dim=loc_dim) df_out = df_out.to_dataframe() @@ -97,7 +107,6 @@ def to_netcdf(self,fn,loc_dim='poly_idx',silent=None): Parameters ----------------- - fn : :py:class:`str` The target filename @@ -105,7 +114,7 @@ def to_netcdf(self,fn,loc_dim='poly_idx',silent=None): What to name the polygon dimension silent : :py:class:`bool`, by default False - If `True`, silences standard out + If `True`, silences status update """ if silent is None: @@ -127,7 +136,7 @@ def to_csv(self,fn,silent=None): The target filename silent : :py:class:`bool`, by default False - If `True`, silences standard out + If `True`, silences status update """ if silent is None: @@ -144,7 +153,7 @@ def to_shp(self,fn,silent=None): The target filename silent : :py:class:`bool`, by default False - If `True`, silences standard out + If `True`, silences status update """ if silent is None: diff --git a/xagg/export.py b/xagg/export.py index 6a28bad..87592f2 100644 --- a/xagg/export.py +++ b/xagg/export.py @@ -127,7 +127,13 @@ def prep_for_nc(agg_obj,loc_dim='poly_idx'): data=agg_obj.ds_in[crd].values, coords=[agg_obj.ds_in[crd].values]) - + # Remove "name" variable created in get_pixel_overlaps + if 'name' in ds_out: + if ((ds_out['name'].dims == ('poly_idx',)) and + (len(np.unique(ds_out['name'])) == 1) and + np.unique(ds_out['name'])[0] in ds_out): + ds_out = ds_out.drop_vars(['name']) + # Rename poly_idx if desired if loc_dim != 'poly_idx': ds_out = ds_out.rename({'poly_idx':loc_dim}) @@ -140,11 +146,13 @@ def prep_for_csv(agg_obj,add_geom=False): """ Preps aggregated data for output as a csv Concretely, aggregated data is placed in a new pandas dataframe - and expanded wide - each aggregated variable is placed in new + and expanded **wide** - each aggregated variable is placed in new columns; one column per coordinate in each dimension that isn't the location (poolygon). So, for example, a lat x lon x time variable "tas", aggregated to location x time, would be reshaped long to columns "tas0", "tas1", "tas2",... for timestep 0, 1, etc. + + For data **long**, use ``agg_obj.to_dataframe().to_csv()`` instead. Note: Currently no support for variables with more than one extra dimension @@ -168,10 +176,19 @@ def prep_for_csv(agg_obj,add_geom=False): a pandas dataframe containing all the fields from the original location polygons + columns containing the values of the aggregated variables at each location. This can then easily be exported as a - csv directly (using ``df.to_csv``) or to shapefiles by first turning into + csv directly (using ``df.to_csv()``) or to shapefiles by first turning into a geodataframe. """ + # Test to make sure there's only one non-location dimension + var_dims = {var:[d for d in agg_obj.ds_in[var].sizes if d != 'loc'] for var in agg_obj.ds_in} + var_ndims = {var:len(dims) for var,dims in var_dims.items()} + if np.max([n for d,n in var_ndims.items()]) > 1: + raise NotImplementedError('The `agg` object has variables with more than 1 non-location dimension; '+ + 'agg.to_csv() and agg.to_geodataframe() return wide arrays, but the code can not yet create wide arrays spanning data with multiple `wide` dimensions. '+ + 'Try agg.to_dataframe() instead. (the offending variables with their non-location dimensions are '+ + str({var:dims for var,dims in var_dims.items() if var_ndims[var]>1})+')') + # For output into csv, work with existing geopandas data frame csv_out = agg_obj.agg.drop(columns=['rel_area','pix_idxs','coords','poly_idx']) @@ -191,11 +208,15 @@ def prep_for_csv(agg_obj,add_geom=False): if len(dimsteps) == 0: # (in this case, all it does is move from one list per row to # one value per row) - expanded_var = (pd.DataFrame(pd.DataFrame(csv_out[var].to_list())[0].to_list(), - columns=[var])) + #expanded_var = (pd.DataFrame(pd.DataFrame(csv_out[var].to_list())[0].to_list(), + # columns=[var])) + expanded_var = pd.DataFrame(csv_out[var].apply(np.squeeze).to_list(), + columns=[var]) else: - expanded_var = (pd.DataFrame(pd.DataFrame(csv_out[var].to_list())[0].to_list(), - columns=[var+str(idx) for idx in np.arange(0,len(csv_out[var][0][0]))])) + #expanded_var = (pd.DataFrame(pd.DataFrame(csv_out[var].to_list())[0].to_list(), + # columns=[var+str(idx) for idx in np.arange(0,len(csv_out[var][0][0]))])) + expanded_var = pd.DataFrame(csv_out[var].apply(np.squeeze).to_list(), + columns = [var+str(idx) for idx in range(len(csv_out[var].apply(np.squeeze))+1)]) # Append to existing series csv_out = pd.concat([csv_out.drop(columns=(var)), expanded_var], @@ -204,7 +225,8 @@ def prep_for_csv(agg_obj,add_geom=False): if add_geom: # Return the geometry from the original geopandas.GeoDataFrame - csv_out.geometry = agg_obj.geometry + csv_out['geometry'] = agg_obj.geometry + csv_out = csv_out.set_geometry('geometry') # Return return csv_out From 41e77f07867916de164f8df141068c1f1a061b94 Mon Sep 17 00:00:00 2001 From: Kevin Schwarzwald Date: Thu, 6 Jun 2024 14:41:19 -0400 Subject: [PATCH 2/3] new export tests for file exports --- tests/test_core.py | 23 +++++ tests/test_export.py | 198 +++++++++++++++++++++++++++++------------- tests/test_options.py | 38 ++++++++ xagg/export.py | 6 +- 4 files changed, 202 insertions(+), 63 deletions(-) create mode 100644 tests/test_options.py diff --git a/tests/test_core.py b/tests/test_core.py index 793be90..64bfd3a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -408,6 +408,29 @@ def test_aggregate_basic(ds=ds): # Possibly worth examining more closely later assert np.allclose([v for v in agg.agg.test.values],1.4999,rtol=1e-4) +def test_aggregate_wdataarray(ds=ds): + da = ds['test'].copy() + # Create polygon covering multiple pixels + gdf = {'name':['test'], + 'geometry':[Polygon([(0,0),(0,1),(1,1),(1,0),(0,0)])]} + gdf = gpd.GeoDataFrame(gdf,crs="EPSG:4326") + + # calculate the pix_agg variable tested above, to be used in the + # tests below + pix_agg = create_raster_polygons(ds) + + # Get pixel overlaps + wm = get_pixel_overlaps(gdf,pix_agg) + + # Get aggregate + agg = aggregate(da,wm) + + # This requires shifting rtol to 1e-4 for some reason, in that + # it's actually 1.499981, whereas multiplying out + # np.sum(agg.agg.rel_area[0]*np.array([0,1,2,3]))gives 1.499963... + # Possibly worth examining more closely later + assert np.allclose([v for v in agg.agg.test.values],1.4999,rtol=1e-4) + def test_aggregate_basic_wdotproduct(ds=ds): # Create multiple polygons, to double check, since dot product # implementation takes a slightly different approach to indices diff --git a/tests/test_export.py b/tests/test_export.py index a9d4ec2..ee290ac 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -7,11 +7,14 @@ import geopandas as gpd from geopandas import testing as gpdt from unittest import TestCase +from unittest.mock import patch +from io import StringIO from shapely.geometry import Polygon from xagg.core import (process_weights,create_raster_polygons,get_pixel_overlaps,aggregate,read_wm) -from xagg.wrappers import (pixel_overlaps) -from xagg.export import (prep_for_csv) +from xagg.wrappers import pixel_overlaps +from xagg.export import prep_for_csv +from xagg.options import (set_options,get_options) ##### to_dataset() tests ##### @@ -85,6 +88,25 @@ def test_to_dataframe_renamelocdim(agg=agg): # Assert equal within tolerance pd.testing.assert_frame_equal(df_out,df_ref,atol=0.0001) +##### to_csv() tests ##### +def test_to_csv(agg=agg): + # Change to dataframe + df_out = agg.to_csv('test.csv') + + try: + # Build reference output dataframe + df_ref = pd.DataFrame({'poly_idx':[0,0],'run':[0,1],'test':[0.9999,1.9999]}) + + # Load + df_in = pd.read_csv('test.csv') + + # Assert equal within tolerance + pd.testing.assert_frame_equal(df_in,df_ref,atol=0.0001) + + finally: + # Clean + os.remove('test.csv') + ##### to_geodataframe() tests ##### def test_to_geodataframe(agg=agg): # Change to geodatafarme @@ -131,24 +153,75 @@ def test_prep_for_csv_multd(): # Make sure prep_for_csv(agg_extrad) +##### to_shp() tests ##### +def test_to_shp(agg=agg): + # Export to shapefile + agg.to_shp('test.shp') + + try: + # Make reference geodataframe + gdf_ref = gpd.GeoDataFrame({'name':['test'], + 'test0':[0.9999629411369734], + 'test1':[1.9999629411369735]}, + geometry = [Polygon([(0,0),(0,1),(1,1),(1,0),(0,0)])], + crs = 'EPSG:4326') + + # Read file + gdf_in = gpd.read_file('test.shp') + + # Assert equal within tolerance + gpdt.assert_geodataframe_equal(gdf_in,gdf_ref,check_less_precise=True) + finally: + # Clean + for suff in ['cpg','dbf','prj','shp','shx']: + os.remove('test.'+suff) + + ##### to_netcdf() tests ##### def test_to_netcdf(agg=agg): # Export to netcdf agg.to_netcdf('test.nc') - # Make reference dataset - ds_ref = xr.Dataset({'test':(('poly_idx','run'),[[0.9999629411369734,1.9999629411369735]])}, - coords = {'poly_idx':(('poly_idx'),[0]), - 'run':(('run'),[0,1])}) - - # Load - ds_out = xr.open_dataset('test.nc') - - # Test - xr.testing.assert_allclose(ds_ref,ds_out) - - # Remove test export file - os.remove('test.nc') + try: + # Make reference dataset + ds_ref = xr.Dataset({'test':(('poly_idx','run'),[[0.9999629411369734,1.9999629411369735]])}, + coords = {'poly_idx':(('poly_idx'),[0]), + 'run':(('run'),[0,1])}) + + # Load + ds_out = xr.open_dataset('test.nc') + + # Test + xr.testing.assert_allclose(ds_ref,ds_out) + finally: + # Remove test export file + os.remove('test.nc') + + +##### silent export tests ##### +@patch('sys.stdout', new_callable=StringIO) +def test_silent_filesaves(mock_stdout): + try: + fns = {'nc':'test.nc', + 'csv':'test.csv', + 'shp':'test'} + + with set_options(silent=True): + agg.to_netcdf(fns['nc']) + agg.to_csv(fns['csv']) + agg.to_shp(fns['shp']) + # Check that nothing was printed + assert mock_stdout.getvalue() == '' + + finally: + for k in fns: + # Clean + if (k == 'shp') and os.path.exists(fns[k]+'.shp'): + # Clean shapefile aux files + for suff in ['cpg','dbf','prj','shp','shx']: + os.remove(fns[k]+'.'+suff) + elif os.path.exists(fns[k]): + os.remove(fns[k]) ##### pixel_overlaps() export tests ##### ## Create weightmap to export @@ -170,15 +243,17 @@ def test_export_wm_nooverwrite(wm_out=wm_out): # directory already exists and overwrite=False fn = 'wm_export_test' - # Create temporary directory - os.mkdir(fn) - - # Try to save with overwrite=False - with pytest.raises(FileExistsError): - wm_out.to_file(fn,overwrite=False) + try: + # Create temporary directory + os.mkdir(fn) - # Clean - shutil.rmtree(fn) + # Try to save with overwrite=False + with pytest.raises(FileExistsError): + wm_out.to_file(fn,overwrite=False) + finally: + if os.path.exists(fn): + # Clean + shutil.rmtree(fn) def test_export_wm_standard(wm_out=wm_out): # Testing the .to_file() --> read_wm() workflow. @@ -191,41 +266,44 @@ def test_export_wm_standard(wm_out=wm_out): fn = 'wm_export_test' - # Export weightmap - wm_out.to_file(fn,overwrite=True) - - # Load weightmap - wm_in = read_wm(fn) - - - ###### agg - # This is just checking the parts of the thing that aren't lists. Unfortunately putting lists into - # dataframes is a bad idea, but that's the best method I've come up with so far. - pd.testing.assert_frame_equal(wm_out.agg[[v for v in wm_out.agg if v not in ['rel_area','pix_idxs','coords']]], - wm_in.agg[[v for v in wm_in.agg if v not in ['rel_area','pix_idxs','coords']]]) - # Now for the columns with lists - for v in ['rel_area','pix_idxs','coords']: - for it in np.arange(0,wm_out.agg[v].shape[0]): - np.testing.assert_array_equal(wm_out.agg[v].values[it],wm_in.agg[v].values[it]) - # Now for the column names to make sure all the columns are remained - pd.testing.assert_index_equal(wm_out.agg.columns,wm_in.agg.columns) - - ###### geometry - gpdt.assert_geoseries_equal(wm_out.geometry,wm_in.geometry) - - ###### source grids - for k in wm_out.source_grid: - xr.testing.assert_allclose(wm_out.source_grid[k],wm_in.source_grid[k]) - - ###### weights - if (type(wm_out.weights) is str) and (wm_out.weights=='nowghts'): - np.testing.assert_string_equal(wm_in.weights,wm_out.weights) - else: - # `read_wm()` reads in weights as objects (see notes in relevant - # section of `read_wm()`... this shouldn't have too big - # of a consequence, but does make this test more complicated - pd.testing.assert_series_equal(wm_in.weights,wm_out.weights.astype(object)) - - ##### clean - shutil.rmtree(fn) + try: + # Export weightmap + wm_out.to_file(fn,overwrite=True) + + + # Load weightmap + wm_in = read_wm(fn) + + + ###### agg + # This is just checking the parts of the thing that aren't lists. Unfortunately putting lists into + # dataframes is a bad idea, but that's the best method I've come up with so far. + pd.testing.assert_frame_equal(wm_out.agg[[v for v in wm_out.agg if v not in ['rel_area','pix_idxs','coords']]], + wm_in.agg[[v for v in wm_in.agg if v not in ['rel_area','pix_idxs','coords']]]) + # Now for the columns with lists + for v in ['rel_area','pix_idxs','coords']: + for it in np.arange(0,wm_out.agg[v].shape[0]): + np.testing.assert_array_equal(wm_out.agg[v].values[it],wm_in.agg[v].values[it]) + # Now for the column names to make sure all the columns are remained + pd.testing.assert_index_equal(wm_out.agg.columns,wm_in.agg.columns) + + ###### geometry + gpdt.assert_geoseries_equal(wm_out.geometry,wm_in.geometry) + + ###### source grids + for k in wm_out.source_grid: + xr.testing.assert_allclose(wm_out.source_grid[k],wm_in.source_grid[k]) + + ###### weights + if (type(wm_out.weights) is str) and (wm_out.weights=='nowghts'): + np.testing.assert_string_equal(wm_in.weights,wm_out.weights) + else: + # `read_wm()` reads in weights as objects (see notes in relevant + # section of `read_wm()`... this shouldn't have too big + # of a consequence, but does make this test more complicated + pd.testing.assert_series_equal(wm_in.weights,wm_out.weights.astype(object)) + finally: + ##### clean + if os.path.exists(fn): + shutil.rmtree(fn) diff --git a/tests/test_options.py b/tests/test_options.py new file mode 100644 index 0000000..6d9fc44 --- /dev/null +++ b/tests/test_options.py @@ -0,0 +1,38 @@ +import pytest + +from xagg.options import set_options,get_options + + +##### set_options() tests ##### +def test_set_options_withtemp(): + # Test to make sure using set_options in a + # with block doesn't change the original + # options + # (Hardcoded... would have to change if the + # settings default is changed...) + with set_options(silent=True,impl='dot_product'): + pass + assert not get_options()['silent'] + assert get_options()['impl'] == 'for_loop' + +def test_set_options(): + # Test changing options works + with set_options(silent=True): + assert get_options()['silent'] + with set_options(impl='dot_product'): + assert get_options()['impl'] == 'dot_product' + +def test_set_options_badoption(): + # Test error for unsupported option name + with pytest.raises(ValueError): + set_options(fake_option=True) + +def test_set_options_badoptioninput(): + # Test error for unsupported option input + # Testing with silent = not a bool + # (seems too much to test for every option... + # but maybe required?) + with pytest.raises(ValueError): + set_options(silent='a') + with pytest.raises(ValueError): + set_options(impl='fake_option') \ No newline at end of file diff --git a/xagg/export.py b/xagg/export.py index 87592f2..5fd7ff4 100644 --- a/xagg/export.py +++ b/xagg/export.py @@ -270,7 +270,7 @@ def output_data(agg_obj,output_format,output_fn,loc_dim='poly_idx',silent=False) if not output_fn.endswith('.nc'): output_fn = output_fn+'.nc' ds_out.to_netcdf(output_fn) - if silent: + if not silent: print(output_fn+' saved!') # Return @@ -285,7 +285,7 @@ def output_data(agg_obj,output_format,output_fn,loc_dim='poly_idx',silent=False) if not output_fn.endswith('.csv'): output_fn = output_fn+'.csv' csv_out.to_csv(output_fn) - if silent: + if not silent: print(output_fn+' saved!') # Return @@ -306,7 +306,7 @@ def output_data(agg_obj,output_format,output_fn,loc_dim='poly_idx',silent=False) if not output_fn.endswith('.shp'): output_fn = output_fn+'.shp' shp_out.to_file(output_fn) - if silent: + if not silent: print(output_fn+' saved!') # Return From 6a7ed2deff327d71cb291967040565efd7f0e4cf Mon Sep 17 00:00:00 2001 From: Kevin Schwarzwald Date: Thu, 6 Jun 2024 15:01:19 -0400 Subject: [PATCH 3/3] extra auxfunc tests --- tests/test_auxfuncs.py | 63 +++++++++++++++++++++++++++++++++++++++++- xagg/auxfuncs.py | 7 +++-- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/tests/test_auxfuncs.py b/tests/test_auxfuncs.py index 5b70845..8ada5dd 100644 --- a/tests/test_auxfuncs.py +++ b/tests/test_auxfuncs.py @@ -107,6 +107,27 @@ def test_get_bnds_null(): xr.testing.assert_allclose(ds,get_bnds(ds)) +def test_get_bnds_badlons(): + # Test to make sure > 180 lons throws an error + ds = xr.Dataset({'lat_bnds':(['lat','bnds'],np.array([[-0.5,0.5],[0.5,1.5],[1.5,2.5]])), + 'lon_bnds':(['lon','bnds'],np.array([[179.5,180.5],[180.5,181.5],[181.5,182.5]]))}, + coords={'lat':(['lat'],np.array([0,1,2])), + 'lon':(['lon'],np.array([180,181,182]))}) + + with pytest.raises(ValueError): + get_bnds(ds) + +def test_get_bnds_missingdims(): + # Test to make sure missing lat/lon throws an error + ds = xr.Dataset({'latitude_bnds':(['latitude','bnds'],np.array([[-0.5,0.5],[0.5,1.5],[1.5,2.5]])), + 'longitude_bnds':(['longitude','bnds'],np.array([[-0.5,0.5],[0.5,1.5],[1.5,2.5]]))}, + coords={'latitude':(['latitude'],np.array([0,1,2])), + 'longitude':(['longitude'],np.array([0,1,2]))}) + + + with pytest.raises(KeyError): + get_bnds(ds) + def test_get_bnds_basic(): # Basic attempt to get the bounds (far away from the wraparound) ds = xr.Dataset(coords={'lat':(['lat'],np.array([0,1,2])), @@ -169,6 +190,22 @@ def test_get_bnds_partialgrid(): xr.testing.assert_allclose(ds,ds_compare) +def test_get_bnds_partialgrid_customwat(): + # Partial grid, that should _not_ be wrapped around, with a + # manual wraparoundthreshold + ds = xr.Dataset(coords={'lat':(['lat'],np.arange(-89.5,89.51)), + 'lon':(['lon'],np.arange(-179.5,177.51))}) + ds = get_bnds(ds,wrap_around_thresh=1) + + lat_bnds = np.array(list(zip(np.arange(-90,89.91),np.arange(-89,90.1)))) + lon_bnds = np.array(list(zip(np.arange(-180,177.01),np.arange(-179,178.01)))) + ds_compare = xr.Dataset({'lat_bnds':(['lat','bnds'],lat_bnds), + 'lon_bnds':(['lon','bnds'],lon_bnds)}, + coords={'lat':(['lat'],np.arange(-89.5,89.51)), + 'lon':(['lon'],np.arange(-179.5,177.51))}) + + xr.testing.assert_allclose(ds,ds_compare) + def test_get_bnds_offsetgrid(): # Full planetary grid, but with pixels crossing the antimeridian ds = xr.Dataset(coords={'lat':(['lat'],np.arange(-89.4,89.7)), # -180:180, offset @@ -229,8 +266,15 @@ def test_get_bnds_1pixelinWH(): xr.testing.assert_allclose(ds,ds_compare) +def test_get_bnds_badwraparoundthresh(): + # If there are bounds already, do nothing + ds = xr.Dataset({'lat_bnds':(['lat','bnds'],np.array([[-0.5,0.5],[0.5,1.5],[1.5,2.5]])), + 'lon_bnds':(['lon','bnds'],np.array([[-0.5,0.5],[0.5,1.5],[1.5,2.5]]))}, + coords={'lat':(['lat'],np.array([0,1,2])), + 'lon':(['lon'],np.array([0,1,2]))}) - + with pytest.raises(ValueError): + get_bnds(ds,wrap_around_thresh='bad_option') ###### subset_find tests ##### def test_subset_find_basic(): @@ -248,6 +292,23 @@ def test_subset_find_basic(): xr.testing.assert_allclose(subset_find(ds0,ds1),ds_compare) +def test_subset_find_nomatch(): + # Test that it breaks if it can't find one within the other + # Create two grids, with one offset by a half degree lat from the other + ds0 = xr.Dataset({'test':(['lat','lon'],np.array([[0,1,2],[3,4,5],[6,7,8]]))}, + coords={'lat':(['lat'],np.array([0.5,1.5,2.5])), + 'lon':(['lon'],np.array([-1,0,1]))}) + + ds1 = xr.Dataset({'lat':(['lat'],np.array([0,1])), + 'lon':(['lon'],np.array([-1,0]))}) + ds1 = ds1.stack(loc=('lat','lon')) + + ds_compare = xr.Dataset({'test':(['lat','lon'],np.array([[0,1],[3,4]]))}, + coords={'lat':(['lat'],np.array([0,1])), + 'lon':(['lon'],np.array([-1,0]))}) + with pytest.raises(ValueError): + subset_find(ds0,ds1) + diff --git a/xagg/auxfuncs.py b/xagg/auxfuncs.py index 72e2187..d52a504 100644 --- a/xagg/auxfuncs.py +++ b/xagg/auxfuncs.py @@ -243,6 +243,10 @@ def get_bnds(ds,wrap_around_thresh='dynamic', if (type(wrap_around_thresh) == str) and (wrap_around_thresh != 'dynamic'): raise ValueError('`wrap_around_thresh` must either be numeric or the string "dynamic"; instead, it is '+str(wrap_around_thresh)+'.') + if ('lat' not in ds) or ('lon' not in ds): + raise KeyError('"lat"/"lon" not found in [ds]. Make sure the '+ + 'geographic dimensions follow this naming convention (e.g., run `xa.fix_ds(ds)` before inputting.') + if ds.lon.max()>180: raise ValueError('Longitude seems to be in the 0:360 format.'+ ' -180:180 format required (e.g., run `xa.fix_ds(ds)` before inputting.') @@ -250,9 +254,6 @@ def get_bnds(ds,wrap_around_thresh='dynamic', # honestly, it *may* already work by just changing edges['lon'] # to [0,360], but it's not tested yet. - if ('lat' not in ds) or ('lon' not in ds): - raise KeyError('"lat"/"lon" not found in [ds]. Make sure the '+ - 'geographic dimensions follow this naming convention (e.g., run `xa.fix_ds(ds)` before inputting.') if ('lat_bnds' in ds) and ('lon_bnds' in ds): # `xa.fix_ds()` should rename bounds to `lat/lon_bnds`