Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Add relative foot odom pose shift tracking termination conditions. #820

Merged
merged 2 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
[gym/common] Add unit test checking that obs wrappers preserve key or…
…dering.
  • Loading branch information
duburcqa committed Jun 28, 2024
commit 23f8bdb99a8b4e1411dbc4889f05aa82305cb886
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
MinimizeAngularMomentumReward,
MinimizeFrictionReward,
BaseRollPitchTermination,
BaseHeightTermination,
FallingTermination,
FootCollisionTermination,
FlyingTermination,
ImpactForceTermination)
Expand Down Expand Up @@ -57,7 +57,7 @@
"MechanicalPowerConsumptionTermination",
"FlyingTermination",
"BaseRollPitchTermination",
"BaseHeightTermination",
"FallingTermination",
"FootCollisionTermination",
"ImpactForceTermination"
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Rewards mainly relevant for locomotion tasks on floating-base robots.
"""
import math
from functools import partial
from dataclasses import dataclass
from typing import Optional, Union, Sequence, Literal, Callable, cast
Expand Down Expand Up @@ -309,7 +308,7 @@ def __init__(self,
is_training_only=is_training_only)


class BaseHeightTermination(QuantityTermination):
class FallingTermination(QuantityTermination):
"""Terminate the episode immediately if the floating base of the robot
gets too close from the ground.

Expand All @@ -323,14 +322,14 @@ class BaseHeightTermination(QuantityTermination):
"""
def __init__(self,
env: InterfaceJiminyEnv,
min_height: float,
min_base_height: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param min_height: Minimum height of the floating base of the robot
below which termination is triggered.
:param min_base_height: Minimum height of the floating base of the
robot below which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -344,7 +343,7 @@ def __init__(self,
env,
"termination_base_height",
(BaseRelativeHeight, {}), # type: ignore[arg-type]
min_height,
min_base_height,
None,
grace_period,
is_truncation=False,
Expand Down Expand Up @@ -542,15 +541,15 @@ class ImpactForceTermination(QuantityTermination):
"""
def __init__(self,
env: InterfaceJiminyEnv,
max_force: float,
max_force_rel: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_force: Maximum vertical force applied on any of the contact
points or collision bodies above which termination is
triggered.
:param max_force_rel: Maximum vertical force applied on any of the
contact points or collision bodies above which
termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -568,7 +567,7 @@ def __init__(self,
axis=0,
keys=(2,))),
None,
max_force,
max_force_rel,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
Expand Down
4 changes: 2 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def clip(data: DataNested, space: gym.Space[DataNested]) -> DataNested:
field: clip(data[field], subspace)
for field, subspace in space.spaces.items()})
if tree.issubclass_sequence(data_type):
return data_type(tuple(
return data_type([
clip(data[i], subspace)
for i, subspace in enumerate(space.spaces)))
for i, subspace in enumerate(space.spaces)])
return _array_clip(data, *get_bounds(space))


Expand Down
27 changes: 27 additions & 0 deletions python/gym_jiminy/unit_py/test_pipeline_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,30 @@ def test_repeatability(self):
assert np.all(a_prev == env.robot_state.a)
for _ in range(n_steps):
env.step(env.action)

def test_preserve_obs_key_order(self):
""" TODO: Write documentation.
"""
env = AtlasPDControlJiminyEnv()

env_stack = StackObservation(
env, skip_frames_ratio=-1, num_stack=2, nested_filter_keys=[["t"]])
env_filter = FilterObservation(
env, nested_filter_keys=env.observation_space.keys())
env_obs_norm = NormalizeObservation(env)
for env in (env, env_stack, env_filter, env_obs_norm):
env.reset(seed=0)
assert [*env.observation_space.keys()] == [*env.observation.keys()]

env_flat = FlattenObservation(env)
env_flat.reset(seed=0)
all_values_flat = []
obs_nodes = list(env.observation.values())
while obs_nodes:
value = obs_nodes.pop()
if isinstance(value, dict):
obs_nodes += value.values()
else:
all_values_flat.append(value.flatten())
obs_flat = np.concatenate(all_values_flat[::-1])
np.testing.assert_allclose(env_flat.observation, obs_flat)
16 changes: 8 additions & 8 deletions python/gym_jiminy/unit_py/test_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,18 +587,18 @@ def test_foot_relative_pose(self):
for frame_name in ("l_foot", "r_foot"):
frame_index = env.robot.pinocchio_model.getFrameId(frame_name)
foot_poses.append(env.robot.pinocchio_data.oMf[frame_index])
pos_feet = np.stack(tuple(
foot_pose.translation for foot_pose in foot_poses), axis=-1)
quat_feet = np.stack(tuple(
pos_feet = np.stack([
foot_pose.translation for foot_pose in foot_poses], axis=-1)
quat_feet = np.stack([
matrix_to_quat(foot_pose.rotation)
for foot_pose in foot_poses), axis=-1)
for foot_pose in foot_poses], axis=-1)

pos_mean = np.mean(pos_feet, axis=-1, keepdims=True)
rot_mean = quat_to_matrix(quat_average(quat_feet))
pos_rel = rot_mean.T @ (pos_feet - pos_mean)
quat_rel = np.stack(tuple(
quat_rel = np.stack([
matrix_to_quat(rot_mean.T @ foot_pose.rotation)
for foot_pose in foot_poses), axis=-1)
for foot_pose in foot_poses], axis=-1)
quat_rel[-4:] *= np.sign(quat_rel[-1])

value = env.quantities["foot_rel_poses"].copy()
Expand All @@ -621,9 +621,9 @@ def test_contact_spatial_forces(self):

gravity = abs(env.robot.pinocchio_model.gravity.linear[2])
robot_weight = env.robot.pinocchio_data.mass[0] * gravity
force_spatial_rel = np.stack(tuple(np.concatenate(
force_spatial_rel = np.stack([np.concatenate(
(constraint.lambda_c[:3], np.zeros((2,)), constraint.lambda_c[[3]])
) for constraint in env.robot.constraints.contact_frames.values()),
) for constraint in env.robot.constraints.contact_frames.values()],
axis=-1) / robot_weight
np.testing.assert_allclose(
force_spatial_rel, env.quantities["force_spatial_rel"])
Expand Down
4 changes: 2 additions & 2 deletions python/gym_jiminy/unit_py/test_terminations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
DriftTrackingQuantityTermination,
ShiftTrackingQuantityTermination,
BaseRollPitchTermination,
BaseHeightTermination,
FallingTermination,
FootCollisionTermination,
MechanicalSafetyTermination,
FlyingTermination,
Expand Down Expand Up @@ -345,7 +345,7 @@
_, _, terminated, _, _ = self.env.step(action)
if terminated:
break
terminated, truncated = termination_pos({})

Check warning on line 348 in python/gym_jiminy/unit_py/test_terminations.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/unit_py/test_terminations.py#L348

Unused variable 'truncated'
diff = quantity_pos.value_left - quantity_pos.value_right
is_valid = np.linalg.norm(diff) <= MAX_POS_ERROR
assert terminated ^ is_valid
Expand All @@ -358,7 +358,7 @@
""" TODO: Write documentation
"""
for termination in (
BaseHeightTermination(self.env, 0.6),
FallingTermination(self.env, 0.6),
ImpactForceTermination(self.env, 1.0),
MechanicalPowerConsumptionTermination(self.env, 400.0, 1.0),
ShiftTrackingMotorPositionsTermination(self.env, 0.4, 0.5),
Expand Down
8 changes: 3 additions & 5 deletions python/jiminy_py/src/jiminy_py/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def _unflatten_as(data: StructNested[Any],
if issubclass_mapping(data_type): # type: ignore[arg-type]
flat_items = [
(key, _unflatten_as(value, data_leaf_it))
for key, value in data.items() # type: ignore[union-attr]
]
for key, value in data.items()] # type: ignore[union-attr]
try:
# Initialisation from dict cannot be the default path as
# `gym.spaces.Dict` would sort keys in this specific scenario,
Expand All @@ -188,9 +187,8 @@ def _unflatten_as(data: StructNested[Any],
# sequence of key-value pairs.
return data_type(dict(flat_items)) # type: ignore[call-arg]
if issubclass_sequence(data_type): # type: ignore[arg-type]
return data_type(tuple( # type: ignore[call-arg]
_unflatten_as(value, data_leaf_it) for value in data
))
return data_type([ # type: ignore[call-arg]
_unflatten_as(value, data_leaf_it) for value in data])
return next(data_leaf_it)


Expand Down
Loading