-
Notifications
You must be signed in to change notification settings - Fork 24
/
helper.py
116 lines (96 loc) · 4.07 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/python3
import os, pickle, re
from string import punctuation
#Categories
categories_dict = {
'algeria':1,
'sport':2,
'entertainment':3,
'society':4,
'world':5,
'religion':6,
}
### Tools
## Farasa Arabic NLP Toolkit
# Tokenizer
farasaSegmenter = 'Tools/farasa/segmenter'
## Arabic StopWords List
stopWords = open("Tools/arabic-stop-words/list.txt").read().splitlines()
## Models directory
models = 'Models/dumps/'
## Remove Numbers and add Other punctuation
punctuation += '،؛؟”0123456789“'
class Helper():
def __init__(self, article = False):
self.article = article
##~~Pickle helpers~~#
def getPickleContent(self, pklFile):
with open (pklFile, 'rb') as fp:
itemlist = pickle.load(fp)
return itemlist
def setPickleContent(self, fileName, itemList):
with open(fileName+'.pkl', 'wb') as fp:
pickle.dump(itemList, fp)
#~~~~~~~~~~~~~~~~~~#
#~~~ Set and get Model
def getModel(self, name):
model = self.getPickleContent(os.path.join(models, name+'/model_'+name+'.pkl'))
cv = self.getPickleContent(os.path.join(models, name+'/cv_'+name+'.pkl'))
tfidf = self.getPickleContent(os.path.join(models, name+'/tfidf_'+name+'.pkl'))
return model, cv, tfidf
def setModel(self, name, model, cv, tfidf):
path = os.path.join(models, name)
if not os.path.exists(path):
os.mkdir(path)
self.setPickleContent(os.path.join(models, name+'/model_'+name), model)
self.setPickleContent(os.path.join(models, name+'/cv_'+name), cv)
self.setPickleContent(os.path.join(models, name+'/tfidf_'+name), tfidf)
#~~~~~~~~~~~~~~~~~~
# Get the article content
def getArticleContent(self, article):
if os.path.exists(article):
return open(article, 'r').read()
# Drop empty lines
def dropNline(self, article):
if os.path.exists(article):
content = self.getArticleContent(article)
return re.sub(r'\n', ' ', content)
# Get stemmed content
def getLemmaArticle(self, content):
jarFarasaSegmenter = os.path.join(farasaSegmenter, 'FarasaSegmenterJar.jar')
tmp = os.path.join(farasaSegmenter, 'tmp')
if os.path.exists(tmp):
os.system('rm '+tmp)
open(tmp, 'w').write(content)
tmpLemma = os.path.join(farasaSegmenter, 'tmpLemma')
if os.path.exists(tmpLemma):
os.system('rm '+tmpLemma)
os.system('java -jar ' + jarFarasaSegmenter + ' -l true -i ' + tmp + ' -o ' + tmpLemma)
return self.getArticleContent(tmpLemma)
# Remove Stop words
def getCleanArticle(self, content):
content = ''.join(c for c in content if c not in punctuation)
words = content.split()
cleandWords = [w for w in words if w not in stopWords]
return ' '.join(cleandWords)
# Pre-processing Pipeline, before prediction (Get article Bag of Words)
def pipeline(self, content):
cleanArticle = self.getCleanArticle(content)
lemmaContent = self.getLemmaArticle(cleanArticle)
cleanArticle = self.getCleanArticle(lemmaContent).split()
return ' '.join(cleanArticle)
# Main function, predict content category
def predict(self, content):
article = self.pipeline(content)
model, cv, tfidf = self.getModel('sgd_94')
vectorized = tfidf.transform(cv.transform([article]))
predicted = model.predict(vectorized)
keys = list(categories_dict.keys())
values = list(categories_dict.values())
categoryPredicted = keys[values.index(predicted[0])].upper()
return categoryPredicted
if __name__ == '__main__':
help = Helper()
content = 'أمرت السلطات القطرية الأسواق والمراكز التجارية في البلاد برفع وإزالة السلع الواردة من السعودية والبحرين والإمارات ومصر في الذكرى الأولى لإعلان هذه الدول الحصار عليها.'
category = help.predict(content)
print(category)