-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
linear_discriminant_analysis.py
43 lines (36 loc) · 1.36 KB
/
linear_discriminant_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import print_function, division
import numpy as np
from mlfromscratch.utils import calculate_covariance_matrix, normalize, standardize
class LDA():
"""The Linear Discriminant Analysis classifier, also known as Fisher's linear discriminant.
Can besides from classification also be used to reduce the dimensionaly of the dataset.
"""
def __init__(self):
self.w = None
def transform(self, X, y):
self.fit(X, y)
# Project data onto vector
X_transform = X.dot(self.w)
return X_transform
def fit(self, X, y):
# Separate data by class
X1 = X[y == 0]
X2 = X[y == 1]
# Calculate the covariance matrices of the two datasets
cov1 = calculate_covariance_matrix(X1)
cov2 = calculate_covariance_matrix(X2)
cov_tot = cov1 + cov2
# Calculate the mean of the two datasets
mean1 = X1.mean(0)
mean2 = X2.mean(0)
mean_diff = np.atleast_1d(mean1 - mean2)
# Determine the vector which when X is projected onto it best separates the
# data by class. w = (mean1 - mean2) / (cov1 + cov2)
self.w = np.linalg.pinv(cov_tot).dot(mean_diff)
def predict(self, X):
y_pred = []
for sample in X:
h = sample.dot(self.w)
y = 1 * (h < 0)
y_pred.append(y)
return y_pred