Skip to content

Commit 551aa21

Browse files
authored
Add file
1 parent cf6adf7 commit 551aa21

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

Diff for: helpers.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
2+
import os
3+
import matplotlib.pyplot as plt
4+
import matplotlib.image as mplimg
5+
import networkx as nx
6+
import random
7+
8+
from io import BytesIO
9+
from itertools import chain
10+
from collections import namedtuple, OrderedDict
11+
12+
13+
Sentence = namedtuple("Sentence", "words tags")
14+
15+
def read_data(filename):
16+
"""Read tagged sentence data"""
17+
with open(filename, 'r') as f:
18+
sentence_lines = [l.split("\n") for l in f.read().split("\n\n")]
19+
return OrderedDict(((s[0], Sentence(*zip(*[l.strip().split("\t")
20+
for l in s[1:]]))) for s in sentence_lines if s[0]))
21+
22+
23+
def read_tags(filename):
24+
"""Read a list of word tag classes"""
25+
with open(filename, 'r') as f:
26+
tags = f.read().split("\n")
27+
return frozenset(tags)
28+
29+
30+
def model2png(model, filename="", overwrite=False, show_ends=False):
31+
"""Convert a Pomegranate model into a PNG image
32+
33+
The conversion pipeline extracts the underlying NetworkX graph object,
34+
converts it to a PyDot graph, then writes the PNG data to a bytes array,
35+
which can be saved as a file to disk or imported with matplotlib for display.
36+
37+
Model -> NetworkX.Graph -> PyDot.Graph -> bytes -> PNG
38+
39+
Parameters
40+
----------
41+
model : Pomegranate.Model
42+
The model object to convert. The model must have an attribute .graph
43+
referencing a NetworkX.Graph instance.
44+
45+
filename : string (optional)
46+
The PNG file will be saved to disk with this filename if one is provided.
47+
By default, the image file will NOT be created if a file with this name
48+
already exists unless overwrite=True.
49+
50+
overwrite : bool (optional)
51+
overwrite=True allows the new PNG to overwrite the specified file if it
52+
already exists
53+
54+
show_ends : bool (optional)
55+
show_ends=True will generate the PNG including the two end states from
56+
the Pomegranate model (which are not usually an explicit part of the graph)
57+
"""
58+
nodes = model.graph.nodes()
59+
if not show_ends:
60+
nodes = [n for n in nodes if n not in (model.start, model.end)]
61+
g = nx.relabel_nodes(model.graph.subgraph(nodes), {n: n.name for n in model.graph.nodes()})
62+
pydot_graph = nx.drawing.nx_pydot.to_pydot(g)
63+
pydot_graph.set_rankdir("LR")
64+
png_data = pydot_graph.create_png(prog='dot')
65+
img_data = BytesIO()
66+
img_data.write(png_data)
67+
img_data.seek(0)
68+
if filename:
69+
if os.path.exists(filename) and not overwrite:
70+
raise IOError("File already exists. Use overwrite=True to replace existing files on disk.")
71+
with open(filename, 'wb') as f:
72+
f.write(img_data.read())
73+
img_data.seek(0)
74+
return mplimg.imread(img_data)
75+
76+
77+
def show_model(model, figsize=(5, 5), **kwargs):
78+
"""Display a Pomegranate model as an image using matplotlib
79+
80+
Parameters
81+
----------
82+
model : Pomegranate.Model
83+
The model object to convert. The model must have an attribute .graph
84+
referencing a NetworkX.Graph instance.
85+
86+
figsize : tuple(int, int) (optional)
87+
A tuple specifying the dimensions of a matplotlib Figure that will
88+
display the converted graph
89+
90+
**kwargs : dict
91+
The kwargs dict is passed to the model2png program, see that function
92+
for details
93+
"""
94+
plt.figure(figsize=figsize)
95+
plt.imshow(model2png(model, **kwargs))
96+
plt.axis('off')
97+
98+
99+
class Subset(namedtuple("BaseSet", "sentences keys vocab X tagset Y N stream")):
100+
def __new__(cls, sentences, keys):
101+
word_sequences = tuple([sentences[k].words for k in keys])
102+
tag_sequences = tuple([sentences[k].tags for k in keys])
103+
wordset = frozenset(chain(*word_sequences))
104+
tagset = frozenset(chain(*tag_sequences))
105+
N = sum(1 for _ in chain(*(sentences[k].words for k in keys)))
106+
stream = tuple(zip(chain(*word_sequences), chain(*tag_sequences)))
107+
return super().__new__(cls, {k: sentences[k] for k in keys}, keys, wordset, word_sequences,
108+
tagset, tag_sequences, N, stream.__iter__)
109+
110+
def __len__(self):
111+
return len(self.sentences)
112+
113+
def __iter__(self):
114+
return iter(self.sentences.items())
115+
116+
117+
class Dataset(namedtuple("_Dataset", "sentences keys vocab X tagset Y training_set testing_set N stream")):
118+
def __new__(cls, tagfile, datafile, train_test_split=0.8, seed=112890):
119+
tagset = read_tags(tagfile)
120+
sentences = read_data(datafile)
121+
keys = tuple(sentences.keys())
122+
wordset = frozenset(chain(*[s.words for s in sentences.values()]))
123+
word_sequences = tuple([sentences[k].words for k in keys])
124+
tag_sequences = tuple([sentences[k].tags for k in keys])
125+
N = sum(1 for _ in chain(*(s.words for s in sentences.values())))
126+
127+
# split data into train/test sets
128+
_keys = list(keys)
129+
if seed is not None: random.seed(seed)
130+
random.shuffle(_keys)
131+
split = int(train_test_split * len(_keys))
132+
training_data = Subset(sentences, _keys[:split])
133+
testing_data = Subset(sentences, _keys[split:])
134+
stream = tuple(zip(chain(*word_sequences), chain(*tag_sequences)))
135+
return super().__new__(cls, dict(sentences), keys, wordset, word_sequences, tagset,
136+
tag_sequences, training_data, testing_data, N, stream.__iter__)
137+
138+
def __len__(self):
139+
return len(self.sentences)
140+
141+
def __iter__(self):
142+
return iter(self.sentences.items())

0 commit comments

Comments
 (0)