Skip to content

Commit defa4fa

Browse files
authored
make sure model on target device
1 parent 6237b93 commit defa4fa

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

going_modular/going_modular/engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def train(model: torch.nn.Module,
160160
"test_loss": [],
161161
"test_acc": []
162162
}
163+
164+
# Make sure model on target device
165+
model.to(device)
163166

164167
# Loop through training and testing steps for a number of epochs
165168
for epoch in tqdm(range(epochs)):

0 commit comments

Comments
 (0)