Skip to content

Commit 8836271

Browse files
authored
Fixed an issue in multipart S3 client (#5959)
* Fixed an issue in multipart S3 client * Fix integ test * Fix checkstyle
1 parent 1acf8a3 commit 8836271

File tree

3 files changed

+191
-23
lines changed

3 files changed

+191
-23
lines changed

services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java

+43-14
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
package software.amazon.awssdk.services.s3.multipart;
1717

18+
import static java.util.Base64.getEncoder;
1819
import static java.util.concurrent.TimeUnit.SECONDS;
1920
import static org.assertj.core.api.Assertions.assertThat;
21+
import static org.junit.Assert.assertEquals;
2022
import static software.amazon.awssdk.services.s3.model.ServerSideEncryption.AES256;
2123
import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName;
2224

@@ -25,16 +27,22 @@
2527
import java.io.File;
2628
import java.io.FileInputStream;
2729
import java.io.IOException;
30+
import java.io.InputStream;
31+
import java.io.OutputStream;
2832
import java.nio.ByteBuffer;
2933
import java.nio.charset.Charset;
3034
import java.nio.file.Files;
35+
import java.security.DigestInputStream;
3136
import java.security.MessageDigest;
37+
import java.security.NoSuchAlgorithmException;
3238
import java.security.SecureRandom;
3339
import java.util.Base64;
3440
import java.util.List;
3541
import java.util.Map;
3642
import java.util.Optional;
43+
import java.util.Random;
3744
import java.util.UUID;
45+
import java.util.concurrent.CompletableFuture;
3846
import java.util.zip.CRC32;
3947
import java.util.concurrent.ExecutorService;
4048
import java.util.concurrent.Executors;
@@ -48,22 +56,31 @@
4856
import org.junit.jupiter.api.Timeout;
4957
import org.reactivestreams.Subscriber;
5058
import software.amazon.awssdk.core.ClientType;
59+
import software.amazon.awssdk.core.ResponseBytes;
5160
import software.amazon.awssdk.core.ResponseInputStream;
5261
import software.amazon.awssdk.core.async.AsyncRequestBody;
62+
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
63+
import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody;
5364
import software.amazon.awssdk.core.interceptor.Context;
5465
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
5566
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
5667
import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody;
5768
import software.amazon.awssdk.core.sync.ResponseTransformer;
5869
import software.amazon.awssdk.services.s3.S3AsyncClient;
70+
import software.amazon.awssdk.services.s3.S3Client;
5971
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
6072
import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm;
6173
import software.amazon.awssdk.services.s3.model.ChecksumMode;
6274
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
6375
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
76+
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
6477
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
78+
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
6579
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
6680
import software.amazon.awssdk.services.s3.utils.ChecksumUtils;
81+
import software.amazon.awssdk.testutils.FileUtils;
82+
import software.amazon.awssdk.testutils.RandomTempFile;
83+
import software.amazon.awssdk.utils.BinaryUtils;
6784
import software.amazon.awssdk.utils.IoUtils;
6885
import software.amazon.awssdk.utils.Md5Utils;
6986

@@ -72,9 +89,8 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest
7289

7390
private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class);
7491
private static final String TEST_KEY = "testfile.dat";
75-
private static final int OBJ_SIZE = 19 * 1024 * 1024;
92+
private static final int OBJ_SIZE = 1024 * 1024 * 30;
7693
private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor();
77-
private static final byte[] CONTENT = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset());
7894
private static File testFile;
7995
private static S3AsyncClient mpuS3Client;
8096
private static ExecutorService executorService = Executors.newFixedThreadPool(2);
@@ -83,8 +99,8 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest
8399
public static void setup() throws Exception {
84100
setUp();
85101
createBucket(TEST_BUCKET);
86-
testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString());
87-
Files.write(testFile.toPath(), CONTENT);
102+
103+
testFile = new RandomTempFile(OBJ_SIZE);
88104
mpuS3Client = S3AsyncClient
89105
.builder()
90106
.region(DEFAULT_REGION)
@@ -108,6 +124,26 @@ public void reset() {
108124
CAPTURING_INTERCEPTOR.reset();
109125
}
110126

127+
@Test
128+
public void upload_blockingInputStream_shouldSucceed() throws IOException {
129+
String objectPath = UUID.randomUUID().toString();
130+
String expectedMd5 = Md5Utils.md5AsBase64(testFile);
131+
132+
BlockingInputStreamAsyncRequestBody body = AsyncRequestBody.forBlockingInputStream(null);
133+
134+
CompletableFuture<PutObjectResponse> put =
135+
mpuS3Client.putObject(req -> req.bucket(TEST_BUCKET).key(objectPath)
136+
.build(), body);
137+
body.writeInputStream(new FileInputStream(testFile));
138+
put.join();
139+
140+
ResponseInputStream<GetObjectResponse> objContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(objectPath),
141+
ResponseTransformer.toInputStream());
142+
143+
String actualMd5 = BinaryUtils.toBase64(Md5Utils.computeMD5Hash(objContent));
144+
assertEquals(expectedMd5, actualMd5);
145+
}
146+
111147
@Test
112148
void putObject_fileRequestBody_objectSentCorrectly() throws Exception {
113149
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
@@ -127,7 +163,7 @@ void putObject_fileRequestBody_objectSentCorrectly() throws Exception {
127163
@Test
128164
void putObject_inputStreamAsyncRequestBody_objectSentCorrectly() throws Exception {
129165
AsyncRequestBody body = AsyncRequestBody.fromInputStream(
130-
new ByteArrayInputStream(CONTENT),
166+
new FileInputStream(testFile),
131167
Long.valueOf(OBJ_SIZE),
132168
executorService);
133169
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
@@ -193,7 +229,7 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
193229
@Test
194230
void putObject_withSSECAndChecksum_objectSentCorrectly() throws Exception {
195231
byte[] secretKey = generateSecretKey();
196-
String b64Key = Base64.getEncoder().encodeToString(secretKey);
232+
String b64Key = getEncoder().encodeToString(secretKey);
197233
String b64KeyMd5 = Md5Utils.md5AsBase64(secretKey);
198234

199235
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
@@ -282,17 +318,10 @@ private static String calculateCRC32AsString(String filePath) throws IOException
282318
IoUtils.drainInputStream(cis);
283319
long checksumValue = cis.getChecksum().getValue();
284320
byte[] checksumBytes = ByteBuffer.allocate(4).putInt((int) checksumValue).array();
285-
return Base64.getEncoder().encodeToString(checksumBytes);
321+
return getEncoder().encodeToString(checksumBytes);
286322
}
287323
}
288324

289-
private static String calculateSHA1AsString() throws Exception {
290-
MessageDigest md = MessageDigest.getInstance("SHA-1");
291-
md.update(CONTENT);
292-
byte[] checksum = md.digest();
293-
return Base64.getEncoder().encodeToString(checksum);
294-
}
295-
296325
private static byte[] generateSecretKey() {
297326
KeyGenerator generator;
298327
try {

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java

+13-9
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscrib
112112

113113
private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false);
114114

115-
private AtomicInteger partNumber = new AtomicInteger(1);
115+
private AtomicInteger partNumber = new AtomicInteger(0);
116116
private AtomicLong contentLength = new AtomicLong(0);
117117

118118
private final Queue<CompletedPart> completedParts = new ConcurrentLinkedQueue<>();
@@ -160,6 +160,7 @@ public void onSubscribe(Subscription s) {
160160

161161
@Override
162162
public void onNext(AsyncRequestBody asyncRequestBody) {
163+
int currentPartNum = partNumber.incrementAndGet();
163164
log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength());
164165
asyncRequestBodyInFlight.incrementAndGet();
165166

@@ -187,8 +188,8 @@ public void onNext(AsyncRequestBody asyncRequestBody) {
187188
uploadId = createMultipartUploadResponse.uploadId();
188189
log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId);
189190

190-
sendUploadPartRequest(uploadId, firstRequestBody);
191-
sendUploadPartRequest(uploadId, asyncRequestBody);
191+
sendUploadPartRequest(uploadId, firstRequestBody, 1);
192+
sendUploadPartRequest(uploadId, asyncRequestBody, 2);
192193

193194
// We need to complete the uploadIdFuture *after* the first two requests have been sent
194195
uploadIdFuture.complete(uploadId);
@@ -197,21 +198,24 @@ public void onNext(AsyncRequestBody asyncRequestBody) {
197198
CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture);
198199
} else {
199200
uploadIdFuture.whenComplete((r, t) -> {
200-
sendUploadPartRequest(uploadId, asyncRequestBody);
201+
sendUploadPartRequest(uploadId, asyncRequestBody, currentPartNum);
201202
});
202203
}
203204
}
204205

205-
private void sendUploadPartRequest(String uploadId, AsyncRequestBody asyncRequestBody) {
206+
private void sendUploadPartRequest(String uploadId,
207+
AsyncRequestBody asyncRequestBody,
208+
int currentPartNum) {
206209
Optional<Long> contentLength = asyncRequestBody.contentLength();
207210
if (!contentLength.isPresent()) {
208211
SdkClientException e = SdkClientException.create("Content length must be present on the AsyncRequestBody");
209212
multipartUploadHelper.failRequestsElegantly(futures, e, uploadId, returnFuture, putObjectRequest);
210213
}
211214
this.contentLength.getAndAdd(contentLength.get());
212215

213-
multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedParts::add, futures,
214-
uploadPart(asyncRequestBody), progressListener)
216+
multipartUploadHelper
217+
.sendIndividualUploadPartRequest(uploadId, completedParts::add, futures,
218+
uploadPart(asyncRequestBody, currentPartNum), progressListener)
215219
.whenComplete((r, t) -> {
216220
if (t != null) {
217221
if (failureActionInitiated.compareAndSet(false, true)) {
@@ -226,10 +230,10 @@ private void sendUploadPartRequest(String uploadId, AsyncRequestBody asyncReques
226230
};
227231
}
228232

229-
private Pair<UploadPartRequest, AsyncRequestBody> uploadPart(AsyncRequestBody asyncRequestBody) {
233+
private Pair<UploadPartRequest, AsyncRequestBody> uploadPart(AsyncRequestBody asyncRequestBody, int partNum) {
230234
UploadPartRequest uploadRequest =
231235
SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest,
232-
partNumber.getAndIncrement(),
236+
partNum,
233237
uploadId);
234238
return Pair.of(uploadRequest, asyncRequestBody);
235239
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.s3.internal.multipart;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.mockito.Mockito.times;
20+
import static org.mockito.Mockito.verify;
21+
import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall;
22+
import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall;
23+
import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls;
24+
25+
import java.io.FileInputStream;
26+
import java.io.FileNotFoundException;
27+
import java.io.IOException;
28+
import java.util.List;
29+
import java.util.concurrent.CompletableFuture;
30+
import java.util.stream.Collectors;
31+
import java.util.stream.IntStream;
32+
import org.junit.jupiter.api.AfterAll;
33+
import org.junit.jupiter.api.BeforeAll;
34+
import org.junit.jupiter.api.BeforeEach;
35+
import org.junit.jupiter.api.Test;
36+
import org.mockito.ArgumentCaptor;
37+
import org.mockito.Mockito;
38+
import software.amazon.awssdk.core.async.AsyncRequestBody;
39+
import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody;
40+
import software.amazon.awssdk.services.s3.S3AsyncClient;
41+
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
42+
import software.amazon.awssdk.services.s3.model.CompletedPart;
43+
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
44+
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
45+
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
46+
import software.amazon.awssdk.testutils.RandomTempFile;
47+
48+
public class UploadWithUnknownContentLengthHelperTest {
49+
private static final String BUCKET = "bucket";
50+
private static final String KEY = "key";
51+
private static final String UPLOAD_ID = "1234";
52+
53+
// Should contain 126 parts
54+
private static final long MPU_CONTENT_SIZE = 1005 * 1024;
55+
private static final long PART_SIZE = 8 * 1024;
56+
57+
private UploadWithUnknownContentLengthHelper helper;
58+
private S3AsyncClient s3AsyncClient;
59+
private static RandomTempFile testFile;
60+
61+
@BeforeAll
62+
public static void beforeAll() throws IOException {
63+
testFile = new RandomTempFile("testfile.dat", MPU_CONTENT_SIZE);
64+
}
65+
66+
@AfterAll
67+
public static void afterAll() throws Exception {
68+
testFile.delete();
69+
}
70+
71+
@BeforeEach
72+
public void beforeEach() {
73+
s3AsyncClient = Mockito.mock(S3AsyncClient.class);
74+
helper = new UploadWithUnknownContentLengthHelper(s3AsyncClient, PART_SIZE, PART_SIZE, PART_SIZE * 4);
75+
}
76+
77+
@Test
78+
void upload_blockingInputStream_shouldInOrder() throws FileNotFoundException {
79+
stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient);
80+
stubSuccessfulUploadPartCalls(s3AsyncClient);
81+
stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient);
82+
83+
BlockingInputStreamAsyncRequestBody body = AsyncRequestBody.forBlockingInputStream(null);
84+
85+
CompletableFuture<PutObjectResponse> future = helper.uploadObject(putObjectRequest(), body);
86+
87+
body.writeInputStream(new FileInputStream(testFile));
88+
89+
future.join();
90+
91+
ArgumentCaptor<UploadPartRequest> requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class);
92+
ArgumentCaptor<AsyncRequestBody> requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class);
93+
int numTotalParts = 126;
94+
verify(s3AsyncClient, times(numTotalParts)).uploadPart(requestArgumentCaptor.capture(),
95+
requestBodyArgumentCaptor.capture());
96+
97+
List<UploadPartRequest> actualRequests = requestArgumentCaptor.getAllValues();
98+
List<AsyncRequestBody> actualRequestBodies = requestBodyArgumentCaptor.getAllValues();
99+
assertThat(actualRequestBodies).hasSize(numTotalParts);
100+
assertThat(actualRequests).hasSize(numTotalParts);
101+
102+
for (int i = 0; i < actualRequests.size(); i++) {
103+
UploadPartRequest request = actualRequests.get(i);
104+
AsyncRequestBody requestBody = actualRequestBodies.get(i);
105+
assertThat(request.partNumber()).isEqualTo( i + 1);
106+
assertThat(request.bucket()).isEqualTo(BUCKET);
107+
assertThat(request.key()).isEqualTo(KEY);
108+
109+
if (i == actualRequests.size() - 1) {
110+
assertThat(requestBody.contentLength()).hasValue(5120L);
111+
} else{
112+
assertThat(requestBody.contentLength()).hasValue(PART_SIZE);
113+
}
114+
}
115+
116+
ArgumentCaptor<CompleteMultipartUploadRequest> completeMpuArgumentCaptor = ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class);
117+
verify(s3AsyncClient).completeMultipartUpload(completeMpuArgumentCaptor.capture());
118+
119+
CompleteMultipartUploadRequest actualRequest = completeMpuArgumentCaptor.getValue();
120+
assertThat(actualRequest.multipartUpload().parts()).isEqualTo(completedParts(numTotalParts));
121+
122+
}
123+
124+
private static PutObjectRequest putObjectRequest() {
125+
return PutObjectRequest.builder()
126+
.bucket(BUCKET)
127+
.key(KEY)
128+
.build();
129+
}
130+
131+
private List<CompletedPart> completedParts(int totalNumParts) {
132+
return IntStream.range(1, totalNumParts + 1).mapToObj(i -> CompletedPart.builder().partNumber(i).build()).collect(Collectors.toList());
133+
}
134+
135+
}

0 commit comments

Comments
 (0)