diff --git a/ads/dataset/sampled_dataset.py b/ads/dataset/sampled_dataset.py index 665957ebd..8d53715c0 100644 --- a/ads/dataset/sampled_dataset.py +++ b/ads/dataset/sampled_dataset.py @@ -47,13 +47,13 @@ OptionalDependency, ) +NATURAL_EARTH_DATASET = "naturalearth_lowres" class PandasDataset(object): """ This class provides APIs that can work on a sampled dataset. """ - @runtime_dependency(module="geopandas", install_from=OptionalDependency.GEO) def __init__( self, sampled_df, @@ -67,9 +67,7 @@ def __init__( self.correlation = None self.feature_dist_html_dict = {} self.feature_types = metadata if metadata is not None else {} - self.world = geopandas.read_file( - geopandas.datasets.get_path("naturalearth_lowres") - ) + self.world = None self.numeric_columns = self.sampled_df.select_dtypes( utils.numeric_pandas_dtypes() @@ -562,7 +560,7 @@ def plot_gis_scatter(self, lon="longitude", lat="latitude", ax=None): ), ) world = geopandas.read_file( - geopandas.datasets.get_path("naturalearth_lowres") + geopandas.datasets.get_path(NATURAL_EARTH_DATASET) ) ax1 = world.plot(ax=ax, color="lightgrey", linewidth=0.5, edgecolor="white") gdf.plot(ax=ax1, color="blue", markersize=10) @@ -706,6 +704,12 @@ def _visualize_feature_distribution(self, html_widget): gdf = geopandas.GeoDataFrame( df, geometry=geopandas.points_from_xy(df["lon"], df["lat"]) ) + + if not self.world: + self.world = geopandas.read_file( + geopandas.datasets.get_path(NATURAL_EARTH_DATASET) + ) + self.world.plot( ax=ax, color="lightgrey", linewidth=0.5, edgecolor="white" )