Skip to content

Commit c972332

Browse files
committed
For the simple use cases where no Enum or unexpected config objects are in the save files, use weights_only=True. Significantly cuts down on the number of torch warnings. #1429
1 parent 1d5a75b commit c972332

File tree

18 files changed

+24
-16
lines changed

18 files changed

+24
-16
lines changed

stanza/models/charlm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def get_current_lr(trainer, args):
206206
return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0]
207207

208208
def load_char_vocab(vocab_file):
209-
return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage))}
209+
return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))}
210210

211211
def train(args):
212212
utils.log_training_args(args, logger)

stanza/models/classifiers/trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def load(filename, args, foundation_cache=None, load_optimizer=False):
6969
else:
7070
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
7171
try:
72+
# TODO: switch to weights_only=True
73+
# need to convert enums to int first
7274
checkpoint = torch.load(filename, lambda storage, loc: storage)
7375
except BaseException:
7476
logger.exception("Cannot load model from {}".format(filename))

stanza/models/common/char_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def from_full_state(cls, state, finetune=False):
268268

269269
@classmethod
270270
def load(cls, filename, finetune=False):
271-
state = torch.load(filename, lambda storage, loc: storage)
271+
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
272272
# allow saving just the Model object,
273273
# and allow for old charlms to still work
274274
if 'state_dict' in state:
@@ -342,7 +342,7 @@ def load(cls, args, filename, finetune=False):
342342
Note that you MUST set finetune=True if planning to continue training
343343
Otherwise the only benefit you will get will be a warm GPU
344344
"""
345-
state = torch.load(filename, lambda storage, loc: storage)
345+
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
346346
model = CharacterLanguageModel.from_full_state(state['model'], finetune)
347347
model = model.to(args['device'])
348348

stanza/models/common/pretrain.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def emb(self):
5353
def load(self):
5454
if self.filename is not None and os.path.exists(self.filename):
5555
try:
56+
# TODO: update all pretrains to satisfy weights_only=True
5657
data = torch.load(self.filename, lambda storage, loc: storage)
5758
logger.debug("Loaded pretrain from {}".format(self.filename))
5859
if not isinstance(data, dict):

stanza/models/common/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def save(self, filename):
1313
torch.save(savedict, filename)
1414

1515
def load(self, filename):
16-
savedict = torch.load(filename, lambda storage, loc: storage)
16+
savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True)
1717

1818
self.model.load_state_dict(savedict['model'])
1919
if self.args['mode'] == 'train':

stanza/models/constituency/base_trainer.py

+3
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_
8484
else:
8585
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args['save_dir'], filename)))
8686
try:
87+
# TODO: currently cannot switch this to weights_only=True
88+
# without in some way changing the model to save enums in
89+
# a safe manner, probably by converting to int
8790
checkpoint = torch.load(filename, lambda storage, loc: storage)
8891
except BaseException:
8992
logger.exception("Cannot load model from %s", filename)

stanza/models/coref/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def load_weights(self,
224224
if map_location is None:
225225
map_location = self.config.device
226226
logger.debug(f"Loading from {path}...")
227+
# TODO: the config is preventing us from using weights_only=True
227228
state_dicts = torch.load(path, map_location=map_location)
228229
self.epochs_trained = state_dicts.pop("epochs_trained", 0)
229230
# just ignore a config in the model, since we should already have one

stanza/models/depparse/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None
191191
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
192192
"""
193193
try:
194-
checkpoint = torch.load(filename, lambda storage, loc: storage)
194+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
195195
except BaseException:
196196
logger.error("Cannot load model from {}".format(filename))
197197
raise

stanza/models/langid/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def load(cls, path, device=None, batch_size=64, lang_subset=None):
114114
raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name")
115115
if not os.path.exists(path):
116116
raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path)
117-
checkpoint = torch.load(path, map_location=torch.device("cpu"))
117+
checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
118118
weights = checkpoint["model_state_dict"]["loss_train.weight"]
119119
model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"],
120120
checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights,

stanza/models/lemma/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def save(self, filename, skip_modules=True):
236236

237237
def load(self, filename, args, foundation_cache):
238238
try:
239-
checkpoint = torch.load(filename, lambda storage, loc: storage)
239+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
240240
except BaseException:
241241
logger.error("Cannot load model from {}".format(filename))
242242
raise

stanza/models/mwt/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def save(self, filename):
198198

199199
def load(self, filename):
200200
try:
201-
checkpoint = torch.load(filename, lambda storage, loc: storage)
201+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
202202
except BaseException:
203203
logger.error("Cannot load model from {}".format(filename))
204204
raise

stanza/models/ner/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=
5555
def init_vocab(self, data):
5656
def from_model(model_filename):
5757
""" Try loading vocab from charLM model file. """
58-
state_dict = torch.load(model_filename, lambda storage, loc: storage)
58+
state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True)
5959
if 'vocab' in state_dict:
6060
return state_dict['vocab']
6161
if 'model' in state_dict and 'vocab' in state_dict['model']:

stanza/models/ner/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def save(self, filename, skip_modules=True):
194194

195195
def load(self, filename, pretrain=None, args=None, foundation_cache=None):
196196
try:
197-
checkpoint = torch.load(filename, lambda storage, loc: storage)
197+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
198198
except BaseException:
199199
logger.error("Cannot load model from {}".format(filename))
200200
raise

stanza/models/pos/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def load(self, filename, pretrain, args=None, foundation_cache=None):
136136
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
137137
"""
138138
try:
139-
checkpoint = torch.load(filename, lambda storage, loc: storage)
139+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
140140
except BaseException:
141141
logger.error("Cannot load model from {}".format(filename))
142142
raise

stanza/models/tokenization/trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def save(self, filename):
7979

8080
def load(self, filename):
8181
try:
82+
# the tokenizers with dictionaries won't properly load weights_only=True because they have a set
8283
checkpoint = torch.load(filename, lambda storage, loc: storage)
8384
except BaseException:
8485
logger.error("Cannot load model from {}".format(filename))

stanza/tests/depparse/test_parser.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
145145
save_name = trainer.args['save_name']
146146
filename = tmp_path / save_name
147147
assert os.path.exists(filename)
148-
checkpoint = torch.load(filename, lambda storage, loc: storage)
148+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
149149
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
150150

151151
# Test loading the saved model, saving it, and still having bert in it
@@ -157,7 +157,7 @@ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
157157
saved_model.save(filename)
158158

159159
# This is the part that would fail if the force_bert_saved option did not exist
160-
checkpoint = torch.load(filename, lambda storage, loc: storage)
160+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
161161
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
162162

163163
def test_with_peft(self, tmp_path, wordvec_pretrain_file):

stanza/tests/lemma/test_lemma_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,5 @@ def test_charlm_train(self, tmp_path, charlm_args):
150150
# check that the charlm wasn't saved in here
151151
args = saved_model.args
152152
save_name = os.path.join(args['save_dir'], args['save_name'])
153-
checkpoint = torch.load(save_name, lambda storage, loc: storage)
153+
checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True)
154154
assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys())

stanza/tests/ner/test_ner_training.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_train_model_cpu(pretrain_file, tmp_path):
226226
assert str(device).startswith("cpu")
227227

228228
def model_file_has_bert(filename):
229-
checkpoint = torch.load(filename, lambda storage, loc: storage)
229+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
230230
return any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
231231

232232
def test_with_bert(pretrain_file, tmp_path):
@@ -253,7 +253,7 @@ def test_with_peft_finetune(pretrain_file, tmp_path):
253253
# TODO: check that the peft tensors are moving when training?
254254
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft')
255255
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
256-
checkpoint = torch.load(model_file, lambda storage, loc: storage)
256+
checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True)
257257
assert 'bert_lora' in checkpoint
258258
assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys())
259259

0 commit comments

Comments
 (0)