Skip to content

Commit

Permalink
refactor to add sim_features for type
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Jan 18, 2025
1 parent d5b49e5 commit 79d7458
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 31 deletions.
50 changes: 26 additions & 24 deletions predicators/envs/pybullet_ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ class PyBulletAntEnv(PyBulletEnv):

# Food has color channels + "attractive" as 0.0 or 1.0
_food_type = Type(
"food", ["x", "y", "z", "rot", "is_held", "attractive", "r", "g", "b"])
"food", ["x", "y", "z", "rot", "is_held", "attractive", "r", "g", "b"],
sim_features=["r", "g", "b", "attractive"])

# Each ant might have orientation, but minimal for demonstration
_ant_type = Type("ant", ["x", "y", "z", "rot"])
_ant_type = Type("ant", ["x", "y", "z", "rot"],
sim_features=["target_food"])

def __init__(self,
use_gui: bool = True,
Expand Down Expand Up @@ -225,22 +227,22 @@ def _get_state(self) -> State:
is_held_val = 1.0 if (food_obj.id == self._held_obj_id) else 0.0
# Just keep placeholders for r,g,b,attractive for now—will read from init_dict
# or store them in environment if needed.
if not hasattr(food_obj, "_r"):
if not hasattr(food_obj, "r"):
# fallback if not yet assigned
food_obj._r = 0.5
food_obj._g = 0.5
food_obj._b = 0.5
food_obj._attractive = 0.0
food_obj.r = 0.5
food_obj.g = 0.5
food_obj.b = 0.5
food_obj.attractive = 0.0
state_dict[food_obj] = {
"x": fx,
"y": fy,
"z": fz,
"rot": utils.wrap_angle(yaw),
"is_held": is_held_val,
"attractive": food_obj._attractive,
"r": food_obj._r,
"g": food_obj._g,
"b": food_obj._b,
"attractive": food_obj.attractive,
"r": food_obj.r,
"g": food_obj.g,
"b": food_obj.b,
}

# 3) Ants
Expand Down Expand Up @@ -319,7 +321,7 @@ def _update_ant_positions(self, state: State) -> None:
food."""
for ant_obj in self.ants:
# Retrieve this ant’s assigned food
target_food_obj = getattr(ant_obj, "_target_food", None)
target_food_obj = getattr(ant_obj, "target_food", None)
if target_food_obj is None:
continue

Expand Down Expand Up @@ -432,30 +434,30 @@ def _make_tasks(self, num_tasks: int,
color_idx = self._color_indices[i]
color_rgba = self.color_palette[color_idx]
# Store color in object attributes
fobj._r = color_rgba[0]
fobj._g = color_rgba[1]
fobj._b = color_rgba[2]
fobj.r = color_rgba[0]
fobj.g = color_rgba[1]
fobj.b = color_rgba[2]
# If color is in attractive_colors, set "attractive"=1
if color_rgba in attractive_colors:
fobj._attractive = 1.0
fobj.attractive = 1.0
else:
fobj._attractive = 0.0
fobj.attractive = 0.0

init_dict[fobj] = {
"x": x,
"y": y,
"z": self.z_lb + self.food_half_extents[2], # on table
"rot": rot,
"is_held": 0.0,
"attractive": fobj._attractive,
"r": fobj._r,
"g": fobj._g,
"b": fobj._b,
"attractive": fobj.attractive,
"r": fobj.r,
"g": fobj.g,
"b": fobj.b,
}

# Collect the "attractive" foods for random assignment
attractive_food_objs = [
f for f in self.food if f._attractive == 1.0
f for f in self.food if f.attractive == 1.0
]

# 3) Ants
Expand All @@ -472,9 +474,9 @@ def _make_tasks(self, num_tasks: int,
# Assign a random attractive block if any exist
# (store that choice as an attribute in the Python object)
if attractive_food_objs:
aobj._target_food = rng.choice(attractive_food_objs)
aobj.target_food = rng.choice(attractive_food_objs)
else:
aobj._target_food = None
aobj.target_food = None

self._objects = [self._robot] + self.food + self.ants
init_state = utils.create_state_from_dict(init_dict)
Expand Down
3 changes: 2 additions & 1 deletion predicators/envs/pybullet_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class PyBulletCircuitEnv(PyBulletEnv):
# Types
_robot_type = Type("robot", ["x", "y", "z", "fingers", "tilt", "wrist"])
_wire_type = Type("wire", ["x", "y", "z", "rot", "is_held"])
_battery_type = Type("battery", ["x", "y", "z", "rot"])
_battery_type = Type("battery", ["x", "y", "z", "rot"],
sim_features=["id", "joint_id", "joint_scale"])
_light_type = Type("light", ["x", "y", "z", "rot", "is_on"])

def __init__(self, use_gui: bool = True) -> None:
Expand Down
19 changes: 19 additions & 0 deletions predicators/envs/pybullet_coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,3 +1265,22 @@ def _get_jug_handle_grasp(cls, state: State,
target_y += 0.02
target_z = cls.z_lb + cls.jug_handle_height()
return (target_x, target_y, target_z)

if __name__ == "__main__":
"""Run a simple simulation to test the environment."""
import time

# Make a task
CFG.seed = 1
CFG.pybullet_sim_steps_per_action = 1
env = PyBulletCoffeeEnv(use_gui=True)
rng = np.random.default_rng(CFG.seed)
task = env._make_tasks(1, rng)[0]
env._reset_state(task.init)

while True:
# Robot does nothing
action = Action(np.array(env._pybullet_robot.initial_joint_positions))

env.step(action)
time.sleep(0.01)
6 changes: 4 additions & 2 deletions predicators/envs/pybullet_fan.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class PyBulletFanEnv(PyBulletEnv):
"rot", # base orientation (Z euler)
"side", # 0=left,1=right,2=back,3=front
"is_on", # whether the controlling switch is on
])
],
sim_features=["id", "joint_id", "side_idx"])
# New separate switch type:
_switch_type = Type(
"switch",
Expand All @@ -99,7 +100,8 @@ class PyBulletFanEnv(PyBulletEnv):
"rot", # switch orientation
"side", # matches fan side
"is_on", # is this switch on
])
],
sim_features=["id", "joint_id", "side_idx"])
_wall_type = Type("wall", ["x", "y", "z", "rot", "length"])
_ball_type = Type("ball", ["x", "y", "z"])
_target_type = Type("target", ["x", "y", "z", "rot", "is_hit"])
Expand Down
20 changes: 20 additions & 0 deletions predicators/envs/pybullet_grow.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,23 @@ def _create_pybullet_liquid_for_cup(self, cup: Object,
basePosition=pose,
baseOrientation=orientation,
physicsClientId=self._physics_client_id)


if __name__ == "__main__":
"""Run a simple simulation to test the environment."""
import time

# Make a task
CFG.seed = 1
CFG.pybullet_sim_steps_per_action = 1
env = PyBulletGrowEnv(use_gui=True)
rng = np.random.default_rng(CFG.seed)
task = env._make_tasks(1, rng)[0]
env._reset_state(task.init)

while True:
# Robot does nothing
action = Action(np.array(env._pybullet_robot.initial_joint_positions))

env.step(action)
time.sleep(0.01)
3 changes: 2 additions & 1 deletion predicators/envs/pybullet_laser.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class PyBulletLaserEnv(PyBulletEnv):
# Types
# -------------
_robot_type = Type("robot", ["x", "y", "z", "fingers", "tilt", "wrist"])
_station_type = Type("station", ["x", "y", "z", "rot", "is_on"])
_station_type = Type("station", ["x", "y", "z", "rot", "is_on"],
sim_features=["id", "joint_id"])
_mirror_type = Type("mirror", ["x", "y", "z", "rot", "split_mirror"])
_target_type = Type("target", ["x", "y", "z", "rot", "is_hit"])

Expand Down
40 changes: 37 additions & 3 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@

@dataclass(frozen=True, order=True)
class Type:
"""Struct defining a type."""
"""Struct defining a type.
sim_feature_names are features stored in an object, and usually won't change
throughout and across tasks. An example is the object's pybullet id."""
name: str
feature_names: Sequence[str] = field(repr=False)
parent: Optional[Type] = field(default=None, repr=False)
sim_features: Sequence[str] = field(default_factory=lambda: ["id"],
repr=False)

@property
def dim(self) -> int:
Expand Down Expand Up @@ -62,7 +66,6 @@ class _TypedEntity:
"""
name: str
type: Type
id: Optional[int] = None

@cached_property
def _str(self) -> str:
Expand All @@ -88,14 +91,45 @@ def is_instance(self, t: Type) -> bool:
cur_type = cur_type.parent
return False


@dataclass(frozen=False, order=True, repr=False)
class Object(_TypedEntity):
"""Struct defining an Object, which is just a _TypedEntity whose name does
not start with "?"."""
sim_data: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
assert not self.name.startswith("?")
# Initialize sim_data from the Type's sim_features
for sim_feature in self.type.sim_features:
self.sim_data[sim_feature] = None # Default to None
# Keep track of allowed attributes
self._allowed_attributes = {"id", "sim_data"}.union(self.sim_data.keys())

def __getattr__(self, name: str) -> Any:
# Bypass custom logic for internal attributes
# Use object.__getattribute__(...) instead of self.sim_data
sim_data = object.__getattribute__(self, "sim_data")
if name in sim_data:
return sim_data[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __setattr__(self, name: str, value: Any) -> None:
# Always allow the dataclass fields (e.g., "name", "type", "sim_data").
if name in {"name", "type", "sim_data", "_allowed_attributes"}:
super().__setattr__(name, value)
return

# For anything else, check _allowed_attributes.
allowed_attrs = object.__getattribute__(self, "_allowed_attributes") \
if object.__getattribute__(self, "__dict__").get("_allowed_attributes") else set()
if name in allowed_attrs:
sim_data = object.__getattribute__(self, "sim_data")
if name in sim_data:
sim_data[name] = value
else:
super().__setattr__(name, value)
else:
raise AttributeError(f"Cannot set unknown attribute '{name}'")

def __hash__(self) -> int:
# By default, the dataclass generates a new __hash__ method when
Expand Down

0 comments on commit 79d7458

Please sign in to comment.