Skip to content

Commit

Permalink
ensure backward compat w/ older stats files that need channel_order
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Aug 5, 2024
1 parent c58a7fe commit 906ff30
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 36 deletions.
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def register_plugin(registry):
registry.set_plugin_version('rastervision.core', 13)
registry.set_plugin_version('rastervision.core', 14)
from rastervision.core.cli import predict, predict_scene
registry.add_plugin_command(predict)
registry.add_plugin_command(predict_scene)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def validate_temporal(self) -> Self:
def build(self, tmp_dir: str | None = None,
use_transformers: bool = True) -> MultiRasterSource:
if use_transformers:
raster_transformers = [t.build() for t in self.transformers]
raster_transformers = [
t.build(channel_order=self.channel_order)
for t in self.transformers
]
else:
raster_transformers = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,15 @@ class RasterioSourceConfig(RasterSourceConfig):
False,
description='Stream assets as needed rather than downloading them.')

def build(self, tmp_dir, use_transformers=True):
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])
def build(self, tmp_dir: str | None,
use_transformers: bool = True) -> RasterioSource:
if use_transformers:
raster_transformers = [
t.build(channel_order=self.channel_order)
for t in self.transformers
]
else:
raster_transformers = []
bbox = Box(*self.bbox) if self.bbox is not None else None
return RasterioSource(
uris=self.uris,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ class XarraySourceConfig(RasterSourceConfig):
def build(self, tmp_dir: str | None = None,
use_transformers: bool = True) -> XarraySource:
item_or_item_collection = self.stac.build()
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])
if use_transformers:
raster_transformers = [
t.build(channel_order=self.channel_order)
for t in self.transformers
]
else:
raster_transformers = []
raster_source = XarraySource.from_stac(
item_or_item_collection,
raster_transformers=raster_transformers,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rastervision.pipeline.config import register_config, Field
from rastervision.core.data.raster_transformer.raster_transformer_config import ( # noqa
RasterTransformerConfig)
from rastervision.core.data.raster_transformer.cast_transformer import ( # noqa
from rastervision.core.data.raster_transformer.cast_transformer import (
CastTransformer)


Expand All @@ -14,5 +14,5 @@ class CastTransformerConfig(RasterTransformerConfig):
description='dtype to cast raster to. Must be a valid Numpy dtype '
'e.g. "uint8", "float32", etc.')

def build(self):
def build(self, channel_order: list[int] | None = None):
return CastTransformer(to_dtype=self.to_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
class MinMaxTransformerConfig(RasterTransformerConfig):
"""Configure a :class:`.MinMaxTransformer`."""

def build(self):
def build(self, channel_order: list[int] | None = None):
return MinMaxTransformer()
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ class NanTransformerConfig(RasterTransformerConfig):
to_value: float | None = Field(
0.0, description=('Turn all NaN values into this value.'))

def build(self):
def build(self, channel_order: list[int] | None = None):
return NanTransformer(to_value=self.to_value)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rastervision.pipeline.config import Config, register_config

if TYPE_CHECKING:
from rastervision.core.data import SceneConfig
from rastervision.core.data import RasterTransformer, SceneConfig
from rastervision.core.rv_pipeline import RVPipelineConfig


Expand All @@ -17,3 +17,7 @@ def update(self,

def update_root(self, root_dir: str):
pass

def build(self,
channel_order: list[int] | None = None) -> 'RasterTransformer':
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ class ReclassTransformerConfig(RasterTransformerConfig):
mapping: dict[int, int] = Field(
..., description=('The reclassification mapping.'))

def build(self) -> ReclassTransformer:
def build(self,
channel_order: list[int] | None = None) -> ReclassTransformer:
return ReclassTransformer(mapping=self.mapping)
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ class RGBClassTransformerConfig(RasterTransformerConfig):
description=('The class config defining the mapping between '
'classes and colors.'))

def build(self) -> RGBClassTransformer:
def build(self,
channel_order: list[int] | None = None) -> RGBClassTransformer:
return RGBClassTransformer(class_config=self.class_config)
Original file line number Diff line number Diff line change
Expand Up @@ -111,39 +111,53 @@ def from_raster_sources(cls,
return stats_transformer

@classmethod
def from_stats_json(cls, uri: str, **kwargs) -> Self:
def from_stats_json(cls,
uri: str,
channel_order: list[int] | None = None,
**kwargs) -> Self:
"""Build with stats from a JSON file.
The file is expected to be in the same format as written by
:meth:`.RasterStats.save`.
Args:
uri (str): URI of the JSON file.
uri: URI of the JSON file.
channel_order: Channel order to apply to the means and stds in the
file.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
StatsTransformer: A StatsTransformer.
A StatsTransformer.
"""
stats = RasterStats.load(uri)
stats_transformer = StatsTransformer.from_raster_stats(stats, **kwargs)
stats_transformer = StatsTransformer.from_raster_stats(
stats, channel_order=channel_order, **kwargs)
return stats_transformer

@classmethod
def from_raster_stats(cls, stats: RasterStats, **kwargs) -> Self:
def from_raster_stats(cls,
stats: RasterStats,
channel_order: list[int] | None = None,
**kwargs) -> Self:
"""Build with stats from a :class:`.RasterStats` instance.
The file is expected to be in the same format as written by
:meth:`.RasterStats.save`.
Args:
stats (RasterStats): A :class:`.RasterStats` instance with
non-None stats.
stats: A :class:`.RasterStats` instance with non-None stats.
channel_order: Channel order to apply to the means and stds in the
:class:`.RasterStats`.
**kwargs: Extra args for :meth:`.__init__`.
Returns:
StatsTransformer: A StatsTransformer.
A StatsTransformer.
"""
stats_transformer = StatsTransformer(stats.means, stats.stds, **kwargs)
means, stds = stats.means, stats.stds
if channel_order is not None:
means = means[channel_order]
stds = stds[channel_order]
stats_transformer = StatsTransformer(means, stds, **kwargs)
return stats_transformer

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from rastervision.pipeline.config import register_config, Field
from rastervision.core.data.raster_transformer import (RasterTransformerConfig,
StatsTransformer)
from rastervision.core.raster_stats import RasterStats

if TYPE_CHECKING:
from rastervision.core.rv_pipeline import RVPipelineConfig
Expand All @@ -18,6 +17,8 @@ def stats_transformer_config_upgrader(cfg_dict: dict, version: int) -> dict:
# `update_root()`, which is called by the predictor, knows to set
# `stats_uri` to the old location of `stats.json`.
cfg_dict['scene_group'] = '__N/A__'
elif version == 13:
cfg_dict['needs_channel_order'] = True
return cfg_dict


Expand All @@ -35,6 +36,14 @@ class StatsTransformerConfig(RasterTransformerConfig):
'train_scenes',
description='Name of the group of scenes whose stats to use. Defaults'
'to "train_scenes".')
needs_channel_order: bool = Field(
False,
description='Whether the means and stds in the stats_uri file need to '
'be re-ordered/subsetted using ``channel_order`` to be compatible '
'with the chips that will be passed to the :class:`.StatsTransformer` '
'by the :class:`.RasterSource`. This field exists for backward '
'compatibility with Raster Vision versions <= 0.30. It will be set '
'automatically when loading stats from older model-bundles.')

def update(self,
pipeline: 'RVPipelineConfig | None' = None,
Expand All @@ -43,9 +52,14 @@ def update(self,
self.stats_uri = join(pipeline.analyze_uri, 'stats',
self.scene_group, 'stats.json')

def build(self):
stats = RasterStats.load(self.stats_uri)
return StatsTransformer(means=stats.means, stds=stats.stds)
def build(self,
channel_order: list[int] | None = None) -> StatsTransformer:
if self.needs_channel_order:
tf = StatsTransformer.from_stats_json(
self.stats_uri, channel_order=channel_order)
else:
tf = StatsTransformer.from_stats_json(self.stats_uri)
return tf

def update_root(self, root_dir: str) -> None:
if self.scene_group == '__N/A__':
Expand Down
12 changes: 6 additions & 6 deletions rastervision_core/rastervision/core/raster_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class RasterStats:
"""Band-wise means and standard deviations."""

def __init__(self,
means: np.ndarray | None = None,
stds: np.ndarray | None = None,
counts: np.ndarray | None = None):
means: Sequence[float] | None = None,
stds: Sequence[float] | None = None,
counts: Sequence[float] | None = None):
"""Constructor.
Args:
Expand All @@ -27,9 +27,9 @@ def __init__(self,
counts: Band pixel counts (used to compute the specified means and
stds). Defaults to ``None``.
"""
self.means = means
self.stds = stds
self.counts = counts
self.means = np.array(means) if means is not None else None
self.stds = np.array(stds) if stds is not None else None
self.counts = np.array(counts) if counts is not None else None

@classmethod
def load(cls, stats_uri: str) -> Self:
Expand Down
36 changes: 33 additions & 3 deletions tests/core/data/raster_transformer/test_stats_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

import numpy as np

from rastervision.pipeline.config import build_config
from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core.raster_stats import RasterStats
from rastervision.core.data import (RasterioSource, StatsTransformer,
StatsTransformerConfig)
from rastervision.core.data.raster_transformer.stats_transformer_config import ( # noqa
stats_transformer_config_upgrader)

from tests import data_file_path

Expand Down Expand Up @@ -46,6 +49,33 @@ def test_build(self):
np.testing.assert_array_equal(tf.means, np.array([1, 2]))
np.testing.assert_array_equal(tf.stds, np.array([3, 4]))

def test_upgrader_v2(self):
cfg = StatsTransformerConfig()
old_cfg_dict = cfg.dict()
old_cfg_dict.pop('needs_channel_order')
new_cfg_dict = stats_transformer_config_upgrader(old_cfg_dict, 2)
self.assertEqual(new_cfg_dict['scene_group'], '__N/A__')

def test_upgrader_v13(self):
stats = RasterStats(np.array([1, 2]), np.array([3, 4]))

with get_tmp_dir() as tmp_dir:
stats_uri = join(tmp_dir, 'stats.json')
stats.save(stats_uri)

cfg = StatsTransformerConfig(stats_uri=stats_uri)
old_cfg_dict = cfg.dict()
old_cfg_dict.pop('needs_channel_order')
new_cfg_dict = stats_transformer_config_upgrader(old_cfg_dict, 13)
self.assertTrue(new_cfg_dict['needs_channel_order'])

cfg = build_config(new_cfg_dict)
self.assertIsInstance(cfg, StatsTransformerConfig)

tf = cfg.build(channel_order=[1, 0])
np.testing.assert_array_equal(tf.means, np.array([2, 1]))
np.testing.assert_array_equal(tf.stds, np.array([4, 3]))


class TestStatsTransformer(unittest.TestCase):
def test_transform(self):
Expand All @@ -72,9 +102,9 @@ def test_stats(self):

def test_from_raster_stats(self):
stats = RasterStats(np.array([1, 2]), np.array([3, 4]))
tf = StatsTransformer.from_raster_stats(stats)
np.testing.assert_array_equal(tf.means, np.array([1, 2]))
np.testing.assert_array_equal(tf.stds, np.array([3, 4]))
tf = StatsTransformer.from_raster_stats(stats, channel_order=[1, 0])
np.testing.assert_array_equal(tf.means, np.array([2, 1]))
np.testing.assert_array_equal(tf.stds, np.array([4, 3]))

def test_from_stats_json(self):
stats = RasterStats(np.array([1, 2]), np.array([3, 4]))
Expand Down

0 comments on commit 906ff30

Please sign in to comment.