|
| 1 | + |
| 2 | +import os |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import matplotlib.image as mplimg |
| 5 | +import networkx as nx |
| 6 | +import random |
| 7 | + |
| 8 | +from io import BytesIO |
| 9 | +from itertools import chain |
| 10 | +from collections import namedtuple, OrderedDict |
| 11 | + |
| 12 | + |
| 13 | +Sentence = namedtuple("Sentence", "words tags") |
| 14 | + |
| 15 | +def read_data(filename): |
| 16 | + """Read tagged sentence data""" |
| 17 | + with open(filename, 'r') as f: |
| 18 | + sentence_lines = [l.split("\n") for l in f.read().split("\n\n")] |
| 19 | + return OrderedDict(((s[0], Sentence(*zip(*[l.strip().split("\t") |
| 20 | + for l in s[1:]]))) for s in sentence_lines if s[0])) |
| 21 | + |
| 22 | + |
| 23 | +def read_tags(filename): |
| 24 | + """Read a list of word tag classes""" |
| 25 | + with open(filename, 'r') as f: |
| 26 | + tags = f.read().split("\n") |
| 27 | + return frozenset(tags) |
| 28 | + |
| 29 | + |
| 30 | +def model2png(model, filename="", overwrite=False, show_ends=False): |
| 31 | + """Convert a Pomegranate model into a PNG image |
| 32 | +
|
| 33 | + The conversion pipeline extracts the underlying NetworkX graph object, |
| 34 | + converts it to a PyDot graph, then writes the PNG data to a bytes array, |
| 35 | + which can be saved as a file to disk or imported with matplotlib for display. |
| 36 | +
|
| 37 | + Model -> NetworkX.Graph -> PyDot.Graph -> bytes -> PNG |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + model : Pomegranate.Model |
| 42 | + The model object to convert. The model must have an attribute .graph |
| 43 | + referencing a NetworkX.Graph instance. |
| 44 | +
|
| 45 | + filename : string (optional) |
| 46 | + The PNG file will be saved to disk with this filename if one is provided. |
| 47 | + By default, the image file will NOT be created if a file with this name |
| 48 | + already exists unless overwrite=True. |
| 49 | +
|
| 50 | + overwrite : bool (optional) |
| 51 | + overwrite=True allows the new PNG to overwrite the specified file if it |
| 52 | + already exists |
| 53 | +
|
| 54 | + show_ends : bool (optional) |
| 55 | + show_ends=True will generate the PNG including the two end states from |
| 56 | + the Pomegranate model (which are not usually an explicit part of the graph) |
| 57 | + """ |
| 58 | + nodes = model.graph.nodes() |
| 59 | + if not show_ends: |
| 60 | + nodes = [n for n in nodes if n not in (model.start, model.end)] |
| 61 | + g = nx.relabel_nodes(model.graph.subgraph(nodes), {n: n.name for n in model.graph.nodes()}) |
| 62 | + pydot_graph = nx.drawing.nx_pydot.to_pydot(g) |
| 63 | + pydot_graph.set_rankdir("LR") |
| 64 | + png_data = pydot_graph.create_png(prog='dot') |
| 65 | + img_data = BytesIO() |
| 66 | + img_data.write(png_data) |
| 67 | + img_data.seek(0) |
| 68 | + if filename: |
| 69 | + if os.path.exists(filename) and not overwrite: |
| 70 | + raise IOError("File already exists. Use overwrite=True to replace existing files on disk.") |
| 71 | + with open(filename, 'wb') as f: |
| 72 | + f.write(img_data.read()) |
| 73 | + img_data.seek(0) |
| 74 | + return mplimg.imread(img_data) |
| 75 | + |
| 76 | + |
| 77 | +def show_model(model, figsize=(5, 5), **kwargs): |
| 78 | + """Display a Pomegranate model as an image using matplotlib |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + model : Pomegranate.Model |
| 83 | + The model object to convert. The model must have an attribute .graph |
| 84 | + referencing a NetworkX.Graph instance. |
| 85 | +
|
| 86 | + figsize : tuple(int, int) (optional) |
| 87 | + A tuple specifying the dimensions of a matplotlib Figure that will |
| 88 | + display the converted graph |
| 89 | +
|
| 90 | + **kwargs : dict |
| 91 | + The kwargs dict is passed to the model2png program, see that function |
| 92 | + for details |
| 93 | + """ |
| 94 | + plt.figure(figsize=figsize) |
| 95 | + plt.imshow(model2png(model, **kwargs)) |
| 96 | + plt.axis('off') |
| 97 | + |
| 98 | + |
| 99 | +class Subset(namedtuple("BaseSet", "sentences keys vocab X tagset Y N stream")): |
| 100 | + def __new__(cls, sentences, keys): |
| 101 | + word_sequences = tuple([sentences[k].words for k in keys]) |
| 102 | + tag_sequences = tuple([sentences[k].tags for k in keys]) |
| 103 | + wordset = frozenset(chain(*word_sequences)) |
| 104 | + tagset = frozenset(chain(*tag_sequences)) |
| 105 | + N = sum(1 for _ in chain(*(sentences[k].words for k in keys))) |
| 106 | + stream = tuple(zip(chain(*word_sequences), chain(*tag_sequences))) |
| 107 | + return super().__new__(cls, {k: sentences[k] for k in keys}, keys, wordset, word_sequences, |
| 108 | + tagset, tag_sequences, N, stream.__iter__) |
| 109 | + |
| 110 | + def __len__(self): |
| 111 | + return len(self.sentences) |
| 112 | + |
| 113 | + def __iter__(self): |
| 114 | + return iter(self.sentences.items()) |
| 115 | + |
| 116 | + |
| 117 | +class Dataset(namedtuple("_Dataset", "sentences keys vocab X tagset Y training_set testing_set N stream")): |
| 118 | + def __new__(cls, tagfile, datafile, train_test_split=0.8, seed=112890): |
| 119 | + tagset = read_tags(tagfile) |
| 120 | + sentences = read_data(datafile) |
| 121 | + keys = tuple(sentences.keys()) |
| 122 | + wordset = frozenset(chain(*[s.words for s in sentences.values()])) |
| 123 | + word_sequences = tuple([sentences[k].words for k in keys]) |
| 124 | + tag_sequences = tuple([sentences[k].tags for k in keys]) |
| 125 | + N = sum(1 for _ in chain(*(s.words for s in sentences.values()))) |
| 126 | + |
| 127 | + # split data into train/test sets |
| 128 | + _keys = list(keys) |
| 129 | + if seed is not None: random.seed(seed) |
| 130 | + random.shuffle(_keys) |
| 131 | + split = int(train_test_split * len(_keys)) |
| 132 | + training_data = Subset(sentences, _keys[:split]) |
| 133 | + testing_data = Subset(sentences, _keys[split:]) |
| 134 | + stream = tuple(zip(chain(*word_sequences), chain(*tag_sequences))) |
| 135 | + return super().__new__(cls, dict(sentences), keys, wordset, word_sequences, tagset, |
| 136 | + tag_sequences, training_data, testing_data, N, stream.__iter__) |
| 137 | + |
| 138 | + def __len__(self): |
| 139 | + return len(self.sentences) |
| 140 | + |
| 141 | + def __iter__(self): |
| 142 | + return iter(self.sentences.items()) |
0 commit comments