From 22d159933d909926287da18c7251c9f524ac3d75 Mon Sep 17 00:00:00 2001 From: WangYihang Date: Thu, 11 Apr 2024 15:02:38 +0800 Subject: [PATCH] Fixed issue #181 --- zeta/nn/modules/simple_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeta/nn/modules/simple_mamba.py b/zeta/nn/modules/simple_mamba.py index 362a7059..9df0d9b2 100644 --- a/zeta/nn/modules/simple_mamba.py +++ b/zeta/nn/modules/simple_mamba.py @@ -199,7 +199,7 @@ def selective_scan(self, u, delta, A, B, C, D): ) # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) - x = torch.zeros((b, d_in, n)) + x = torch.zeros((b, d_in, n), device=next(self.parameters()).device) ys = [] for i in range(l): x = deltaA[:, :, i] * x + deltaB_u[:, :, i]