-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrelation_fetcher.py
120 lines (93 loc) · 4.02 KB
/
relation_fetcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from collections import defaultdict
from pathlib import Path
import asyncio
import csv
import os
from redis import Redis
from wikidata_endpoint import WikidataEndpoint, WikidataEndpointConfiguration
from wikidata_endpoint.return_types import UriReturnType
class NullDatabase:
def __init__(self, *args, **kwargs):
pass
def sadd(self, *args, **kwargs):
pass
def save(self, *args, **kwargs):
pass
class RelationFetcherCache:
def __init__(self, *args, **kwargs):
self.filename = "relation_fetcher_cache.csv"
self.cache = {}
open_modifier = "a"
if os.path.exists(self.filename) and os.stat(self.filename).st_size > 0:
with open(self.filename, "r+") as csv_file:
csv_reader = csv.reader(csv_file)
next(csv_reader)
for row in csv_reader:
self._add(row[0], row[1], row[2])
else:
open_modifier = "w+"
self.persisted_cache = open(self.filename, open_modifier)
self.csv_writer = csv.writer(self.persisted_cache)
if open_modifier == "w+":
self.csv_writer.writerow(["subject", "predicate", "object"])
def sadd(self, *args, **kwargs):
subject = args[0]
predicate = args[1]
object_ = args[2]
self.csv_writer.writerow([subject, predicate, object_])
self._add(subject, predicate, object_)
def _add(self, subject, predicate, object_):
relations = self.cache.get(subject, None)
if not relations:
relations = []
self.cache[subject] = relations
relations.append((predicate, object_))
def get(self, subject):
return self.cache.get(subject, None)
def save(self, *args, **kwargs):
self.persisted_cache.flush()
def __del__(self):
self.save()
self.persisted_cache.close()
class RelationFetcher:
def __init__(self, wikidata_ids, entities_per_query, endpoint=None, redis_config=None):
self.entities_per_query = entities_per_query
self.entities_fetched = 0
self.wikidata_ids = wikidata_ids
self.redis = RelationFetcherCache() if not redis_config else Redis(**redis_config)
self.endpoint = endpoint or WikidataEndpoint(
WikidataEndpointConfiguration(Path("resources/wikidata_endpoint_config.ini")))
async def get_relations(self, wikidata_ids, relations_map):
cached_ids = set()
for id_ in wikidata_ids:
subject = f"http://www.wikidata.org/entity/Q{id_}"
relations = self.redis.get(subject)
if relations:
for predicate, object_ in relations:
relations_map[(UriReturnType(predicate), UriReturnType(object_))].add(
UriReturnType(subject))
cached_ids.add(id_)
joined_ids = ""
for id_ in wikidata_ids:
if id_ not in cached_ids:
joined_ids += f"wd:Q{id_} "
if joined_ids != "":
query = open('resources/get_relations.rq').read() % joined_ids
with self.endpoint.request() as request:
for results in request.post(query):
subject, predicate, object_ = results.values()
relations_map[(predicate, object_)].add(subject)
self.redis.sadd(subject.value, predicate.value, object_.value)
self.entities_fetched += len(wikidata_ids)
print(f"{self.entities_fetched} entities fetched.")
async def fetch(self):
self.entities_fetched = 0
relations_entity_map = defaultdict(set)
print(f"Fetching {len(self.wikidata_ids)} entities.")
await asyncio.gather(*map(
lambda x: self.get_relations(self.wikidata_ids[x: min(x + self.entities_per_query, len(self.wikidata_ids))],
relations_entity_map),
range(0, len(self.wikidata_ids), self.entities_per_query)))
return relations_entity_map
def __del__(self):
self.redis.save()