-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from dfki-ric/feature/configurable_initial_pos…
…ition Feature/configurable initial position
- Loading branch information
Showing
5 changed files
with
263 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Protocol, Union, Optional | ||
import numpy.typing as npt | ||
import numpy as np | ||
|
||
|
||
class Sampler(Protocol): | ||
def sample_initial_pose(self) -> npt.NDArray: | ||
... | ||
|
||
|
||
class FixedSampler(Sampler): | ||
def __init__(self, initial_pose: npt.ArrayLike): | ||
self.initial_pose = np.array(initial_pose) | ||
|
||
def sample_initial_pose(self) -> npt.NDArray: | ||
return self.initial_pose | ||
|
||
|
||
class GaussianSampler(Sampler): | ||
def __init__( | ||
self, | ||
mu: npt.ArrayLike, | ||
sigma: npt.ArrayLike, | ||
seed: Optional[int] = None | ||
): | ||
self.mu = np.array(mu) | ||
self.sigma = np.array(sigma) | ||
self.rng = np.random.default_rng(seed) | ||
|
||
def sample_initial_pose(self) -> npt.NDArray: | ||
return self.rng.normal(self.mu, self.sigma) | ||
|
||
|
||
class UniformSampler(Sampler): | ||
def __init__( | ||
self, | ||
low: npt.ArrayLike, | ||
high: npt.ArrayLike, | ||
seed: Optional[int] = None | ||
): | ||
self.low = np.array(low) | ||
self.high = np.array(high) | ||
self.rng = np.random.default_rng(seed) | ||
|
||
def sample_initial_pose(self) -> npt.NDArray: | ||
return self.rng.uniform(self.low, self.high) | ||
|
||
|
||
class GaussianCurriculumSampler(Sampler): | ||
def __init__( | ||
self, | ||
mu: npt.ArrayLike, | ||
sigma: npt.ArrayLike, | ||
step_size: Union[float, npt.ArrayLike] = 1e-3, | ||
seed: Optional[int] = None | ||
): | ||
self.mu = np.array(mu) | ||
self.sigma = np.array(sigma) | ||
self.rng = np.random.default_rng(seed) | ||
self.step_size = step_size | ||
|
||
def sample_initial_pose(self) -> npt.NDArray: | ||
self.sigma += self.step_size | ||
return self.rng.normal(self.mu, self.sigma) | ||
|
||
|
||
class UniformCurriculumSampler(Sampler): | ||
def __init__( | ||
self, | ||
low: npt.ArrayLike, | ||
high: npt.ArrayLike, | ||
step_size: Union[float, npt.ArrayLike] = 1e-3, | ||
seed: Optional[int] = None | ||
): | ||
self.low = np.array(low) | ||
self.high = np.array(high) | ||
self.rng = np.random.default_rng(seed) | ||
self.step_size = step_size | ||
|
||
def sample_initial_pose(self) -> npt.NDArray: | ||
self.high += self.step_size | ||
self.low -= self.step_size | ||
return self.rng.uniform(self.low, self.high) | ||
|
||
|
||
class GridSampler(Sampler): | ||
def __init__( | ||
self, | ||
low: npt.ArrayLike, | ||
high: npt.ArrayLike, | ||
n_points_per_axis: npt.ArrayLike, | ||
): | ||
self.n_dims = len(low) | ||
points_per_axis = [np.linspace( | ||
low[i], high[i], n_points_per_axis[i]) for i in range(self.n_dims)] | ||
|
||
self.grid = np.array(np.meshgrid(*points_per_axis)).T.reshape(-1, 3) | ||
self.n_samples = len(self.grid) | ||
self.n_calls = 0 | ||
|
||
def sample_initial_pose(self) -> npt.NDArray: | ||
sample = self.grid[self.n_calls % self.n_samples].copy() | ||
self.n_calls += 1 | ||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import gymnasium | ||
|
||
from deformable_gym.envs.floating_mia_grasp_env import FloatingMiaGraspEnv | ||
from deformable_gym.envs.sampler import UniformSampler | ||
|
||
""" | ||
========= | ||
Floating Mia Example | ||
========= | ||
This is an example of how to use the FloatingMiaGraspEnv. A random policy is | ||
then used to generate ten episodes. | ||
""" | ||
|
||
base_initial_pose = FloatingMiaGraspEnv.HARD_INITIAL_POSE.copy() | ||
low = base_initial_pose.copy() | ||
high = base_initial_pose.copy() | ||
|
||
low[:3] -= 0.03 | ||
high[:3] += 0.03 | ||
|
||
sampler = UniformSampler(low, high, seed=0) | ||
|
||
env = gymnasium.make( | ||
"FloatingMiaGraspInsole-v0", initial_state_sampler=sampler, gui=True) | ||
|
||
obs, info = env.reset() | ||
episode_return = 0 | ||
num_episodes = 0 | ||
|
||
while num_episodes < 10: | ||
|
||
action = env.action_space.sample() | ||
|
||
obs, reward, terminated, truncated, _ = env.step(action) | ||
episode_return += reward | ||
|
||
if terminated or truncated: | ||
print(f"Episode finished with return {episode_return}!") | ||
num_episodes += 1 | ||
episode_return = 0 | ||
|
||
obs, _ = env.reset() | ||
|
||
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import pytest | ||
import numpy as np | ||
import numpy.typing as npt | ||
from numpy.testing import assert_allclose | ||
|
||
from deformable_gym.envs.sampler import FixedSampler, GaussianSampler, \ | ||
UniformSampler, GridSampler | ||
|
||
SEED = 0 | ||
|
||
|
||
@pytest.fixture | ||
def gaussian_target_pose() -> npt.NDArray: | ||
rng = np.random.default_rng(SEED) | ||
target = rng.normal(np.array([1, 2, 3]), np.array([2, 3, 4])) | ||
return target | ||
|
||
|
||
@pytest.fixture | ||
def fixed_target_pose() -> npt.NDArray: | ||
return np.array([1, 2, 3]) | ||
|
||
|
||
@pytest.fixture | ||
def uniform_target_pose() -> npt.NDArray: | ||
rng = np.random.default_rng(SEED) | ||
target = rng.uniform(np.array([1, 2, 3]), np.array([2, 3, 4])) | ||
return target | ||
|
||
|
||
@pytest.fixture | ||
def grid_target_pose() -> npt.NDArray: | ||
target = np.array([1, 2, 3]) | ||
return target | ||
|
||
|
||
@pytest.fixture | ||
def fixed_sampler(fixed_target_pose: npt.NDArray) -> FixedSampler: | ||
return FixedSampler(fixed_target_pose) | ||
|
||
|
||
@pytest.fixture | ||
def gaussian_sampler() -> GaussianSampler: | ||
return GaussianSampler( | ||
mu=np.array([1, 2, 3]), | ||
sigma=np.array([2, 3, 4]), | ||
seed=SEED | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def uniform_sampler() -> UniformSampler: | ||
return UniformSampler( | ||
low=np.array([1, 2, 3]), | ||
high=np.array([2, 3, 4]), | ||
seed=SEED | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def grid_sampler() -> GridSampler: | ||
return GridSampler( | ||
low=np.array([1, 2, 3]), | ||
high=np.array([2, 3, 4]), | ||
n_points_per_axis=np.array([5, 3, 1]) | ||
) | ||
|
||
|
||
def test_fixed_sampler( | ||
fixed_sampler: FixedSampler, | ||
fixed_target_pose: npt.NDArray | ||
): | ||
sampled_pose = fixed_sampler.sample_initial_pose() | ||
assert_allclose(sampled_pose, fixed_target_pose) | ||
|
||
|
||
def test_gaussian_sampler( | ||
gaussian_sampler: GaussianSampler, | ||
gaussian_target_pose: npt.NDArray | ||
): | ||
sampled_pose = gaussian_sampler.sample_initial_pose() | ||
assert_allclose(sampled_pose, gaussian_target_pose) | ||
|
||
|
||
def test_uniform_sampler( | ||
uniform_sampler: UniformSampler, | ||
uniform_target_pose: npt.NDArray | ||
): | ||
sampled_pose = uniform_sampler.sample_initial_pose() | ||
assert_allclose(sampled_pose, uniform_target_pose) | ||
|
||
|
||
def test_grid_sampler( | ||
grid_sampler: GridSampler, | ||
grid_target_pose: npt.NDArray | ||
): | ||
sampled_pose = grid_sampler.sample_initial_pose() | ||
|
||
for i in range(20): | ||
print(grid_sampler.sample_initial_pose()) | ||
|
||
assert len(grid_sampler.grid) == 15 | ||
assert_allclose(sampled_pose, grid_target_pose) | ||
|