-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
40 lines (29 loc) · 788 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, regularizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import cv2
import matplotlib.pyplot as plt
import classify
model = classify.load_model()
base_path = './dataset/'
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.1,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
vertical_flip=True,
validation_split=0
)
train_generator = train_datagen.flow_from_directory(
base_path,
target_size=(300, 300),
batch_size=16,
class_mode='categorical',
subset='training',
seed=0
)
model.fit_generator(train_generator, epochs=25)
model.save('new_model.h5')