Skip to content

Commit 2dfbf71

Browse files
committed
added check for GPU into eval.py
1 parent a73d9c3 commit 2dfbf71

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

eval.py

+4
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,13 @@
109109

110110
vocab = infos['vocab'] # ix -> word mapping
111111

112+
#check for GPU
113+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
114+
112115
# Setup the model
113116
model = models.setup(opt)
114117
model.load_state_dict(torch.load(opt.model, map_location=torch.device(device)))
118+
model.to(device=device)
115119
model.eval()
116120
crit = utils.LanguageModelCriterion()
117121

0 commit comments

Comments
 (0)