Skip to content

Commit 5328641

Browse files
committed
Refactor/fix S3 express integrations
1 parent 2761bf9 commit 5328641

File tree

3 files changed

+126
-41
lines changed

3 files changed

+126
-41
lines changed

codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/s3/express/S3ExpressIntegration.kt

+1-34
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato
1616
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware
1717
import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigProperty
1818
import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigPropertyType
19-
import software.amazon.smithy.kotlin.codegen.utils.dq
20-
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
2119
import software.amazon.smithy.model.Model
2220
import software.amazon.smithy.model.shapes.*
2321
import software.amazon.smithy.model.traits.*
@@ -99,7 +97,6 @@ class S3ExpressIntegration : KotlinIntegration {
9997
resolved + listOf(
10098
addClientToExecutionContext,
10199
addBucketToExecutionContext,
102-
useCrc32Checksum,
103100
uploadPartDisableChecksum,
104101
)
105102

@@ -132,44 +129,14 @@ class S3ExpressIntegration : KotlinIntegration {
132129
}
133130
}
134131

135-
/**
136-
* For any operations that require a checksum, set CRC32 if the user has not already configured a checksum.
137-
*/
138-
private val useCrc32Checksum = object : ProtocolMiddleware {
139-
override val name: String = "UseCrc32Checksum"
140-
141-
override val order: Byte = -1 // Render before flexible checksums
142-
143-
override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = !op.isS3UploadPart &&
144-
(op.hasTrait<HttpChecksumRequiredTrait>() || (op.hasTrait<HttpChecksumTrait>() && op.expectTrait<HttpChecksumTrait>().isRequestChecksumRequired))
145-
146-
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
147-
val interceptorSymbol = buildSymbol {
148-
namespace = "aws.sdk.kotlin.services.s3.express"
149-
name = "S3ExpressCrc32ChecksumInterceptor"
150-
}
151-
152-
val httpChecksumTrait = op.getTrait<HttpChecksumTrait>()
153-
154-
val checksumAlgorithmMember = ctx.model.expectShape<StructureShape>(op.input.get())
155-
.members()
156-
.firstOrNull { it.memberName == httpChecksumTrait?.requestAlgorithmMember?.getOrNull() }
157-
158-
// S3 models a header name x-amz-sdk-checksum-algorithm representing the name of the checksum algorithm used
159-
val checksumHeaderName = checksumAlgorithmMember?.getTrait<HttpHeaderTrait>()?.value
160-
161-
writer.write("op.interceptors.add(#T(${checksumHeaderName?.dq() ?: ""}))", interceptorSymbol)
162-
}
163-
}
164-
165132
/**
166133
* Disable all checksums for s3:UploadPart
167134
*/
168135
private val uploadPartDisableChecksum = object : ProtocolMiddleware {
169136
override val name: String = "UploadPartDisableChecksum"
170137

171138
override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean =
172-
op.isS3UploadPart
139+
op.isS3UploadPart && op.hasTrait<HttpChecksumTrait>()
173140

174141
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
175142
val interceptorSymbol = buildSymbol {

services/s3/common/src/aws/sdk/kotlin/services/s3/express/S3ExpressDisableChecksumInterceptor.kt

+42-7
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ package aws.sdk.kotlin.services.s3.express
66

77
import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
88
import aws.smithy.kotlin.runtime.collections.AttributeKey
9+
import aws.smithy.kotlin.runtime.http.DeferredHeadersBuilder
10+
import aws.smithy.kotlin.runtime.http.HeadersBuilder
911
import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
10-
import aws.smithy.kotlin.runtime.http.operation.HttpOperationContext
1112
import aws.smithy.kotlin.runtime.http.request.HttpRequest
13+
import aws.smithy.kotlin.runtime.http.request.toBuilder
1214
import aws.smithy.kotlin.runtime.telemetry.logging.logger
1315
import kotlin.coroutines.coroutineContext
1416

17+
private const val CHECKSUM_HEADER_PREFIX = "x-amz-checksum-"
18+
1519
/**
16-
* Disable checksums entirely for s3:UploadPart requests.
20+
* Disables checksums for s3:UploadPart requests that use S3 express.
1721
*/
1822
internal class S3ExpressDisableChecksumInterceptor : HttpInterceptor {
1923
override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext<Any, HttpRequest>): HttpRequest {
@@ -22,14 +26,45 @@ internal class S3ExpressDisableChecksumInterceptor : HttpInterceptor {
2226
}
2327

2428
val logger = coroutineContext.logger<S3ExpressDisableChecksumInterceptor>()
29+
logger.warn { "Checksums must not be sent with S3 express upload part operation, removing checksum(s)" }
30+
31+
val request = context.protocolRequest.toBuilder()
32+
33+
request.headers.removeChecksumHeaders()
34+
request.trailingHeaders.removeChecksumTrailingHeaders()
35+
request.headers.removeChecksumTrailingHeadersFromXAmzTrailer()
36+
37+
return request.build()
38+
}
39+
}
2540

26-
val configuredChecksumAlgorithm = context.executionContext.getOrNull(HttpOperationContext.ChecksumAlgorithm)
41+
/**
42+
* Removes any checksums sent in the request's headers
43+
*/
44+
internal fun HeadersBuilder.removeChecksumHeaders(): Unit =
45+
names().forEach { name ->
46+
if (name.startsWith(CHECKSUM_HEADER_PREFIX)) {
47+
remove(name)
48+
}
49+
}
2750

28-
configuredChecksumAlgorithm?.let {
29-
logger.warn { "Disabling configured checksum $it for S3 Express UploadPart" }
30-
context.executionContext.remove(HttpOperationContext.ChecksumAlgorithm)
51+
/**
52+
* Removes any checksums sent in the request's trailing headers
53+
*/
54+
internal fun DeferredHeadersBuilder.removeChecksumTrailingHeaders(): Unit =
55+
names().forEach { name ->
56+
if (name.startsWith(CHECKSUM_HEADER_PREFIX)) {
57+
remove(name)
3158
}
59+
}
3260

33-
return context.protocolRequest
61+
/**
62+
* Removes any checksums sent in the request's trailing headers from `x-amz-trailer`
63+
*/
64+
internal fun HeadersBuilder.removeChecksumTrailingHeadersFromXAmzTrailer() {
65+
this.getAll("x-amz-trailer")?.forEach { trailingHeader ->
66+
if (trailingHeader.startsWith(CHECKSUM_HEADER_PREFIX)) {
67+
this.remove("x-amz-trailer", trailingHeader)
68+
}
3469
}
3570
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package aws.sdk.kotlin.services.s3.express
2+
3+
import aws.smithy.kotlin.runtime.http.DeferredHeadersBuilder
4+
import aws.smithy.kotlin.runtime.http.HeadersBuilder
5+
import kotlin.test.Test
6+
import kotlin.test.assertFalse
7+
import kotlin.test.assertTrue
8+
9+
class ChecksumRemovalTest {
10+
@Test
11+
fun removeChecksumHeaders() {
12+
val headers = HeadersBuilder()
13+
14+
headers.append("x-amz-checksum-crc32", "foo")
15+
headers.append("x-amz-checksum-sha256", "bar")
16+
17+
assertTrue(
18+
headers.contains("x-amz-checksum-crc32"),
19+
)
20+
assertTrue(
21+
headers.contains("x-amz-checksum-sha256"),
22+
)
23+
24+
headers.removeChecksumHeaders()
25+
26+
assertFalse(
27+
headers.contains("x-amz-checksum-crc32"),
28+
)
29+
assertFalse(
30+
headers.contains("x-amz-checksum-sha256"),
31+
)
32+
}
33+
34+
@Test
35+
fun removeChecksumTrailingHeaders() {
36+
val trailingHeaders = DeferredHeadersBuilder()
37+
38+
trailingHeaders.add("x-amz-checksum-crc32", "foo")
39+
trailingHeaders.add("x-amz-checksum-sha256", "bar")
40+
41+
assertTrue(
42+
trailingHeaders.contains("x-amz-checksum-crc32"),
43+
)
44+
assertTrue(
45+
trailingHeaders.contains("x-amz-checksum-sha256"),
46+
)
47+
48+
trailingHeaders.removeChecksumTrailingHeaders()
49+
50+
assertFalse(
51+
trailingHeaders.contains("x-amz-checksum-crc32"),
52+
)
53+
assertFalse(
54+
trailingHeaders.contains("x-amz-checksum-sha256"),
55+
)
56+
}
57+
58+
@Test
59+
fun removeChecksumTrailingHeadersFromXAmzTrailer() {
60+
val headers = HeadersBuilder()
61+
62+
headers.append("x-amz-trailer", "x-amz-checksum-crc32")
63+
headers.append("x-amz-trailer", "x-amz-trailing-header")
64+
65+
val xAmzTrailer = headers.getAll("x-amz-trailer")
66+
67+
assertTrue(
68+
xAmzTrailer?.contains("x-amz-checksum-crc32") ?: false,
69+
)
70+
assertTrue(
71+
xAmzTrailer?.contains("x-amz-trailing-header") ?: false,
72+
)
73+
74+
headers.removeChecksumTrailingHeadersFromXAmzTrailer()
75+
76+
assertFalse(
77+
xAmzTrailer?.contains("x-amz-checksum-crc32") ?: false,
78+
)
79+
assertTrue(
80+
xAmzTrailer?.contains("x-amz-trailing-header") ?: false,
81+
)
82+
}
83+
}

0 commit comments

Comments
 (0)