Skip to content

Commit 0a94259

Browse files
authored
Merge pull request #7 from lantunes/csv_vectors
Adding CSV file with learned vectors
2 parents a0c9f19 + 119b2a7 commit 0a94259

File tree

3 files changed

+142
-1
lines changed

3 files changed

+142
-1
lines changed

bin/create_csv_vectors_file.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import sys
2+
sys.path.extend([".", ".."])
3+
from sys import argv
4+
import argparse
5+
import csv
6+
from pymatgen import Element
7+
8+
from skipatom import SkipAtomModel, SkipAtomInducedModel
9+
10+
"""
11+
e.g.
12+
--model ../data/mp_2020_10_09.dim200.model
13+
--data ../data/mp_2020_10_09.training.data
14+
--induced --min-count 2e7 --top-n 5
15+
--out ../data/skipatom_20201009_induced.csv
16+
"""
17+
if __name__ == '__main__':
18+
parser = argparse.ArgumentParser(
19+
description="Create a CSV file with the SkipAtom vectors."
20+
)
21+
parser.add_argument("--model", nargs="?", required=True, type=str,
22+
help="path to SkipAtom .model file")
23+
parser.add_argument("--data", nargs="?", required=True, type=str,
24+
help="path to SkipAtom .training.data file")
25+
parser.add_argument("--out", nargs="?", required=True, type=str,
26+
help="path to the output file; a .csv extension should be used")
27+
parser.add_argument("--induced", action="store_true",
28+
help="whether to use induced SkipAtom vectors")
29+
parser.add_argument("--min-count", required=("induced" in argv), type=lambda x: int(float(x)),
30+
help="the min. count to use if induced vectors are specified")
31+
parser.add_argument("--top-n", required=("induced" in argv), type=int,
32+
help="the top N to use if induced vectors are specified")
33+
34+
args = parser.parse_args()
35+
36+
if args.induced:
37+
model = SkipAtomInducedModel.load(args.model, args.data, min_count=args.min_count, top_n=args.top_n)
38+
else:
39+
model = SkipAtomModel.load(args.model, args.data)
40+
41+
sorted_elems = sorted([(e, Element(e).number) for e in model.dictionary], key=lambda v: v[1])
42+
43+
dim = len(model.vectors[0])
44+
45+
with open(args.out, "wt") as f:
46+
writer = csv.writer(f)
47+
header = ["element"]
48+
header.extend([str(i) for i in range(dim)])
49+
writer.writerow(header)
50+
for elem, _ in sorted_elems:
51+
vec = model.vectors[model.dictionary[elem]]
52+
row = [elem]
53+
row.extend([str(v) for v in vec])
54+
writer.writerow(row)

0 commit comments

Comments
 (0)