Skip to content

Commit 32e6026

Browse files
committedFeb 5, 2020
Bulk updates.
1 parent 2ef31d0 commit 32e6026

File tree

2 files changed

+135
-164
lines changed

2 files changed

+135
-164
lines changed
 

‎cscg.ipynb

+134-163
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
"import sys\n",
1010
"sys.path.insert(0, './language_model/')\n",
1111
"\n",
12-
"# for suppressing T.save warnings\n",
13-
"# see https://discuss.pyT.org/t/got-warning-couldnt-retrieve-source-code-for-container/7689\n",
1412
"import warnings\n",
1513
"warnings.simplefilter('ignore', UserWarning)"
1614
]
@@ -77,7 +75,7 @@
7775
"DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/django')\n",
7876
"CONALA_DIR = os.path.join(ROOT_DIR, 'raw-datasets/conala-corpus')\n",
7977
"\n",
80-
"DATASET_DIR = CONALA_DIR\n",
78+
"DATASET_DIR = DJANGO_DIR\n",
8179
"EMB_DIR = os.path.join(ROOT_DIR, 'embeddings')\n",
8280
"\n",
8381
"print(f'Dataset: {os.path.basename(DATASET_DIR)}')"
@@ -103,8 +101,8 @@
103101
"d = pd.DataFrame([{'a': _a, 'c': _c} for (_a, _c) in zip(a, c)])\n",
104102
"d.describe()\n",
105103
"\n",
106-
"a = round(len(list(filter(lambda x: x <= 10, a))) / len(a), 3)\n",
107-
"c = round(len(list(filter(lambda x: x <= 10, c))) / len(c), 3)\n",
104+
"a = round(len(list(filter(lambda x: x <= 24, a))) / len(a), 3)\n",
105+
"c = round(len(list(filter(lambda x: x <= 20, c))) / len(c), 3)\n",
108106
"a, c"
109107
]
110108
},
@@ -127,10 +125,10 @@
127125
"CFG.dataset_cfg = Config()\n",
128126
"CFG.dataset_cfg.__dict__ = {\n",
129127
" 'root_dir': DATASET_DIR,\n",
130-
" 'anno_min_freq': 1,\n",
131-
" 'code_min_freq': 1,\n",
132-
" 'anno_seq_maxlen': 10,\n",
133-
" 'code_seq_maxlen': 10,\n",
128+
" 'anno_min_freq': 10,\n",
129+
" 'code_min_freq': 10,\n",
130+
" 'anno_seq_maxlen': 24,\n",
131+
" 'code_seq_maxlen': 20,\n",
134132
" 'emb_file': os.path.join(EMB_DIR, 'glove.6B.200d-ft-9-1.txt.pickle'),\n",
135133
"}\n",
136134
"\n",
@@ -140,8 +138,8 @@
140138
"CFG.anno = Config() \n",
141139
"CFG.anno.__dict__ = {\n",
142140
" 'lstm_hidden_size': 64,\n",
143-
" 'lstm_dropout_p': 0.0,\n",
144-
" 'att_dropout_p': 0.0,\n",
141+
" 'lstm_dropout_p': 0.2,\n",
142+
" 'att_dropout_p': 0.1,\n",
145143
" 'lang': dataset.anno_lang,\n",
146144
" 'load_pretrained_emb': True,\n",
147145
" 'emb_size': 200,\n",
@@ -151,15 +149,15 @@
151149
"CFG.code = Config() \n",
152150
"CFG.code.__dict__ = {\n",
153151
" 'lstm_hidden_size': 64,\n",
154-
" 'lstm_dropout_p': 0.0,\n",
155-
" 'att_dropout_p': 0.0,\n",
152+
" 'lstm_dropout_p': 0.2,\n",
153+
" 'att_dropout_p': 0.1,\n",
156154
" 'lang': dataset.code_lang,\n",
157155
" 'load_pretrained_emb': False,\n",
158-
" 'emb_size': 50,\n",
156+
" 'emb_size': 32,\n",
159157
"}\n",
160158
"\n",
161159
"CFG.__dict__.update({\n",
162-
" 'exp_name': f'{os.path.basename(DATASET_DIR)}-p{1}-a{1}',\n",
160+
" 'exp_name': f'{os.path.basename(DATASET_DIR)}-p{0}-a{0}-minfreq10',\n",
163161
" 'cuda': True,\n",
164162
" 'batch_size': 128,\n",
165163
" 'num_epochs': 50,\n",
@@ -337,7 +335,7 @@
337335
"for f in lm_paths.values():\n",
338336
" assert os.path.exists(f), f'Language Model: file <{f}> does not exist!'\n",
339337
" \n",
340-
"dataset.compute_lm_probs(lm_paths)"
338+
"_ = dataset.compute_lm_probs(lm_paths)"
341339
]
342340
},
343341
{
@@ -465,15 +463,9 @@
465463
" emb.weight = nn.Parameter(T.tensor(config.lang.emb_matrix, dtype=T.float32))\n",
466464
" emb.weight.requires_grad = False\n",
467465
" \n",
468-
" return emb"
469-
]
470-
},
471-
{
472-
"cell_type": "code",
473-
"execution_count": null,
474-
"metadata": {},
475-
"outputs": [],
476-
"source": [
466+
" return emb\n",
467+
"\n",
468+
"\n",
477469
"class Model(nn.Module):\n",
478470
" def __init__(self, config: Config, model_type):\n",
479471
" \"\"\"\n",
@@ -682,35 +674,7 @@
682674
" xa = T.sum(xa, dim=1) / n\n",
683675
" xb = T.sum(xb, dim=1) / n\n",
684676
" \n",
685-
" return 0.5 * (xa + xb)\n",
686-
"\n",
687-
"def JSD_2(A, B, mask=None):\n",
688-
" eps = 1e-8\n",
689-
" \n",
690-
" assert A.shape == B.shape\n",
691-
" b, n, m = A.shape\n",
692-
" \n",
693-
" js = []\n",
694-
" for bi in range(b):\n",
695-
" kl_a, kl_b = 0, 0\n",
696-
" \n",
697-
" for i in range(n):\n",
698-
" a = A[bi, i, :]\n",
699-
" b = B[bi, i, :]\n",
700-
" \n",
701-
" if mask is not None:\n",
702-
" a[mask[i]] = -(1e8)\n",
703-
" b[mask[i]] = -(1e8)\n",
704-
" \n",
705-
" a = F.softmax(a) + eps\n",
706-
" b = F.softmax(b) + eps\n",
707-
" m = 0.5 * (a + b)\n",
708-
" kl_a += stats.entropy(a, m) / n\n",
709-
" kl_b += stats.entropy(b, m) / n\n",
710-
" \n",
711-
" js += [0.5 * (kl_a + kl_b)]\n",
712-
" \n",
713-
" return T.tensor(js)"
677+
" return 0.5 * (xa + xb)"
714678
]
715679
},
716680
{
@@ -802,8 +766,8 @@
802766
" \n",
803767
" # final loss\n",
804768
" p, a = 0, 0\n",
805-
" l_cg = T.mean(l_cg_ce + p * 0.01 * l_dual + a * 0.2 * l_att)\n",
806-
" l_cs = T.mean(l_cs_ce + p * 0.01 * l_dual + a * 0.2 * l_att)\n",
769+
" l_cg = T.mean(l_cg_ce + p * 0.5 * l_dual + a * 0.9 * l_att)\n",
770+
" l_cs = T.mean(l_cs_ce + p * 0.5 * l_dual + a * 0.9 * l_att)\n",
807771
" \n",
808772
" # optimize CG\n",
809773
" cg_model.opt.zero_grad()\n",
@@ -848,7 +812,9 @@
848812
"outputs": [],
849813
"source": [
850814
"torch.save(cg_model.state_dict(), os.path.join(exp_dir, 'cg_model.pt'))\n",
851-
"torch.save(cs_model.state_dict(), os.path.join(exp_dir, 'cs_model.pt'))"
815+
"torch.save(cs_model.state_dict(), os.path.join(exp_dir, 'cs_model.pt'))\n",
816+
"\n",
817+
"tb_writer.close()"
852818
]
853819
},
854820
{
@@ -858,6 +824,24 @@
858824
"# 5. Evaluate"
859825
]
860826
},
827+
{
828+
"cell_type": "code",
829+
"execution_count": null,
830+
"metadata": {},
831+
"outputs": [],
832+
"source": [
833+
"cg_model = Model(CFG, model_type='cg')\n",
834+
"cs_model = Model(CFG, model_type='cs')\n",
835+
"\n",
836+
"# exp_dir = f'./experiments/{CFG.exp_name}'\n",
837+
"exp_dir = f'./experiments/{os.path.basename(DATASET_DIR)}-p{0}-a{1}'\n",
838+
"\n",
839+
"cg_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cg_model.pt')))\n",
840+
"cs_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cs_model.pt')))\n",
841+
"\n",
842+
"exp_dir"
843+
]
844+
},
861845
{
862846
"cell_type": "markdown",
863847
"metadata": {},
@@ -872,65 +856,21 @@
872856
"outputs": [],
873857
"source": [
874858
"def is_valid_code(line):\n",
859+
" \"valid <=> (complete ^ valid) v (incomplete ^ valid_prefix)\"\n",
875860
" try:\n",
876861
" codeop.compile_command(line)\n",
877862
" except SyntaxError:\n",
878863
" return False\n",
879-
" else:\n",
880-
" return True"
881-
]
882-
},
883-
{
884-
"cell_type": "code",
885-
"execution_count": 54,
886-
"metadata": {},
887-
"outputs": [
888-
{
889-
"name": "stdout",
890-
"output_type": "stream",
891-
"text": [
892-
"tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3])\n",
893-
"tensor([1, 2, 3, 4, 4, 4, 7, 1, 1, 2, 2, 9])\n",
894-
"tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1])\n"
895-
]
896-
}
897-
],
898-
"source": [
899-
"import torch as T\n",
900-
"\n",
901-
"x = T.tensor([1,2,3,4,5,6,7,8,9,1,2,3])\n",
902-
"y = T.tensor([1,2,3,4,4,4,7,1,1,2,2,9])\n",
903-
"m = T.tensor([1,1,1,1,1,1,1,0,0,0,0,1])\n",
904-
"\n",
905-
"code_pred = y\n",
906-
"code = x\n",
907-
"code_mask = m\n",
864+
" \n",
865+
" return True\n",
908866
"\n",
909-
"print(x)\n",
910-
"print(y)\n",
911-
"print(m)"
912-
]
913-
},
914-
{
915-
"cell_type": "code",
916-
"execution_count": 55,
917-
"metadata": {},
918-
"outputs": [
919-
{
920-
"data": {
921-
"text/plain": [
922-
"(tensor(0.4167), tensor(0.6250), 0.625, tensor(0.6250))"
923-
]
924-
},
925-
"execution_count": 55,
926-
"metadata": {},
927-
"output_type": "execute_result"
928-
}
929-
],
930-
"source": [
931-
"r1 = T.mean(((x == y) * m).float()).cpu()\n",
932-
"r2 = ((x == y) * m).float().sum() / m.sum()\n",
933-
"r1, r2, (5 / 8), (((code_pred == code) * code_mask).float().sum() / code_mask.sum()).cpu()"
867+
"def to_tok(xs, mode):\n",
868+
" z = (xs)[0].cpu()\n",
869+
" z = z[(z!=0)&(z!=1)&(z!=2)&(z!=3)]\n",
870+
" if mode == 'code':\n",
871+
" return dataset.code_lang.to_tokens(z)[0]\n",
872+
" if mode == 'anno':\n",
873+
" return dataset.anno_lang.to_tokens(z)[0]"
934874
]
935875
},
936876
{
@@ -950,14 +890,24 @@
950890
"# ---\n",
951891
"\n",
952892
"test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)\n",
953-
"\n",
893+
"assert len(test_loader) == len(dataset) - n"
894+
]
895+
},
896+
{
897+
"cell_type": "code",
898+
"execution_count": null,
899+
"metadata": {},
900+
"outputs": [],
901+
"source": [
954902
"ms = ['ind_match', 'exact_match', 'coverage']\n",
955903
"metrics = {\n",
956904
" 'anno': {k: 0 for k in ms},\n",
957905
" 'code': {k: 0 for k in ms}\n",
958906
"}\n",
959907
"metrics['code']['pov'] = 0\n",
960908
"\n",
909+
"anno_toks, code_toks = [], []\n",
910+
"\n",
961911
"with T.no_grad():\n",
962912
" cg_model.eval()\n",
963913
" cs_model.eval()\n",
@@ -967,8 +917,11 @@
967917
" anno, code = anno.cuda(), code.cuda() \n",
968918
" \n",
969919
" # binary mask indicating the presence of padding token\n",
970-
" anno_mask = T.tensor(anno != dataset.anno_lang.token2index['<pad>']).byte()\n",
971-
" code_mask = T.tensor(code != dataset.code_lang.token2index['<pad>']).byte()\n",
920+
"# anno_mask = T.tensor(anno != dataset.anno_lang.token2index['<pad>']).byte()\n",
921+
"# code_mask = T.tensor(code != dataset.code_lang.token2index['<pad>']).byte()\n",
922+
"\n",
923+
" anno_mask = T.tensor((anno != 0) * (anno != 1)).byte()\n",
924+
" code_mask = T.tensor((code != 0) * (code != 1)).byte()\n",
972925
" \n",
973926
" # forward pass\n",
974927
" code_pred, code_att_mat = cg_model(src=anno, tgt=code)\n",
@@ -992,36 +945,72 @@
992945
" metrics['anno']['exact_match'] += 1 / len(test_loader)\n",
993946
" \n",
994947
" # 3)\n",
995-
" cc = set([x.item() for x in code[0].cpu().data if x.item() != 0]) - \\\n",
996-
" set([x.item() for x in code_pred[0].cpu().data if x.item() != 0])\n",
997-
" if len(cc) == 0:\n",
948+
" sy = set([x.item() for x in (code * code_mask)[0].cpu().data if x.item() != 0])\n",
949+
" sy_ = set([x.item() for x in (code_pred * code_mask)[0].cpu().data if x.item() != 0])\n",
950+
" if len(set.difference(sy_, sy)) == 0:\n",
998951
" metrics['code']['coverage'] += 1 / len(test_loader)\n",
952+
" else:\n",
953+
" if np.isclose(code_score, 1):\n",
954+
" print(set.difference(sy_, sy))\n",
999955
" \n",
1000-
" ac = set([x.item() for x in anno[0].cpu().data if x.item() != 0]) - \\\n",
1001-
" set([x.item() for x in anno_pred[0].cpu().data if x.item() != 0])\n",
1002-
" if len(ac) == 0:\n",
956+
" sy = set([x.item() for x in (anno * anno_mask)[0].cpu().data if x.item() != 0])\n",
957+
" sy_ = set([x.item() for x in (anno_pred * anno_mask)[0].cpu().data if x.item() != 0])\n",
958+
" if len(set.difference(sy_, sy)) == 0:\n",
1003959
" metrics['anno']['coverage'] += 1 / len(test_loader)\n",
1004960
" \n",
1005961
" # 4)\n",
1006-
" c = (code_pred * code_mask)[0].cpu()\n",
1007-
" c = c[(c!=0)&(c!=1)&(c!=2)&(c!=3)]\n",
1008-
" c = dataset.code_lang.to_tokens(c)[0]\n",
1009-
" if is_valid_code(' '.join(c)):\n",
962+
" if is_valid_code(' '.join(to_tok(code_pred * code_mask, 'code'))):\n",
1010963
" metrics['code']['pov'] += 1 / len(test_loader)\n",
964+
"\n",
965+
" # save tokens\n",
966+
" code_toks += [(round(code_score.item(), 5), \n",
967+
" to_tok(code_pred * code_mask, 'code'), \n",
968+
" to_tok(code * code_mask, 'code'),\n",
969+
" code_pred[0].cpu(),\n",
970+
" code[0].cpu())]\n",
971+
" \n",
972+
" anno_toks += [(round(anno_score.item(), 5), \n",
973+
" to_tok(anno_pred * anno_mask, 'anno'), \n",
974+
" to_tok(anno * anno_mask, 'anno'),\n",
975+
" anno_pred[0].cpu(),\n",
976+
" anno[0].cpu())]\n",
1011977
" \n",
978+
"code_toks = sorted(code_toks, key=lambda x: x[0])\n",
979+
"anno_toks = sorted(anno_toks, key=lambda x: x[0])\n",
980+
"\n",
981+
"with open(os.path.join(exp_dir, 'eval_code.txt'), 'wt') as fp:\n",
982+
" for i, (s, pt, tt, p, t) in enumerate(code_toks, start=1):\n",
983+
" fp.write(f'{i}\\n')\n",
984+
" fp.write(f'{s}\\n')\n",
985+
" fp.write(f'pred: {\" \".join(pt)}\\n')\n",
986+
" fp.write(f'true: {\" \".join(tt)}\\n')\n",
987+
" fp.write(f'pred_raw: {p}\\n')\n",
988+
" fp.write(f'true_raw: {t}\\n')\n",
989+
" fp.write(f'{\"-\"*80}\\n')\n",
990+
" \n",
991+
"with open(os.path.join(exp_dir, 'eval_anno.txt'), 'wt') as fp:\n",
992+
" for i, (s, pt, tt, p, t) in enumerate(anno_toks, start=1):\n",
993+
" fp.write(f'{i}\\n')\n",
994+
" fp.write(f'{s}\\n')\n",
995+
" fp.write(f'pred: {\" \".join(pt)}\\n')\n",
996+
" fp.write(f'true: {\" \".join(tt)}\\n')\n",
997+
" fp.write(f'pred_raw: {p}\\n')\n",
998+
" fp.write(f'true_raw: {t}\\n')\n",
999+
" fp.write(f'{\"-\"*80}\\n')\n",
10121000
"\n",
10131001
"# results\n",
1014-
"for t in metrics:\n",
1015-
" print(t)\n",
1016-
" for k, v in metrics[t].items():\n",
1017-
" print(f'{k:>16s}: {v:7.5f}')"
1002+
"print(exp_dir.split('/')[-1])\n",
1003+
"print(len(test_loader))\n",
1004+
"for k in ms:\n",
1005+
" print(f\"{metrics['anno'][k]:7.5f}/{metrics['code'][k]:7.5f}\", end=' ')\n",
1006+
"print(round(metrics['code']['pov'], 5))"
10181007
]
10191008
},
10201009
{
10211010
"cell_type": "markdown",
10221011
"metadata": {},
10231012
"source": [
1024-
"## 5.2. Translate"
1013+
"## 5.2. Attention matrices"
10251014
]
10261015
},
10271016
{
@@ -1030,54 +1019,36 @@
10301019
"metadata": {},
10311020
"outputs": [],
10321021
"source": [
1033-
"cg_model = Model(CFG, model_type='cg')\n",
1034-
"cs_model = Model(CFG, model_type='cs')\n",
1035-
"\n",
1036-
"exp_dir = f'./experiments/{CFG.exp_name}'\n",
1022+
"a = T.tensor([ 2, 576, 16, 84, 474, 695, 0, 0, 0, 3])\n",
1023+
"c = T.tensor([ 2, 155, 489, 10, 159, 5, 8, 0, 0, 3])\n",
10371024
"\n",
1038-
"cg_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cg_model.pt')))\n",
1039-
"cs_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cs_model.pt')))\n",
1040-
"\n",
1041-
"exp_dir"
1042-
]
1043-
},
1044-
{
1045-
"cell_type": "markdown",
1046-
"metadata": {},
1047-
"source": [
1048-
"## 5.3. Attention matrices"
1049-
]
1050-
},
1051-
{
1052-
"cell_type": "code",
1053-
"execution_count": null,
1054-
"metadata": {},
1055-
"outputs": [],
1056-
"source": [
10571025
"with T.no_grad():\n",
1058-
" i = np.random.randint(len(train_dataset))\n",
1059-
" a, c, _, _ = train_dataset[i]\n",
1026+
" i = np.random.randint(len(test_dataset))\n",
1027+
"# i = 5557\n",
1028+
" a, c, _, _ = test_dataset[-1]\n",
10601029
" a, c = a.cuda(), c.cuda()\n",
1030+
" anno_mask = T.tensor((a != 0) * (a != 1)).byte().cuda()\n",
1031+
" code_mask = T.tensor((c != 0) * (c != 1)).byte().cuda()\n",
10611032
" x, x_mat = cg_model(src=a.unsqueeze(0), tgt=c.unsqueeze(0))\n",
10621033
" y, y_mat = cs_model(src=c.unsqueeze(0), tgt=a.unsqueeze(0))\n",
1063-
" x = x[0].argmax(dim=1).cpu()\n",
1034+
" x = x[0].argmax(dim=-1)\n",
10641035
" x_mat = x_mat[0].cpu()\n",
1065-
" y = y[0].argmax(dim=1).cpu()\n",
1036+
" y = y[0].argmax(dim=-1)\n",
10661037
" y_mat = y_mat[0].cpu()\n",
10671038
" \n",
1068-
" ct = dataset.code_lang.to_tokens(c)[0]\n",
1069-
" at = dataset.anno_lang.to_tokens(a)[0]\n",
1070-
" xt = dataset.code_lang.to_tokens(x)[0]\n",
1071-
" yt = dataset.anno_lang.to_tokens(y)[0]\n",
1039+
" ct = to_tok((c * code_mask).unsqueeze(0), 'code')\n",
1040+
" xt = to_tok((x * code_mask).unsqueeze(0), 'code')\n",
1041+
" at = to_tok((a * anno_mask).unsqueeze(0), 'anno')\n",
1042+
" yt = to_tok((y * anno_mask).unsqueeze(0), 'anno')\n",
10721043
" \n",
10731044
"\n",
1074-
"plt.figure(figsize=(16, 10))\n",
1045+
"plt.figure(figsize=(12, 8))\n",
10751046
"\n",
10761047
"# plt.subplot(1, 2, 1)\n",
10771048
"plt.imshow(F.softmax(y_mat, -1), cmap='jet')\n",
10781049
"plt.grid(False)\n",
1079-
"# plt.xticks(np.arange(len(ct)), labels=ct, rotation=90)\n",
1080-
"# plt.yticks(np.arange(len(at)), labels=at)\n",
1050+
"plt.xticks(np.arange(len(ct)), labels=ct, rotation=90)\n",
1051+
"plt.yticks(np.arange(len(at)), labels=at)\n",
10811052
"\n",
10821053
"# plt.subplot(1, 2, 2)\n",
10831054
"# plt.imshow(F.softmax(y_mat, -1), cmap='jet')\n",

‎notes.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
"name": "python",
9393
"nbconvert_exporter": "python",
9494
"pygments_lexer": "ipython3",
95-
"version": "3.7.5"
95+
"version": "3.7.6"
9696
}
9797
},
9898
"nbformat": 4,

0 commit comments

Comments
 (0)
Please sign in to comment.