diff --git a/zeta/nn/modules/simple_mamba.py b/zeta/nn/modules/simple_mamba.py index 362a7059..9df0d9b2 100644 --- a/zeta/nn/modules/simple_mamba.py +++ b/zeta/nn/modules/simple_mamba.py @@ -199,7 +199,7 @@ def selective_scan(self, u, delta, A, B, C, D): ) # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) - x = torch.zeros((b, d_in, n)) + x = torch.zeros((b, d_in, n), device=next(self.parameters()).device) ys = [] for i in range(l): x = deltaA[:, :, i] * x + deltaB_u[:, :, i]