-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutil_train.py
58 lines (47 loc) · 2.41 KB
/
util_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sagemaker
from sagemaker.pytorch import PyTorch
from datetime import datetime
def aug_exp_train(model_arch, batch_size, aug_operator, aug_load_factor, instance_type, sm_role):
# Amazon S3 bucket and prefix for fetching and storing data:
BUCKET = 'sm-aug-test'
# Full size download of https://github.com/fastai/imagenette
# 1.3GB — 13,395 (9469 train, 3925 val images) from 10 classes
train_data_s3 = 's3://{}/{}'.format(BUCKET, 'imagenette2')
src_code_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs')
training_job_output_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs_output')
# Encapsulate training on SageMaker with PyTorch:
train_estimator = PyTorch(entry_point='sm_augmentation_train-script.py',
source_dir='./src',
role=sm_role,
framework_version='1.8.1',
py_version='py3',
debugger_hook_config=False,
instance_count=1,
instance_type=instance_type,
output_path=training_job_output_s3,
code_location=src_code_s3,
hyperparameters={'epochs': 2,
'backend': 'nccl',
'model-type': model_arch,
'lr': 0.001,
'batch-size': batch_size,
'aug': aug_operator,
'aug-load-factor': aug_load_factor
}
)
# Setting up File-system to import data from S3
data_channels = {'train': sagemaker.inputs.TrainingInput(
s3_data_type='S3Prefix',
s3_data=train_data_s3,
content_type='image/jpeg',
input_mode='File'),
'val': sagemaker.inputs.TrainingInput(
s3_data_type='S3Prefix',
s3_data=train_data_s3,
content_type='image/jpeg',
input_mode='File')
}
# Launching SageMaker training job
train_job_id = 'sm-aug-' + str(datetime.now().strftime("%H-%M-%S"))
train_estimator.fit(inputs=data_channels, job_name=train_job_id)
return train_job_id, train_estimator