-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.php
72 lines (57 loc) · 1.88 KB
/
train.php
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
<?php
include __DIR__ . '/vendor/autoload.php';
use Rubix\ML\Loggers\Screen;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\PersistentModel;
use Rubix\ML\Pipeline;
use Rubix\ML\Transformers\ImageResizer;
use Rubix\ML\Transformers\ImageVectorizer;
use Rubix\ML\Transformers\ZScaleStandardizer;
use Rubix\ML\Classifiers\MultilayerPerceptron;
use Rubix\ML\NeuralNet\Layers\Dense;
use Rubix\ML\NeuralNet\Layers\Activation;
use Rubix\ML\NeuralNet\Layers\Dropout;
use Rubix\ML\NeuralNet\Layers\BatchNorm;
use Rubix\ML\NeuralNet\ActivationFunctions\ELU;
use Rubix\ML\NeuralNet\Optimizers\Adam;
use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\Extractors\CSV;
ini_set('memory_limit', '-1');
$logger = new Screen();
$logger->info('Loading data into memory');
$samples = $labels = [];
foreach (glob('train/*.png') as $file) {
$samples[] = [imagecreatefrompng($file)];
$labels[] = preg_replace('/[0-9]+_(.*).png/', '$1', basename($file));
}
$dataset = new Labeled($samples, $labels);
$estimator = new PersistentModel(
new Pipeline([
new ImageResizer(32, 32),
new ImageVectorizer(),
new ZScaleStandardizer(),
], new MultilayerPerceptron([
new Dense(200),
new Activation(new ELU()),
new Dropout(0.5),
new Dense(200),
new Activation(new ELU()),
new Dropout(0.5),
new Dense(100, 0.0, false),
new BatchNorm(),
new Activation(new ELU()),
new Dense(100),
new Activation(new ELU()),
new Dense(50),
new Activation(new ELU()),
], 256, new Adam(0.0005))),
new Filesystem('cifar10.rbx', true)
);
$estimator->setLogger($logger);
$estimator->train($dataset);
$extractor = new CSV('progress.csv', true);
$extractor->export($estimator->steps());
$logger->info('Progress saved to progress.csv');
if (strtolower(trim(readline('Save this model? (y|[n]): '))) === 'y') {
$estimator->save();
}