forked from shanujshekhar/Improvement-on-Knowledge-backed-Generation-Model-Using-Post-Modifier-Dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
71 lines (58 loc) · 2.57 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import tempfile
import json
OpenNMT_dir = "OpenNMT-py"
def load_kb(infile, verbose=False):
wiki_kb = {}
with open(infile, "r", encoding='utf8') as fin:
for lid, line in enumerate(fin):
row = line.strip().split("\t")
wiki_kb[row[0]] = {}
wiki_kb[row[0]]['label'] = row[1]
wiki_kb[row[0]]['aliases'] = row[2].split(",")
wiki_kb[row[0]]['description'] = row[3].split(",")
wiki_kb[row[0]]['claims'] = json.loads(row[4])
if verbose:
for k,v in wiki_kb[row[0]].items():
print("{}:\t{}".format(k, v))
return wiki_kb
def generate_tmp_file(infile, wiki_kb, verbose=False):
_, src_path = tempfile.mkstemp(suffix='.tmp')
_, tgt_path = tempfile.mkstemp(suffix='.tmp')
#src_path = "{}.src".format(infile)
#tgt_path = "{}.tgt".format(infile)
with open(infile, "r", encoding='utf8') as fin, open(src_path, "w", encoding='utf8') as src_out, open(tgt_path, "w", encoding='utf8') as tgt_out:
for line in fin:
row = line.strip().split("\t")
source = row[0]
# claims = ['<claims>']
claims = []
for claim in wiki_kb[row[4]]['claims']:
if 'property' not in claim:
print("This claim misses [property]: %s"%(line))
continue
key, val = claim['property']
claims.append('<claim>')
claims.append("<prop> {} </prop> <val> {} </val>".format(key, val))
if 'qualifiers' not in claim:
print("This claim misses [qualifiers]: %s" % (line))
claims.append('</claim>')
continue
for qkey, qval in claim['qualifiers']:
claims.append("<qual_prop> {} </qual_prop> <qual_val> {} </qual_val>".format(qkey, qval))
claims.append('</claim>')
# claims.append("</claims>")
source += " " + " ".join(claims)
src_out.write(source + "\n")
tgt_out.write(row[2] + "\n")
if verbose:
print("Sent_Wo_PM: {}".format(row[0]))
print("Entity: {}".format(row[1]))
print("PM: {}".format(row[2]))
print("Sent: {}".format(row[3]))
print("Wiki_ID: {}".format(row[4]))
print("Prev_Sent: {}".format(row[5]))
print("claims: {}".format("\n".join(claims)))
return src_path, tgt_path
def unlink_tmp_file(file_path):
os.unlink(file_path)