Skip to content

Commit cd95e29

Browse files
committedApr 1, 2022
Add Aws Batch native retry on spot reclaim
1 parent 0044351 commit cd95e29

File tree

4 files changed

+38
-63
lines changed

4 files changed

+38
-63
lines changed
 

‎docs/config.rst

+1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ volumes One or more container mounts. Mounts can be specifie
163163
delayBetweenAttempts Delay between download attempts from S3 (default `10 sec`).
164164
maxParallelTransfers Max parallel upload/download transfer operations *per job* (default: ``4``).
165165
maxTransferAttempts Max number of downloads attempts from S3 (default: `1`).
166+
maxSpotAttempts Max number of execution attempts of a job interrupted by a EC2 spot reclaim event (default: ``5``, requires ``22.04.0`` or later)
166167
=========================== ================
167168

168169

‎plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy

+22-35
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
package nextflow.cloud.aws.batch
1919

20-
import static AwsContainerOptionsMapper.*
20+
import static nextflow.cloud.aws.batch.AwsContainerOptionsMapper.*
2121

2222
import java.nio.file.Path
2323
import java.nio.file.Paths
24-
import java.util.regex.Pattern
2524

2625
import com.amazonaws.services.batch.AWSBatch
2726
import com.amazonaws.services.batch.model.AWSBatchException
@@ -33,6 +32,7 @@ import com.amazonaws.services.batch.model.DescribeJobDefinitionsRequest
3332
import com.amazonaws.services.batch.model.DescribeJobDefinitionsResult
3433
import com.amazonaws.services.batch.model.DescribeJobsRequest
3534
import com.amazonaws.services.batch.model.DescribeJobsResult
35+
import com.amazonaws.services.batch.model.EvaluateOnExit
3636
import com.amazonaws.services.batch.model.Host
3737
import com.amazonaws.services.batch.model.JobDefinition
3838
import com.amazonaws.services.batch.model.JobDefinitionType
@@ -53,14 +53,12 @@ import groovy.transform.CompileStatic
5353
import groovy.util.logging.Slf4j
5454
import nextflow.cloud.types.CloudMachineInfo
5555
import nextflow.container.ContainerNameValidator
56-
import nextflow.exception.NodeTerminationException
5756
import nextflow.exception.ProcessSubmitException
5857
import nextflow.exception.ProcessUnrecoverableException
5958
import nextflow.executor.BashWrapperBuilder
6059
import nextflow.executor.res.AcceleratorResource
6160
import nextflow.processor.BatchContext
6261
import nextflow.processor.BatchHandler
63-
import nextflow.processor.ErrorStrategy
6462
import nextflow.processor.TaskBean
6563
import nextflow.processor.TaskHandler
6664
import nextflow.processor.TaskRun
@@ -74,8 +72,6 @@ import nextflow.util.CacheHelper
7472
@Slf4j
7573
class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,JobDetail> {
7674

77-
private static Pattern TERMINATED = ~/^Host EC2 .* terminated.*/
78-
7975
private final Path exitFile
8076

8177
private final Path wrapperFile
@@ -108,8 +104,6 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
108104

109105
private Map<String,String> environment
110106

111-
private boolean batchNativeRetry
112-
113107
final static private Map<String,String> jobDefinitions = [:]
114108

115109
/**
@@ -256,23 +250,15 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
256250
final job = describeJob(jobId)
257251
final done = job?.status in ['SUCCEEDED', 'FAILED']
258252
if( done ) {
259-
if( !batchNativeRetry && TERMINATED.matcher(job.statusReason).find() ) {
260-
// kee track of the node termination error
261-
task.error = new NodeTerminationException(job.statusReason)
262-
// mark the task as ABORTED since thr failure is caused by a node failure
263-
task.aborted = true
253+
// finalize the task
254+
task.exitStatus = readExitFile()
255+
task.stdout = outputFile
256+
if( job?.status == 'FAILED' ) {
257+
task.error = new ProcessUnrecoverableException(errReason(job))
258+
task.stderr = executor.getJobOutputStream(jobId) ?: errorFile
264259
}
265260
else {
266-
// finalize the task
267-
task.exitStatus = readExitFile()
268-
task.stdout = outputFile
269-
if( job?.status == 'FAILED' ) {
270-
task.error = new ProcessUnrecoverableException(errReason(job))
271-
task.stderr = executor.getJobOutputStream(jobId) ?: errorFile
272-
}
273-
else {
274-
task.stderr = errorFile
275-
}
261+
task.stderr = errorFile
276262
}
277263
status = TaskStatus.COMPLETED
278264
return true
@@ -620,6 +606,10 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
620606
return ['bash','-o','pipefail','-c', cmd.toString() ]
621607
}
622608

609+
protected maxSpotAttempts() {
610+
return executor.awsOptions.maxSpotAttempts
611+
}
612+
623613
/**
624614
* Create a new Batch job request for the given NF {@link TaskRun}
625615
*
@@ -636,19 +626,16 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
636626
result.setJobQueue(getJobQueue(task))
637627
result.setJobDefinition(getJobDefinition(task))
638628

639-
// -- NF uses `maxRetries` *only* if `retry` error strategy is specified
640-
// otherwise delegates the the retry to AWS Batch
641-
// -- NOTE: make sure the `errorStrategy` is a static value before invoking `getMaxRetries` and `getErrorStrategy`
642-
// when the errorStrategy is closure (ie. dynamic evaluated) value, the `task.config.getMaxRetries() && task.config.getErrorStrategy()`
643-
// condition should not be evaluated because otherwise the closure value is cached using the wrong task.attempt and task.exitStatus values.
644-
// -- use of `config.getRawValue('errorStrategy')` instead of `config.getErrorStrategy()` to prevent the resolution
645-
// of values dynamic values i.e. closures
646-
final strategy = task.config.getRawValue('errorStrategy')
647-
final canCheck = strategy == null || strategy instanceof CharSequence
648-
if( canCheck && task.config.getMaxRetries() && task.config.getErrorStrategy() != ErrorStrategy.RETRY ) {
649-
def retry = new RetryStrategy().withAttempts( task.config.getMaxRetries()+1 )
629+
/*
630+
* retry on spot reclaim
631+
* https://aws.amazon.com/blogs/compute/introducing-retry-strategies-for-aws-batch/
632+
*/
633+
final attempts = maxSpotAttempts()
634+
if( attempts>0 ) {
635+
final retry = new RetryStrategy()
636+
.withAttempts( attempts )
637+
.withEvaluateOnExit( new EvaluateOnExit().withOnReason('Host EC2*').withAction('RETRY') )
650638
result.setRetryStrategy(retry)
651-
this.batchNativeRetry = true
652639
}
653640

654641
// set task timeout

‎plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsOptions.groovy

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class AwsOptions implements CloudTransferOptions {
4141

4242
public static final int DEFAULT_AWS_MAX_ATTEMPTS = 5
4343

44+
public static final int DEFAULT_MAX_SPOT_ATTEMPTS = 5
45+
4446
private Map<String,String> env = System.getenv()
4547

4648
String cliPath
@@ -61,6 +63,8 @@ class AwsOptions implements CloudTransferOptions {
6163

6264
String retryMode
6365

66+
int maxSpotAttempts
67+
6468
volatile Boolean fetchInstanceType
6569

6670
/**
@@ -93,6 +97,7 @@ class AwsOptions implements CloudTransferOptions {
9397
maxParallelTransfers = session.config.navigate('aws.batch.maxParallelTransfers', MAX_TRANSFER) as int
9498
maxTransferAttempts = session.config.navigate('aws.batch.maxTransferAttempts', defaultMaxTransferAttempts()) as int
9599
delayBetweenAttempts = session.config.navigate('aws.batch.delayBetweenAttempts', DEFAULT_DELAY_BETWEEN_ATTEMPTS) as Duration
100+
maxSpotAttempts = session.config.navigate('aws.batch.maxSpotAttempts', DEFAULT_MAX_SPOT_ATTEMPTS) as int
96101
region = session.config.navigate('aws.region') as String
97102
volumes = makeVols(session.config.navigate('aws.batch.volumes'))
98103
jobRole = session.config.navigate('aws.batch.jobRole')

‎plugins/nf-amazon/src/test/nextflow/cloud/aws/batch/AwsBatchTaskHandlerTest.groovy

+10-28
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import com.amazonaws.services.batch.model.DescribeJobDefinitionsRequest
2525
import com.amazonaws.services.batch.model.DescribeJobDefinitionsResult
2626
import com.amazonaws.services.batch.model.DescribeJobsRequest
2727
import com.amazonaws.services.batch.model.DescribeJobsResult
28+
import com.amazonaws.services.batch.model.EvaluateOnExit
2829
import com.amazonaws.services.batch.model.JobDefinition
2930
import com.amazonaws.services.batch.model.JobDetail
3031
import com.amazonaws.services.batch.model.KeyValuePair
@@ -36,7 +37,6 @@ import com.amazonaws.services.batch.model.SubmitJobResult
3637
import com.amazonaws.services.batch.model.TerminateJobRequest
3738
import nextflow.cloud.types.CloudMachineInfo
3839
import nextflow.cloud.types.PriceModel
39-
import nextflow.exception.NodeTerminationException
4040
import nextflow.exception.ProcessUnrecoverableException
4141
import nextflow.executor.Executor
4242
import nextflow.processor.BatchContext
@@ -84,6 +84,7 @@ class AwsBatchTaskHandlerTest extends Specification {
8484
when:
8585
def req = handler.newSubmitRequest(task)
8686
then:
87+
1 * handler.maxSpotAttempts() >> 5
8788
1 * handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
8889
1 * handler.getJobQueue(task) >> 'queue1'
8990
1 * handler.getJobDefinition(task) >> 'job-def:1'
@@ -98,11 +99,12 @@ class AwsBatchTaskHandlerTest extends Specification {
9899
req.getContainerOverrides().getResourceRequirements().find { it.type=='MEMORY'}.getValue() == '8192'
99100
req.getContainerOverrides().getEnvironment() == [VAR_FOO, VAR_BAR]
100101
req.getContainerOverrides().getCommand() == ['bash', '-o','pipefail','-c', "trap \"{ ret=\$?; /bin/aws s3 cp --only-show-errors .command.log s3://bucket/test/.command.log||true; exit \$ret; }\" EXIT; /bin/aws s3 cp --only-show-errors s3://bucket/test/.command.run - | bash 2>&1 | tee .command.log".toString()]
101-
req.getRetryStrategy() == null // <-- retry is managed by NF, hence this must be null
102+
req.getRetryStrategy() == new RetryStrategy().withAttempts(5).withEvaluateOnExit( new EvaluateOnExit().withAction('RETRY').withOnReason('Host EC2*') )
102103

103104
when:
104105
req = handler.newSubmitRequest(task)
105106
then:
107+
1 * handler.maxSpotAttempts() >> 0
106108
1 * handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', region: 'eu-west-1') }
107109
1 * handler.getJobQueue(task) >> 'queue1'
108110
1 * handler.getJobDefinition(task) >> 'job-def:1'
@@ -135,6 +137,7 @@ class AwsBatchTaskHandlerTest extends Specification {
135137
then:
136138
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', region: 'eu-west-1') }
137139
and:
140+
1 * handler.maxSpotAttempts() >> 0
138141
1 * handler.getJobQueue(task) >> 'queue1'
139142
1 * handler.getJobDefinition(task) >> 'job-def:1'
140143
and:
@@ -160,6 +163,7 @@ class AwsBatchTaskHandlerTest extends Specification {
160163
task.getConfig() >> new TaskConfig()
161164
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
162165
and:
166+
1 * handler.maxSpotAttempts() >> 0
163167
1 * handler.getJobQueue(task) >> 'queue1'
164168
1 * handler.getJobDefinition(task) >> 'job-def:1'
165169
and:
@@ -176,6 +180,7 @@ class AwsBatchTaskHandlerTest extends Specification {
176180
task.getConfig() >> new TaskConfig(time: '5 sec')
177181
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
178182
and:
183+
1 * handler.maxSpotAttempts() >> 0
179184
1 * handler.getJobQueue(task) >> 'queue2'
180185
1 * handler.getJobDefinition(task) >> 'job-def:2'
181186
and:
@@ -193,6 +198,7 @@ class AwsBatchTaskHandlerTest extends Specification {
193198
task.getConfig() >> new TaskConfig(time: '1 hour')
194199
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
195200
and:
201+
1 * handler.maxSpotAttempts() >> 0
196202
1 * handler.getJobQueue(task) >> 'queue3'
197203
1 * handler.getJobDefinition(task) >> 'job-def:3'
198204
and:
@@ -221,6 +227,7 @@ class AwsBatchTaskHandlerTest extends Specification {
221227
then:
222228
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', retryMode: 'adaptive', maxTransferAttempts: 10) }
223229
and:
230+
1 * handler.maxSpotAttempts() >> 3
224231
1 * handler.getJobQueue(task) >> 'queue1'
225232
1 * handler.getJobDefinition(task) >> 'job-def:1'
226233
1 * handler.wrapperFile >> Paths.get('/bucket/test/.command.run')
@@ -230,7 +237,7 @@ class AwsBatchTaskHandlerTest extends Specification {
230237
req.getJobQueue() == 'queue1'
231238
req.getJobDefinition() == 'job-def:1'
232239
// no error `retry` error strategy is defined by NF, use `maxRetries` to se Batch attempts
233-
req.getRetryStrategy() == new RetryStrategy().withAttempts(3)
240+
req.getRetryStrategy() == new RetryStrategy().withAttempts(3).withEvaluateOnExit( new EvaluateOnExit().withAction('RETRY').withOnReason('Host EC2*') )
234241
req.getContainerOverrides().getEnvironment() == [VAR_RETRY_MODE, VAR_MAX_ATTEMPTS, VAR_METADATA_ATTEMPTS]
235242
}
236243

@@ -727,29 +734,4 @@ class AwsBatchTaskHandlerTest extends Specification {
727734
trace.machineInfo.priceModel == PriceModel.spot
728735
}
729736

730-
def 'should check spot termination' () {
731-
given:
732-
def JOB_ID = 'job-2'
733-
def client = Mock(AWSBatch)
734-
def task = new TaskRun()
735-
def handler = Spy(AwsBatchTaskHandler)
736-
handler.client = client
737-
handler.jobId = JOB_ID
738-
handler.task = task
739-
and:
740-
handler.isRunning() >> true
741-
handler.describeJob(JOB_ID) >> Mock(JobDetail) {
742-
getStatus() >> 'FAILED'
743-
getStatusReason() >> "Host EC2 (instance i-0e2d5c2edc932b4e8) terminated."
744-
}
745-
746-
when:
747-
def done = handler.checkIfCompleted()
748-
then:
749-
task.aborted
750-
task.error instanceof NodeTerminationException
751-
and:
752-
done == true
753-
754-
}
755737
}

0 commit comments

Comments
 (0)
Please sign in to comment.