|
9 | 9 | "import sys\n",
|
10 | 10 | "sys.path.insert(0, './language_model/')\n",
|
11 | 11 | "\n",
|
12 |
| - "# for suppressing T.save warnings\n", |
13 |
| - "# see https://discuss.pyT.org/t/got-warning-couldnt-retrieve-source-code-for-container/7689\n", |
14 | 12 | "import warnings\n",
|
15 | 13 | "warnings.simplefilter('ignore', UserWarning)"
|
16 | 14 | ]
|
|
77 | 75 | "DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/django')\n",
|
78 | 76 | "CONALA_DIR = os.path.join(ROOT_DIR, 'raw-datasets/conala-corpus')\n",
|
79 | 77 | "\n",
|
80 |
| - "DATASET_DIR = CONALA_DIR\n", |
| 78 | + "DATASET_DIR = DJANGO_DIR\n", |
81 | 79 | "EMB_DIR = os.path.join(ROOT_DIR, 'embeddings')\n",
|
82 | 80 | "\n",
|
83 | 81 | "print(f'Dataset: {os.path.basename(DATASET_DIR)}')"
|
|
103 | 101 | "d = pd.DataFrame([{'a': _a, 'c': _c} for (_a, _c) in zip(a, c)])\n",
|
104 | 102 | "d.describe()\n",
|
105 | 103 | "\n",
|
106 |
| - "a = round(len(list(filter(lambda x: x <= 10, a))) / len(a), 3)\n", |
107 |
| - "c = round(len(list(filter(lambda x: x <= 10, c))) / len(c), 3)\n", |
| 104 | + "a = round(len(list(filter(lambda x: x <= 24, a))) / len(a), 3)\n", |
| 105 | + "c = round(len(list(filter(lambda x: x <= 20, c))) / len(c), 3)\n", |
108 | 106 | "a, c"
|
109 | 107 | ]
|
110 | 108 | },
|
|
127 | 125 | "CFG.dataset_cfg = Config()\n",
|
128 | 126 | "CFG.dataset_cfg.__dict__ = {\n",
|
129 | 127 | " 'root_dir': DATASET_DIR,\n",
|
130 |
| - " 'anno_min_freq': 1,\n", |
131 |
| - " 'code_min_freq': 1,\n", |
132 |
| - " 'anno_seq_maxlen': 10,\n", |
133 |
| - " 'code_seq_maxlen': 10,\n", |
| 128 | + " 'anno_min_freq': 10,\n", |
| 129 | + " 'code_min_freq': 10,\n", |
| 130 | + " 'anno_seq_maxlen': 24,\n", |
| 131 | + " 'code_seq_maxlen': 20,\n", |
134 | 132 | " 'emb_file': os.path.join(EMB_DIR, 'glove.6B.200d-ft-9-1.txt.pickle'),\n",
|
135 | 133 | "}\n",
|
136 | 134 | "\n",
|
|
140 | 138 | "CFG.anno = Config() \n",
|
141 | 139 | "CFG.anno.__dict__ = {\n",
|
142 | 140 | " 'lstm_hidden_size': 64,\n",
|
143 |
| - " 'lstm_dropout_p': 0.0,\n", |
144 |
| - " 'att_dropout_p': 0.0,\n", |
| 141 | + " 'lstm_dropout_p': 0.2,\n", |
| 142 | + " 'att_dropout_p': 0.1,\n", |
145 | 143 | " 'lang': dataset.anno_lang,\n",
|
146 | 144 | " 'load_pretrained_emb': True,\n",
|
147 | 145 | " 'emb_size': 200,\n",
|
|
151 | 149 | "CFG.code = Config() \n",
|
152 | 150 | "CFG.code.__dict__ = {\n",
|
153 | 151 | " 'lstm_hidden_size': 64,\n",
|
154 |
| - " 'lstm_dropout_p': 0.0,\n", |
155 |
| - " 'att_dropout_p': 0.0,\n", |
| 152 | + " 'lstm_dropout_p': 0.2,\n", |
| 153 | + " 'att_dropout_p': 0.1,\n", |
156 | 154 | " 'lang': dataset.code_lang,\n",
|
157 | 155 | " 'load_pretrained_emb': False,\n",
|
158 |
| - " 'emb_size': 50,\n", |
| 156 | + " 'emb_size': 32,\n", |
159 | 157 | "}\n",
|
160 | 158 | "\n",
|
161 | 159 | "CFG.__dict__.update({\n",
|
162 |
| - " 'exp_name': f'{os.path.basename(DATASET_DIR)}-p{1}-a{1}',\n", |
| 160 | + " 'exp_name': f'{os.path.basename(DATASET_DIR)}-p{0}-a{0}-minfreq10',\n", |
163 | 161 | " 'cuda': True,\n",
|
164 | 162 | " 'batch_size': 128,\n",
|
165 | 163 | " 'num_epochs': 50,\n",
|
|
337 | 335 | "for f in lm_paths.values():\n",
|
338 | 336 | " assert os.path.exists(f), f'Language Model: file <{f}> does not exist!'\n",
|
339 | 337 | " \n",
|
340 |
| - "dataset.compute_lm_probs(lm_paths)" |
| 338 | + "_ = dataset.compute_lm_probs(lm_paths)" |
341 | 339 | ]
|
342 | 340 | },
|
343 | 341 | {
|
|
465 | 463 | " emb.weight = nn.Parameter(T.tensor(config.lang.emb_matrix, dtype=T.float32))\n",
|
466 | 464 | " emb.weight.requires_grad = False\n",
|
467 | 465 | " \n",
|
468 |
| - " return emb" |
469 |
| - ] |
470 |
| - }, |
471 |
| - { |
472 |
| - "cell_type": "code", |
473 |
| - "execution_count": null, |
474 |
| - "metadata": {}, |
475 |
| - "outputs": [], |
476 |
| - "source": [ |
| 466 | + " return emb\n", |
| 467 | + "\n", |
| 468 | + "\n", |
477 | 469 | "class Model(nn.Module):\n",
|
478 | 470 | " def __init__(self, config: Config, model_type):\n",
|
479 | 471 | " \"\"\"\n",
|
|
682 | 674 | " xa = T.sum(xa, dim=1) / n\n",
|
683 | 675 | " xb = T.sum(xb, dim=1) / n\n",
|
684 | 676 | " \n",
|
685 |
| - " return 0.5 * (xa + xb)\n", |
686 |
| - "\n", |
687 |
| - "def JSD_2(A, B, mask=None):\n", |
688 |
| - " eps = 1e-8\n", |
689 |
| - " \n", |
690 |
| - " assert A.shape == B.shape\n", |
691 |
| - " b, n, m = A.shape\n", |
692 |
| - " \n", |
693 |
| - " js = []\n", |
694 |
| - " for bi in range(b):\n", |
695 |
| - " kl_a, kl_b = 0, 0\n", |
696 |
| - " \n", |
697 |
| - " for i in range(n):\n", |
698 |
| - " a = A[bi, i, :]\n", |
699 |
| - " b = B[bi, i, :]\n", |
700 |
| - " \n", |
701 |
| - " if mask is not None:\n", |
702 |
| - " a[mask[i]] = -(1e8)\n", |
703 |
| - " b[mask[i]] = -(1e8)\n", |
704 |
| - " \n", |
705 |
| - " a = F.softmax(a) + eps\n", |
706 |
| - " b = F.softmax(b) + eps\n", |
707 |
| - " m = 0.5 * (a + b)\n", |
708 |
| - " kl_a += stats.entropy(a, m) / n\n", |
709 |
| - " kl_b += stats.entropy(b, m) / n\n", |
710 |
| - " \n", |
711 |
| - " js += [0.5 * (kl_a + kl_b)]\n", |
712 |
| - " \n", |
713 |
| - " return T.tensor(js)" |
| 677 | + " return 0.5 * (xa + xb)" |
714 | 678 | ]
|
715 | 679 | },
|
716 | 680 | {
|
|
802 | 766 | " \n",
|
803 | 767 | " # final loss\n",
|
804 | 768 | " p, a = 0, 0\n",
|
805 |
| - " l_cg = T.mean(l_cg_ce + p * 0.01 * l_dual + a * 0.2 * l_att)\n", |
806 |
| - " l_cs = T.mean(l_cs_ce + p * 0.01 * l_dual + a * 0.2 * l_att)\n", |
| 769 | + " l_cg = T.mean(l_cg_ce + p * 0.5 * l_dual + a * 0.9 * l_att)\n", |
| 770 | + " l_cs = T.mean(l_cs_ce + p * 0.5 * l_dual + a * 0.9 * l_att)\n", |
807 | 771 | " \n",
|
808 | 772 | " # optimize CG\n",
|
809 | 773 | " cg_model.opt.zero_grad()\n",
|
|
848 | 812 | "outputs": [],
|
849 | 813 | "source": [
|
850 | 814 | "torch.save(cg_model.state_dict(), os.path.join(exp_dir, 'cg_model.pt'))\n",
|
851 |
| - "torch.save(cs_model.state_dict(), os.path.join(exp_dir, 'cs_model.pt'))" |
| 815 | + "torch.save(cs_model.state_dict(), os.path.join(exp_dir, 'cs_model.pt'))\n", |
| 816 | + "\n", |
| 817 | + "tb_writer.close()" |
852 | 818 | ]
|
853 | 819 | },
|
854 | 820 | {
|
|
858 | 824 | "# 5. Evaluate"
|
859 | 825 | ]
|
860 | 826 | },
|
| 827 | + { |
| 828 | + "cell_type": "code", |
| 829 | + "execution_count": null, |
| 830 | + "metadata": {}, |
| 831 | + "outputs": [], |
| 832 | + "source": [ |
| 833 | + "cg_model = Model(CFG, model_type='cg')\n", |
| 834 | + "cs_model = Model(CFG, model_type='cs')\n", |
| 835 | + "\n", |
| 836 | + "# exp_dir = f'./experiments/{CFG.exp_name}'\n", |
| 837 | + "exp_dir = f'./experiments/{os.path.basename(DATASET_DIR)}-p{0}-a{1}'\n", |
| 838 | + "\n", |
| 839 | + "cg_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cg_model.pt')))\n", |
| 840 | + "cs_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cs_model.pt')))\n", |
| 841 | + "\n", |
| 842 | + "exp_dir" |
| 843 | + ] |
| 844 | + }, |
861 | 845 | {
|
862 | 846 | "cell_type": "markdown",
|
863 | 847 | "metadata": {},
|
|
872 | 856 | "outputs": [],
|
873 | 857 | "source": [
|
874 | 858 | "def is_valid_code(line):\n",
|
| 859 | + " \"valid <=> (complete ^ valid) v (incomplete ^ valid_prefix)\"\n", |
875 | 860 | " try:\n",
|
876 | 861 | " codeop.compile_command(line)\n",
|
877 | 862 | " except SyntaxError:\n",
|
878 | 863 | " return False\n",
|
879 |
| - " else:\n", |
880 |
| - " return True" |
881 |
| - ] |
882 |
| - }, |
883 |
| - { |
884 |
| - "cell_type": "code", |
885 |
| - "execution_count": 54, |
886 |
| - "metadata": {}, |
887 |
| - "outputs": [ |
888 |
| - { |
889 |
| - "name": "stdout", |
890 |
| - "output_type": "stream", |
891 |
| - "text": [ |
892 |
| - "tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3])\n", |
893 |
| - "tensor([1, 2, 3, 4, 4, 4, 7, 1, 1, 2, 2, 9])\n", |
894 |
| - "tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1])\n" |
895 |
| - ] |
896 |
| - } |
897 |
| - ], |
898 |
| - "source": [ |
899 |
| - "import torch as T\n", |
900 |
| - "\n", |
901 |
| - "x = T.tensor([1,2,3,4,5,6,7,8,9,1,2,3])\n", |
902 |
| - "y = T.tensor([1,2,3,4,4,4,7,1,1,2,2,9])\n", |
903 |
| - "m = T.tensor([1,1,1,1,1,1,1,0,0,0,0,1])\n", |
904 |
| - "\n", |
905 |
| - "code_pred = y\n", |
906 |
| - "code = x\n", |
907 |
| - "code_mask = m\n", |
| 864 | + " \n", |
| 865 | + " return True\n", |
908 | 866 | "\n",
|
909 |
| - "print(x)\n", |
910 |
| - "print(y)\n", |
911 |
| - "print(m)" |
912 |
| - ] |
913 |
| - }, |
914 |
| - { |
915 |
| - "cell_type": "code", |
916 |
| - "execution_count": 55, |
917 |
| - "metadata": {}, |
918 |
| - "outputs": [ |
919 |
| - { |
920 |
| - "data": { |
921 |
| - "text/plain": [ |
922 |
| - "(tensor(0.4167), tensor(0.6250), 0.625, tensor(0.6250))" |
923 |
| - ] |
924 |
| - }, |
925 |
| - "execution_count": 55, |
926 |
| - "metadata": {}, |
927 |
| - "output_type": "execute_result" |
928 |
| - } |
929 |
| - ], |
930 |
| - "source": [ |
931 |
| - "r1 = T.mean(((x == y) * m).float()).cpu()\n", |
932 |
| - "r2 = ((x == y) * m).float().sum() / m.sum()\n", |
933 |
| - "r1, r2, (5 / 8), (((code_pred == code) * code_mask).float().sum() / code_mask.sum()).cpu()" |
| 867 | + "def to_tok(xs, mode):\n", |
| 868 | + " z = (xs)[0].cpu()\n", |
| 869 | + " z = z[(z!=0)&(z!=1)&(z!=2)&(z!=3)]\n", |
| 870 | + " if mode == 'code':\n", |
| 871 | + " return dataset.code_lang.to_tokens(z)[0]\n", |
| 872 | + " if mode == 'anno':\n", |
| 873 | + " return dataset.anno_lang.to_tokens(z)[0]" |
934 | 874 | ]
|
935 | 875 | },
|
936 | 876 | {
|
|
950 | 890 | "# ---\n",
|
951 | 891 | "\n",
|
952 | 892 | "test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)\n",
|
953 |
| - "\n", |
| 893 | + "assert len(test_loader) == len(dataset) - n" |
| 894 | + ] |
| 895 | + }, |
| 896 | + { |
| 897 | + "cell_type": "code", |
| 898 | + "execution_count": null, |
| 899 | + "metadata": {}, |
| 900 | + "outputs": [], |
| 901 | + "source": [ |
954 | 902 | "ms = ['ind_match', 'exact_match', 'coverage']\n",
|
955 | 903 | "metrics = {\n",
|
956 | 904 | " 'anno': {k: 0 for k in ms},\n",
|
957 | 905 | " 'code': {k: 0 for k in ms}\n",
|
958 | 906 | "}\n",
|
959 | 907 | "metrics['code']['pov'] = 0\n",
|
960 | 908 | "\n",
|
| 909 | + "anno_toks, code_toks = [], []\n", |
| 910 | + "\n", |
961 | 911 | "with T.no_grad():\n",
|
962 | 912 | " cg_model.eval()\n",
|
963 | 913 | " cs_model.eval()\n",
|
|
967 | 917 | " anno, code = anno.cuda(), code.cuda() \n",
|
968 | 918 | " \n",
|
969 | 919 | " # binary mask indicating the presence of padding token\n",
|
970 |
| - " anno_mask = T.tensor(anno != dataset.anno_lang.token2index['<pad>']).byte()\n", |
971 |
| - " code_mask = T.tensor(code != dataset.code_lang.token2index['<pad>']).byte()\n", |
| 920 | + "# anno_mask = T.tensor(anno != dataset.anno_lang.token2index['<pad>']).byte()\n", |
| 921 | + "# code_mask = T.tensor(code != dataset.code_lang.token2index['<pad>']).byte()\n", |
| 922 | + "\n", |
| 923 | + " anno_mask = T.tensor((anno != 0) * (anno != 1)).byte()\n", |
| 924 | + " code_mask = T.tensor((code != 0) * (code != 1)).byte()\n", |
972 | 925 | " \n",
|
973 | 926 | " # forward pass\n",
|
974 | 927 | " code_pred, code_att_mat = cg_model(src=anno, tgt=code)\n",
|
|
992 | 945 | " metrics['anno']['exact_match'] += 1 / len(test_loader)\n",
|
993 | 946 | " \n",
|
994 | 947 | " # 3)\n",
|
995 |
| - " cc = set([x.item() for x in code[0].cpu().data if x.item() != 0]) - \\\n", |
996 |
| - " set([x.item() for x in code_pred[0].cpu().data if x.item() != 0])\n", |
997 |
| - " if len(cc) == 0:\n", |
| 948 | + " sy = set([x.item() for x in (code * code_mask)[0].cpu().data if x.item() != 0])\n", |
| 949 | + " sy_ = set([x.item() for x in (code_pred * code_mask)[0].cpu().data if x.item() != 0])\n", |
| 950 | + " if len(set.difference(sy_, sy)) == 0:\n", |
998 | 951 | " metrics['code']['coverage'] += 1 / len(test_loader)\n",
|
| 952 | + " else:\n", |
| 953 | + " if np.isclose(code_score, 1):\n", |
| 954 | + " print(set.difference(sy_, sy))\n", |
999 | 955 | " \n",
|
1000 |
| - " ac = set([x.item() for x in anno[0].cpu().data if x.item() != 0]) - \\\n", |
1001 |
| - " set([x.item() for x in anno_pred[0].cpu().data if x.item() != 0])\n", |
1002 |
| - " if len(ac) == 0:\n", |
| 956 | + " sy = set([x.item() for x in (anno * anno_mask)[0].cpu().data if x.item() != 0])\n", |
| 957 | + " sy_ = set([x.item() for x in (anno_pred * anno_mask)[0].cpu().data if x.item() != 0])\n", |
| 958 | + " if len(set.difference(sy_, sy)) == 0:\n", |
1003 | 959 | " metrics['anno']['coverage'] += 1 / len(test_loader)\n",
|
1004 | 960 | " \n",
|
1005 | 961 | " # 4)\n",
|
1006 |
| - " c = (code_pred * code_mask)[0].cpu()\n", |
1007 |
| - " c = c[(c!=0)&(c!=1)&(c!=2)&(c!=3)]\n", |
1008 |
| - " c = dataset.code_lang.to_tokens(c)[0]\n", |
1009 |
| - " if is_valid_code(' '.join(c)):\n", |
| 962 | + " if is_valid_code(' '.join(to_tok(code_pred * code_mask, 'code'))):\n", |
1010 | 963 | " metrics['code']['pov'] += 1 / len(test_loader)\n",
|
| 964 | + "\n", |
| 965 | + " # save tokens\n", |
| 966 | + " code_toks += [(round(code_score.item(), 5), \n", |
| 967 | + " to_tok(code_pred * code_mask, 'code'), \n", |
| 968 | + " to_tok(code * code_mask, 'code'),\n", |
| 969 | + " code_pred[0].cpu(),\n", |
| 970 | + " code[0].cpu())]\n", |
| 971 | + " \n", |
| 972 | + " anno_toks += [(round(anno_score.item(), 5), \n", |
| 973 | + " to_tok(anno_pred * anno_mask, 'anno'), \n", |
| 974 | + " to_tok(anno * anno_mask, 'anno'),\n", |
| 975 | + " anno_pred[0].cpu(),\n", |
| 976 | + " anno[0].cpu())]\n", |
1011 | 977 | " \n",
|
| 978 | + "code_toks = sorted(code_toks, key=lambda x: x[0])\n", |
| 979 | + "anno_toks = sorted(anno_toks, key=lambda x: x[0])\n", |
| 980 | + "\n", |
| 981 | + "with open(os.path.join(exp_dir, 'eval_code.txt'), 'wt') as fp:\n", |
| 982 | + " for i, (s, pt, tt, p, t) in enumerate(code_toks, start=1):\n", |
| 983 | + " fp.write(f'{i}\\n')\n", |
| 984 | + " fp.write(f'{s}\\n')\n", |
| 985 | + " fp.write(f'pred: {\" \".join(pt)}\\n')\n", |
| 986 | + " fp.write(f'true: {\" \".join(tt)}\\n')\n", |
| 987 | + " fp.write(f'pred_raw: {p}\\n')\n", |
| 988 | + " fp.write(f'true_raw: {t}\\n')\n", |
| 989 | + " fp.write(f'{\"-\"*80}\\n')\n", |
| 990 | + " \n", |
| 991 | + "with open(os.path.join(exp_dir, 'eval_anno.txt'), 'wt') as fp:\n", |
| 992 | + " for i, (s, pt, tt, p, t) in enumerate(anno_toks, start=1):\n", |
| 993 | + " fp.write(f'{i}\\n')\n", |
| 994 | + " fp.write(f'{s}\\n')\n", |
| 995 | + " fp.write(f'pred: {\" \".join(pt)}\\n')\n", |
| 996 | + " fp.write(f'true: {\" \".join(tt)}\\n')\n", |
| 997 | + " fp.write(f'pred_raw: {p}\\n')\n", |
| 998 | + " fp.write(f'true_raw: {t}\\n')\n", |
| 999 | + " fp.write(f'{\"-\"*80}\\n')\n", |
1012 | 1000 | "\n",
|
1013 | 1001 | "# results\n",
|
1014 |
| - "for t in metrics:\n", |
1015 |
| - " print(t)\n", |
1016 |
| - " for k, v in metrics[t].items():\n", |
1017 |
| - " print(f'{k:>16s}: {v:7.5f}')" |
| 1002 | + "print(exp_dir.split('/')[-1])\n", |
| 1003 | + "print(len(test_loader))\n", |
| 1004 | + "for k in ms:\n", |
| 1005 | + " print(f\"{metrics['anno'][k]:7.5f}/{metrics['code'][k]:7.5f}\", end=' ')\n", |
| 1006 | + "print(round(metrics['code']['pov'], 5))" |
1018 | 1007 | ]
|
1019 | 1008 | },
|
1020 | 1009 | {
|
1021 | 1010 | "cell_type": "markdown",
|
1022 | 1011 | "metadata": {},
|
1023 | 1012 | "source": [
|
1024 |
| - "## 5.2. Translate" |
| 1013 | + "## 5.2. Attention matrices" |
1025 | 1014 | ]
|
1026 | 1015 | },
|
1027 | 1016 | {
|
|
1030 | 1019 | "metadata": {},
|
1031 | 1020 | "outputs": [],
|
1032 | 1021 | "source": [
|
1033 |
| - "cg_model = Model(CFG, model_type='cg')\n", |
1034 |
| - "cs_model = Model(CFG, model_type='cs')\n", |
1035 |
| - "\n", |
1036 |
| - "exp_dir = f'./experiments/{CFG.exp_name}'\n", |
| 1022 | + "a = T.tensor([ 2, 576, 16, 84, 474, 695, 0, 0, 0, 3])\n", |
| 1023 | + "c = T.tensor([ 2, 155, 489, 10, 159, 5, 8, 0, 0, 3])\n", |
1037 | 1024 | "\n",
|
1038 |
| - "cg_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cg_model.pt')))\n", |
1039 |
| - "cs_model.load_state_dict(torch.load(os.path.join(exp_dir, 'cs_model.pt')))\n", |
1040 |
| - "\n", |
1041 |
| - "exp_dir" |
1042 |
| - ] |
1043 |
| - }, |
1044 |
| - { |
1045 |
| - "cell_type": "markdown", |
1046 |
| - "metadata": {}, |
1047 |
| - "source": [ |
1048 |
| - "## 5.3. Attention matrices" |
1049 |
| - ] |
1050 |
| - }, |
1051 |
| - { |
1052 |
| - "cell_type": "code", |
1053 |
| - "execution_count": null, |
1054 |
| - "metadata": {}, |
1055 |
| - "outputs": [], |
1056 |
| - "source": [ |
1057 | 1025 | "with T.no_grad():\n",
|
1058 |
| - " i = np.random.randint(len(train_dataset))\n", |
1059 |
| - " a, c, _, _ = train_dataset[i]\n", |
| 1026 | + " i = np.random.randint(len(test_dataset))\n", |
| 1027 | + "# i = 5557\n", |
| 1028 | + " a, c, _, _ = test_dataset[-1]\n", |
1060 | 1029 | " a, c = a.cuda(), c.cuda()\n",
|
| 1030 | + " anno_mask = T.tensor((a != 0) * (a != 1)).byte().cuda()\n", |
| 1031 | + " code_mask = T.tensor((c != 0) * (c != 1)).byte().cuda()\n", |
1061 | 1032 | " x, x_mat = cg_model(src=a.unsqueeze(0), tgt=c.unsqueeze(0))\n",
|
1062 | 1033 | " y, y_mat = cs_model(src=c.unsqueeze(0), tgt=a.unsqueeze(0))\n",
|
1063 |
| - " x = x[0].argmax(dim=1).cpu()\n", |
| 1034 | + " x = x[0].argmax(dim=-1)\n", |
1064 | 1035 | " x_mat = x_mat[0].cpu()\n",
|
1065 |
| - " y = y[0].argmax(dim=1).cpu()\n", |
| 1036 | + " y = y[0].argmax(dim=-1)\n", |
1066 | 1037 | " y_mat = y_mat[0].cpu()\n",
|
1067 | 1038 | " \n",
|
1068 |
| - " ct = dataset.code_lang.to_tokens(c)[0]\n", |
1069 |
| - " at = dataset.anno_lang.to_tokens(a)[0]\n", |
1070 |
| - " xt = dataset.code_lang.to_tokens(x)[0]\n", |
1071 |
| - " yt = dataset.anno_lang.to_tokens(y)[0]\n", |
| 1039 | + " ct = to_tok((c * code_mask).unsqueeze(0), 'code')\n", |
| 1040 | + " xt = to_tok((x * code_mask).unsqueeze(0), 'code')\n", |
| 1041 | + " at = to_tok((a * anno_mask).unsqueeze(0), 'anno')\n", |
| 1042 | + " yt = to_tok((y * anno_mask).unsqueeze(0), 'anno')\n", |
1072 | 1043 | " \n",
|
1073 | 1044 | "\n",
|
1074 |
| - "plt.figure(figsize=(16, 10))\n", |
| 1045 | + "plt.figure(figsize=(12, 8))\n", |
1075 | 1046 | "\n",
|
1076 | 1047 | "# plt.subplot(1, 2, 1)\n",
|
1077 | 1048 | "plt.imshow(F.softmax(y_mat, -1), cmap='jet')\n",
|
1078 | 1049 | "plt.grid(False)\n",
|
1079 |
| - "# plt.xticks(np.arange(len(ct)), labels=ct, rotation=90)\n", |
1080 |
| - "# plt.yticks(np.arange(len(at)), labels=at)\n", |
| 1050 | + "plt.xticks(np.arange(len(ct)), labels=ct, rotation=90)\n", |
| 1051 | + "plt.yticks(np.arange(len(at)), labels=at)\n", |
1081 | 1052 | "\n",
|
1082 | 1053 | "# plt.subplot(1, 2, 2)\n",
|
1083 | 1054 | "# plt.imshow(F.softmax(y_mat, -1), cmap='jet')\n",
|
|
0 commit comments