Skip to content

Commit 99cf7d7

Browse files
Merge pull request #348 from MartianCoder-git/patch-1
KNN_algorithm
2 parents bbb65f1 + 5de973a commit 99cf7d7

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

KNN_algorithm

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from sklearn import datasets
2+
from sklearn.model_selection import train_test_split
3+
from sklearn.neighbors import KNeighborsClassifier
4+
from sklearn.metrics import accuracy_score
5+
from scipy.spatial import distance
6+
7+
def euc(a,b):
8+
return distance.euclidean(a, b)
9+
10+
11+
class ScrappyKNN():
12+
def fit(self, features_train, labels_train):
13+
self.features_train = features_train
14+
self.labels_train = labels_train
15+
16+
def predict(self, features_test):
17+
predictions = []
18+
for item in features_test:
19+
label = self.closest(item)
20+
predictions.append(label)
21+
22+
return predictions
23+
24+
def closest(self, item):
25+
best_dist = euc(item, self.features_train[0])
26+
best_index = 0
27+
for i in range(1,len(self.features_train)):
28+
dist = euc(item, self.features_train[i])
29+
if dist < best_dist:
30+
best_dist = dist
31+
best_index = i
32+
return self.labels_train[best_index]
33+
34+
iris = datasets.load_iris()
35+
36+
print(iris)
37+
38+
features = iris.data
39+
labels = iris.target
40+
41+
print(features)
42+
print(labels)
43+
44+
features_train, features_test, labels_train, labels_test = train_test_split(features, labels, test_size=.5)
45+
#print(len(features))
46+
#print(len(features_train))
47+
48+
my_classifier = ScrappyKNN()
49+
#my_classifier = KNeighborsClassifier()
50+
my_classifier.fit(features_train, labels_train)
51+
52+
prediction = my_classifier.predict(features_test)
53+
54+
print(prediction)
55+
print(accuracy_score(labels_test, prediction))
56+
57+
iris1 = [[7.1, 2.9, 5.3, 2.4]] #virginica
58+
iris_prediction = my_classifier.predict(iris1)
59+
60+
if iris_prediction == 0:
61+
print("Setosa")
62+
if iris_prediction == 1:
63+
print("Versicolor")
64+
if iris_prediction == 2:
65+
print("Virginica")

0 commit comments

Comments
 (0)