Skip to content

Commit 5f6d41c

Browse files
committed
make pytorch=1.1 and cuda 10, should work
1 parent 664225e commit 5f6d41c

File tree

7 files changed

+36
-18
lines changed

7 files changed

+36
-18
lines changed

components/vocab.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# coding=utf-8
22

33
from __future__ import print_function
4-
import argparse
4+
55
from collections import Counter
66
from itertools import chain
7-
import torch
7+
88

99
class VocabEntry(object):
1010
def __init__(self):
@@ -46,6 +46,11 @@ def add(self, word):
4646
def is_unk(self, word):
4747
return word not in self
4848

49+
def merge(self, other_vocab_entry):
50+
for word in other_vocab_entry.word2id:
51+
self.add(word)
52+
53+
4954
@staticmethod
5055
def from_corpus(corpus, size, freq_cutoff=0):
5156
vocab_entry = VocabEntry()
@@ -55,7 +60,8 @@ def from_corpus(corpus, size, freq_cutoff=0):
5560
singletons = [w for w in word_freq if word_freq[w] == 1]
5661
print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq),
5762
len(non_singletons)))
58-
print('singletons: %s' % singletons)
63+
print('number of singletons: ', len(singletons))
64+
# print('singletons: %s' % singletons)
5965

6066
top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size]
6167
words_not_included = []

config/env/frank.yml

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
name: tranX
22
channels:
3+
- pytorch
34
- anaconda
45
- defaults
56
dependencies:
6-
- cudatoolkit=10.1.168=0
7-
- cudnn=7.6.0=cuda10.1_0
8-
- cupti=10.1.168
7+
- astor=0.7.1
8+
- cudatoolkit=10.0.130
99
- python=3.7.3
10+
- pytorch=1.1.0
1011
- pip:
11-
- compare-mt==0.2.7
12-
- elasticsearch==7.0.5
13-
- six==1.12.0
14-
- xgboost==0.90
15-
- torch==1.0.1.post2
16-
- astor==0.7.1
12+
- six
13+
- xgboost
1714
- tqdm

datasets/conala/dataset.py

+14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from datasets.conala.evaluator import ConalaEvaluator
1616
from datasets.conala.util import *
1717

18+
assert astor.__version__ == '0.7.1'
1819

1920
def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, code_freq=3,
2021
mined_data_file=None, vocab_size=20000, num_mined=0, out_dir='data/conala'):
@@ -33,11 +34,23 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
3334
dev_examples = train_examples[:200]
3435
train_examples = train_examples[200:]
3536

37+
mined_examples = None
3638
if mined_data_file and num_mined > 0:
3739
print("use mined data: ", num_mined)
3840
print("from file: ", mined_data_file)
3941
mined_examples = preprocess_dataset(mined_data_file, name='mined', transition_system=transition_system,
4042
firstk=num_mined)
43+
# mined_src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_examples], size=vocab_size,
44+
# freq_cutoff=src_freq)
45+
# mined_primitive_tokens = [map(lambda a: a.action.token,
46+
# filter(lambda a: isinstance(a.action, GenTokenAction), e.tgt_actions))
47+
# for e in train_examples]
48+
# mined_primitive_vocab = VocabEntry.from_corpus(mined_primitive_tokens, size=vocab_size, freq_cutoff=code_freq)
49+
#
50+
# # generate vocabulary for the code tokens!
51+
# mined_code_tokens = [transition_system.tokenize_code(e.tgt_code, mode='decoder') for e in train_examples]
52+
# mined_code_vocab = VocabEntry.from_corpus(mined_code_tokens, size=vocab_size, freq_cutoff=code_freq)
53+
4154
pickle.dump(mined_examples, open(os.path.join(out_dir, 'pre_{}.bin'.format(num_mined)), 'wb'))
4255
train_examples += mined_examples
4356

@@ -57,6 +70,7 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
5770

5871
# generate vocabulary for the code tokens!
5972
code_tokens = [transition_system.tokenize_code(e.tgt_code, mode='decoder') for e in train_examples]
73+
6074
code_vocab = VocabEntry.from_corpus(code_tokens, size=vocab_size, freq_cutoff=code_freq)
6175

6276
vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)

exp.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import time
55

6+
import astor
67
import six.moves.cPickle as pickle
78
from six.moves import input
89
from six.moves import xrange as range
@@ -21,6 +22,7 @@
2122
from model.reconstruction_model import Reconstructor
2223
from model.utils import GloveHelper
2324

25+
assert astor.__version__ == "0.7.1"
2426
if six.PY3:
2527
# import additional packages for wikisql dataset (works only under Python 3)
2628
pass
@@ -146,7 +148,7 @@ def train(args):
146148
print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
147149
eval_start = time.time()
148150
eval_results = evaluation.evaluate(dev_set.examples, model, evaluator, args,
149-
verbose=True, eval_top_pred_only=args.eval_top_pred_only)
151+
verbose=False, eval_top_pred_only=args.eval_top_pred_only)
150152
dev_score = eval_results[evaluator.default_metric]
151153

152154
print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (

model/nn_utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import torch
44
import torch.nn.functional as F
55
import torch.nn.init as init
6-
import numpy as np
76

8-
import torch
97
import torch.nn as nn
108
from torch.autograd import Variable
119
import numpy as np

scripts/conala/finetune.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ set -e
33

44
seed=0
55
mined_num=$1
6-
vocab="data/conala/vocab.src_freq3.code_freq3.mined_${mined_num}.bin"
6+
pretrained_model_name=$2
7+
freq=${3:-3}
8+
vocab="data/conala/vocab.src_freq${freq}.code_freq${freq}.mined_${mined_num}.bin"
79
finetune_file="data/conala/train.bin"
810
dev_file="data/conala/dev.bin"
911
dropout=0.3
@@ -17,7 +19,6 @@ lr_decay=0.5
1719
beam_size=15
1820
lstm='lstm' # lstm
1921
lr_decay_after_epoch=15
20-
pretrained_model_name=$2
2122
model_name=finetune.conala.${lstm}.hidden${hidden_size}.embed${embed_size}.action${action_embed_size}.field${field_embed_size}.type${type_embed_size}.dr${dropout}.lr${lr}.lr_de${lr_decay}.lr_da${lr_decay_after_epoch}.beam${beam_size}.seed${seed}.pre_${mined_num}
2223

2324
echo "**** Writing results to logs/conala/${model_name}.log ****"

scripts/conala/vanilla.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
set -e
33

4-
seed=0
4+
seed=${1:-0}
55
vocab="data/conala/vocab.src_freq3.code_freq3.mined_0.bin"
66
train_file="data/conala/train.mined_0.bin"
77
dev_file="data/conala/dev.bin"

0 commit comments

Comments
 (0)