Skip to content

Commit 62e4d9d

Browse files
authored
Fix leaks in TLSWebViewClient (#5445)
* Introduce SafeKeyChainAliasCallback * Add missing request ignore when activity is null
1 parent 6b398a4 commit 62e4d9d

File tree

1 file changed

+65
-38
lines changed

1 file changed

+65
-38
lines changed

app/src/main/kotlin/io/homeassistant/companion/android/util/TLSWebViewClient.kt

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ import android.webkit.ClientCertRequest
1010
import android.webkit.WebView
1111
import android.webkit.WebViewClient
1212
import androidx.annotation.RequiresApi
13-
import androidx.appcompat.app.AppCompatActivity
14-
import androidx.lifecycle.lifecycleScope
1513
import io.homeassistant.companion.android.common.data.keychain.KeyChainRepository
16-
import java.security.Principal
14+
import java.lang.ref.WeakReference
1715
import java.security.PrivateKey
1816
import java.security.cert.CertificateException
1917
import java.security.cert.X509Certificate
20-
import javax.inject.Inject
21-
import javax.inject.Named
18+
import kotlinx.coroutines.CoroutineScope
19+
import kotlinx.coroutines.Dispatchers
20+
import kotlinx.coroutines.Job
2221
import kotlinx.coroutines.launch
22+
import timber.log.Timber
2323

24-
open class TLSWebViewClient @Inject constructor(@Named("keyChainRepository") private var keyChainRepository: KeyChainRepository) : WebViewClient() {
25-
24+
open class TLSWebViewClient(private var keyChainRepository: KeyChainRepository) : WebViewClient() {
2625
var isTLSClientAuthNeeded = false
2726
private set
2827

@@ -39,17 +38,14 @@ open class TLSWebViewClient @Inject constructor(@Named("keyChainRepository") pri
3938
if (context == null) {
4039
return null
4140
} else if (context is ContextWrapper) {
42-
return if (context is Activity) {
43-
context
44-
} else {
45-
getActivity(context.baseContext)
46-
}
41+
return context as? Activity ?: getActivity(context.baseContext)
4742
}
4843
return null
4944
}
5045

5146
@RequiresApi(Build.VERSION_CODES.M)
5247
override fun onReceivedClientCertRequest(view: WebView, request: ClientCertRequest) {
48+
Timber.d("onReceivedClientCertRequest invoked looking for cert in local storage or ask the user for it")
5349
// Let the WebViewActivity know the endpoint requires TLS Client Auth
5450
isTLSClientAuthNeeded = true
5551

@@ -75,42 +71,37 @@ open class TLSWebViewClient @Inject constructor(@Named("keyChainRepository") pri
7571
// If no key is available, then the user must be prompt for a key
7672
// The whole operation is wrapped in the selectPrivateKey method but caution as it must occurs outside of the main thread
7773
// see: https://developer.android.com/reference/android/security/KeyChain#getPrivateKey(android.content.Context,%20java.lang.String)
78-
selectClientCert(activity, request.principals, request)
74+
selectClientCert(activity, request)
7975
}
8076
}
77+
} else {
78+
request.ignore()
8179
}
8280
}
8381
}
8482

8583
@RequiresApi(Build.VERSION_CODES.M)
86-
private fun selectClientCert(activity: Activity, principals: Array<Principal>?, request: ClientCertRequest) {
87-
require(activity is AppCompatActivity)
88-
89-
val kcac = KeyChainAliasCallback { alias ->
90-
if (alias != null) {
91-
activity.lifecycleScope.launch {
92-
// Load the key and the chain
93-
keyChainRepository.load(activity.applicationContext, alias)
94-
95-
key = keyChainRepository.getPrivateKey()
96-
chain = keyChainRepository.getCertificateChain()
97-
98-
// If we got the key and the cert
99-
if (key == null || chain == null) {
100-
// Either the user didn't choose a key or no key was available
101-
hasUserDeniedAccess = true
102-
}
103-
84+
private fun selectClientCert(activity: Activity, request: ClientCertRequest) {
85+
// prompt the user for a key
86+
KeyChain.choosePrivateKeyAlias(
87+
activity,
88+
SafeKeyChainAliasCallback(keyChainRepository, activity.applicationContext) { key, chain ->
89+
if (key == null || chain == null) {
90+
hasUserDeniedAccess = true
91+
request.ignore()
92+
} else {
10493
checkChainValidity()
94+
this.key = key
95+
this.chain = chain
10596
request.proceed(key, chain)
10697
}
107-
} else {
108-
request.proceed(key, chain)
109-
}
110-
}
111-
112-
// prompt the user for a key
113-
KeyChain.choosePrivateKeyAlias(activity, kcac, arrayOf<String>(), principals, null, null)
98+
},
99+
request.keyTypes,
100+
request.principals,
101+
request.host,
102+
request.port,
103+
null,
104+
)
114105
}
115106

116107
private fun checkChainValidity() {
@@ -125,3 +116,39 @@ open class TLSWebViewClient @Inject constructor(@Named("keyChainRepository") pri
125116
}
126117
}
127118
}
119+
120+
/**
121+
* Addresses a potential memory leak with [KeyChain.choosePrivateKeyAlias].
122+
*
123+
* [KeyChain.choosePrivateKeyAlias] holds a strong reference to its callback even after
124+
* invocation. To prevent this callback from leaking its capturing context (e.g., a WebView),
125+
* this wrapper stores the actual result consumer ([onResult]) in a [WeakReference].
126+
*
127+
* If the consumer (e.g., WebView) is destroyed before the user selects a key,
128+
* the [WeakReference] will allow it to be garbage collected, and the result will not be
129+
* delivered to the (now-gone) consumer. The user's selection is intended to be
130+
* handled independently by [KeyChainRepository.load] within its coroutine scope,
131+
* ensuring the choice is persisted even if the initial UI component is gone.
132+
*/
133+
private class SafeKeyChainAliasCallback(
134+
private var keyChainRepository: KeyChainRepository,
135+
context: Context,
136+
onResult: (key: PrivateKey?, chain: Array<X509Certificate>?) -> Unit,
137+
) : KeyChainAliasCallback {
138+
private val ioScope: CoroutineScope = CoroutineScope(Dispatchers.IO + Job())
139+
private val context = context.applicationContext
140+
private val onResult = WeakReference(onResult)
141+
142+
override fun alias(alias: String?) {
143+
if (alias != null) {
144+
ioScope.launch {
145+
keyChainRepository.load(context, alias)
146+
val key = keyChainRepository.getPrivateKey()
147+
val chain = keyChainRepository.getCertificateChain()
148+
onResult.get()?.invoke(key, chain)
149+
}
150+
} else {
151+
onResult.get()?.invoke(null, null)
152+
}
153+
}
154+
}

0 commit comments

Comments
 (0)