Skip to content

Commit 9807e95

Browse files
committed
Commit
1 parent 0679237 commit 9807e95

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

decision_tree/dt_author_id.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import sys
1212
from time import time
13-
sys.path.append("../tools/")
13+
sys.path.append("./tools/")
1414
from email_preprocess import preprocess
1515

1616

@@ -41,4 +41,46 @@
4141
acc = accuracy_score(pred, labels_test)
4242

4343
print("Accuracy:", round(accuracy,3))
44-
print("Metrics Accuracy:", round(acc, 3))
44+
print("Metrics Accuracy:", round(acc, 3))
45+
46+
#########################################################
47+
48+
from sklearn.ensemble import AdaBoostClassifier
49+
from sklearn.metrics import accuracy_score
50+
51+
# Create a AdaBoost Classifier (AB) object
52+
t0 = time()
53+
clf = AdaBoostClassifier(n_estimators=100, random_state=0)
54+
clf.fit(features_train, labels_train)
55+
print("Training Time:", round(time()-t0, 3), "s")
56+
57+
acc = accuracy_score(clf.predict(features_test), labels_test)
58+
59+
print("Metrics Accuracy:", round(acc, 3))
60+
61+
62+
#########################################################
63+
from sklearn.neighbors import KNeighborsClassifier
64+
65+
# Create a KNeighbors Classifier (KNN) object
66+
t0 = time()
67+
clf = KNeighborsClassifier(n_neighbors=3)
68+
clf.fit(features_train, labels_train)
69+
print("Training Time:", round(time()-t0, 3), "s")
70+
71+
acc = accuracy_score(clf.predict(features_test), labels_test)
72+
print("Metrics Accuracy:", round(acc, 3))
73+
74+
75+
#########################################################
76+
77+
from sklearn.ensemble import RandomForestClassifier
78+
79+
# Create a RandomForest Classifier (RF) object
80+
t0 = time()
81+
clf = RandomForestClassifier(n_estimators=100, random_state=0)
82+
clf.fit(features_train, labels_train)
83+
print("Training Time:", round(time()-t0, 3), "s")
84+
85+
acc = accuracy_score(clf.predict(features_test), labels_test)
86+
print("Metrics Accuracy:", round(acc, 3))

0 commit comments

Comments
 (0)