-
Notifications
You must be signed in to change notification settings - Fork 299
/
Copy pathrun.lua
61 lines (51 loc) · 2.2 KB
/
run.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
----------------------------------------------------------------------
-- Train a ConvNet on faces.
--
-- original: Clement Farabet
-- new version by: E. Culurciello
-- Mon Oct 14 14:58:50 EDT 2013
----------------------------------------------------------------------
require 'pl'
require 'trepl'
require 'torch' -- torch
require 'image' -- to visualize the dataset
require 'nn' -- provides all sorts of trainable modules/layers
----------------------------------------------------------------------
print(sys.COLORS.red .. '==> processing options')
opt = lapp[[
-r,--learningRate (default 1e-3) learning rate
-d,--learningRateDecay (default 1e-7) learning rate decay (in # samples)
-w,--weightDecay (default 1e-5) L2 penalty on the weights
-m,--momentum (default 0.1) momentum
-d,--dropout (default 0.5) dropout amount
-b,--batchSize (default 128) batch size
-t,--threads (default 8) number of threads
-p,--type (default float) float or cuda
-i,--devid (default 1) device ID (if using CUDA)
-s,--size (default small) dataset: small or full or extra
-o,--save (default results) save directory
--patches (default all) percentage of samples to use for testing'
--visualize (default true) visualize dataset
]]
-- nb of threads and fixed seed (for repeatable experiments)
torch.setnumthreads(opt.threads)
torch.manualSeed(1)
torch.setdefaulttensortype('torch.FloatTensor')
-- type:
if opt.type == 'cuda' then
print(sys.COLORS.red .. '==> switching to CUDA')
require 'cunn'
cutorch.setDevice(opt.devid)
print(sys.COLORS.red .. '==> using GPU #' .. cutorch.getDevice())
end
----------------------------------------------------------------------
print(sys.COLORS.red .. '==> load modules')
local data = require 'data'
local train = require 'train'
local test = require 'test'
----------------------------------------------------------------------
print(sys.COLORS.red .. '==> training!')
while true do
train(data.trainData)
test(data.testData)
end