Skip to content

Commit

Permalink
KTOR-7470 Rethrow UnsupportedMediaTypeException if content type heade…
Browse files Browse the repository at this point in the history
…r is not present or is invalid
  • Loading branch information
Gleb Nazarov authored and Gleb Nazarov committed Sep 25, 2024
1 parent 68d447c commit cc81b68
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ public class CannotTransformContentToTypeException(
*/
@OptIn(ExperimentalCoroutinesApi::class)
public class UnsupportedMediaTypeException(
private val contentType: ContentType
) : ContentTransformationException("Content type $contentType is not supported"),
CopyableThrowable<UnsupportedMediaTypeException> {
private val contentType: ContentType?
) : ContentTransformationException(
contentType?.let { "Content type $it is not supported" }
?: "Content-Type header is required for multipart processing"
), CopyableThrowable<UnsupportedMediaTypeException> {

override fun createCopy(): UnsupportedMediaTypeException = UnsupportedMediaTypeException(contentType).also {
it.initCauseBridge(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.ktor.http.*
import io.ktor.http.cio.*
import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.server.plugins.UnsupportedMediaTypeException
import io.ktor.server.request.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.*
Expand All @@ -17,6 +18,7 @@ import io.ktor.utils.io.streams.*
import kotlinx.coroutines.*
import kotlinx.io.*
import java.io.*
import java.io.IOException

internal actual suspend fun PipelineContext<Any, PipelineCall>.defaultPlatformTransformations(
query: Any
Expand All @@ -33,16 +35,21 @@ internal actual suspend fun PipelineContext<Any, PipelineCall>.defaultPlatformTr
@OptIn(InternalAPI::class)
internal actual fun PipelineContext<*, PipelineCall>.multiPartData(rc: ByteReadChannel): MultiPartData {
val contentType = call.request.header(HttpHeaders.ContentType)
?: throw IllegalStateException("Content-Type header is required for multipart processing")
?: throw UnsupportedMediaTypeException(null)

val contentLength = call.request.header(HttpHeaders.ContentLength)?.toLong()
return CIOMultipartDataBase(
coroutineContext + Dispatchers.Unconfined,
rc,
contentType,
contentLength,
call.formFieldLimit
)

try {
return CIOMultipartDataBase(
coroutineContext + Dispatchers.Unconfined,
rc,
contentType,
contentLength,
call.formFieldLimit
)
} catch (_: IOException) {
throw UnsupportedMediaTypeException(ContentType.parse(contentType))
}
}

internal actual fun Source.readTextWithCustomCharset(charset: Charset): String =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.tests.server.engine

import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.engine.multiPartData
import io.ktor.server.plugins.UnsupportedMediaTypeException
import io.ktor.server.request.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.*
import io.mockk.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runTest
import kotlin.test.*

class MultiPartDataTest {
private val mockContext = mockk<PipelineContext<*, PipelineCall>>(relaxed = true)
private val mockRequest = mockk<PipelineRequest>(relaxed = true)
private val testScope = TestScope()

@Test
fun givenRequest_whenNoContentTypeHeaderPresent_thenUnsupportedMediaTypeException() {
// Setup
every { mockContext.call.request } returns mockRequest
every { mockRequest.header(HttpHeaders.ContentType) } returns null

// Act & Assert
assertFailsWith<UnsupportedMediaTypeException> {
runBlocking { mockContext.multiPartData(ByteReadChannel("sample data")) }
}
}

@Test
fun givenWrongContentType_whenProcessMultiPart_thenUnsupportedMediaTypeException() {
// Given
val rc = ByteReadChannel("sample data")
val contentType = "test/plain; boundary=test"
val contentLength = "123"
every { mockContext.call.request } returns mockRequest
every { mockContext.call.attributes.getOrNull<Long>(any()) } returns 0L
every { mockRequest.header(HttpHeaders.ContentType) } returns contentType
every { mockRequest.header(HttpHeaders.ContentLength) } returns contentLength

// When & Then
testScope.runTest {
assertFailsWith<UnsupportedMediaTypeException> {
mockContext.multiPartData(rc)
}
}
}
}

0 comments on commit cc81b68

Please sign in to comment.