-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFeatureExtractor.py
40 lines (25 loc) · 919 Bytes
/
FeatureExtractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.models import Model
from pathlib import Path
from PIL import Image
import numpy as np
class FeatureExtractor:
def __init__(self):
base_model = VGG16(weights='imagenet')
self.model = Model(
inputs=base_model.input,
outputs=base_model.get_layer('fc1').output
)
def extract_inputs(self, img):
image = img.resize((224, 224))
image = image.convert('RGB')
image_array = np.array(image)
x = np.expand_dims(image_array, axis=0)
x = preprocess_input(x)
features = self.model.predict(x)[0]
return features
# img = Image.open("Database/C01-00001.png")
# FeatureExtractor = FeatureExtractor()
# feeature = FeatureExtractor.extract_inputs(img)
# print(feeature)