Skip to content

Commit

Permalink
fixing ft bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mlkakram committed Jul 30, 2024
1 parent 81e235b commit 0175c39
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 394 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ class ForceDiffusionConfig:
clip_sample_range: float = 1.0

#Transformer
n_layer: int = 12
n_head: int = 12
n_layer: int = 8
n_head: int = 4
n_emb: int = 768
p_drop_emb: float = 0.0
p_drop_attn: float = 0.01
casual_attn: bool = True
time_as_cond: bool = True
obs_as_cond: bool = True
n_cond_layers: int = 4
n_cond_layers: int = 2

# Inference
num_inference_steps: int | None = None
Expand Down Expand Up @@ -199,7 +199,7 @@ def __post_init__(self):
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM", "DPM"]
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
Expand Down
27 changes: 17 additions & 10 deletions lerobot/common/policies/diffusion/modeling_force_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def __init__(

self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

self.other_obs = [k for k in config.input_shapes if not k.startswith("observation.image")]

self.reset()

def get_optimizer_parameters(self):
Expand All @@ -121,7 +123,6 @@ def reset(self):
self._queues = {
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"observation.ft": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}

Expand Down Expand Up @@ -149,6 +150,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.state"] = torch.cat([batch[k] for k in self.other_obs], dim=-1)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)

Expand All @@ -169,6 +171,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.state"] = torch.cat([batch[k] for k in self.other_obs], dim=-1)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
Expand All @@ -183,8 +186,6 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
elif name == "DPM":
return DPMSolverMultistepScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")

Expand All @@ -197,13 +198,19 @@ def __init__(self, config: ForceDiffusionConfig):
self.rgb_encoder = DiffusionRgbEncoder(config)
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])

global_cond_dim = config.input_shapes["observation.state"][0] + config.input_shapes["observation.ft"][0] + self.rgb_encoder.feature_dim * num_images
# Get the keys that do not start with "observation.image"
other_obs_keys = [k for k in config.input_shapes if not k.startswith("observation.image")]

# Sum the first dimension of the shapes of these keys
state_shape = sum(config.input_shapes[key][0] for key in other_obs_keys) # add to another later

global_cond_dim = state_shape + self.rgb_encoder.feature_dim * num_images

if self.config.model == "FILM":
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(
config.input_shapes["observation.state"][0] + config.input_shapes["observation.ft"][0] + self.rgb_encoder.feature_dim * num_images
state_shape + self.rgb_encoder.feature_dim * num_images
)
* config.n_obs_steps,
)
Expand Down Expand Up @@ -282,12 +289,12 @@ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
) #img_feature shape is config.softmax * 2 * num_cameras

# Concatenate state and image features then flatten to (B, global_cond_dim) incase of transformer it would be (B, T, global_cond_dim)
if self.config.model == "FILM":
output = torch.cat([batch["observation.state"], batch["observation.ft"], img_features], dim=-1).flatten(start_dim=1)
output = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
elif self.config.model == "TRANSFORMER":
output = torch.cat([batch["observation.state"], batch["observation.ft"], img_features], dim=-1) # u add the state to the end of the image feature [B, To, sp*num_cam + statedim]
output = torch.cat([batch["observation.state"], img_features], dim=-1) # u add the state to the end of the image feature [B, To, sp*num_cam + statedim]

return output

Expand Down Expand Up @@ -321,14 +328,14 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.state": (B, n_obs_steps, state_dim + whatever extra)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.ft", "observation.images", "action", "action_is_pad"})
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"}) # needs to be changed
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
Expand Down
Loading

0 comments on commit 0175c39

Please sign in to comment.