-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-51974][CONNECT][ML] Limit model size and per-session model cache size #50751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
9c90878
8f1dd65
a737354
282a782
6b083fd
f106449
9af6cff
1cf42fb
f5dd09a
588e5da
dd584ae
7e93efb
a071a6a
49710f4
d425ede
f6b3e21
e883892
4e2e449
25c7818
2b0adaa
502cc23
3515092
b7da2b4
23a66fb
f3ddbb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,31 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { | |
|
|
||
| private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) | ||
|
|
||
| private[ml] val totalModelCacheSizeBytes: AtomicLong = new AtomicLong(0) | ||
|
||
| private[spark] def getModelCacheMaxSize: Long = { | ||
| sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_MODEL_CACHE_MAX_SIZE) | ||
| } | ||
| private[spark] def getModelMaxSize: Long = { | ||
| sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_MODEL_MAX_SIZE) | ||
| } | ||
|
|
||
| def checkModelSize(estimatedModelSize: Long): Unit = { | ||
| if (totalModelCacheSizeBytes.get() + estimatedModelSize > getModelCacheMaxSize) { | ||
| throw MLModelCacheSizeOverflowException( | ||
| "The model cache size in current session is about to exceed" + | ||
| f"$getModelCacheMaxSize bytes. " + | ||
| "Please delete existing cached model by executing 'del model' in python client " + | ||
| "before fitting new model or loading new model, or increase " + | ||
| "Spark config 'spark.connect.session.connectML.model.maxSize'.") | ||
| } | ||
| if (estimatedModelSize > getModelMaxSize) { | ||
| throw MLModelSizeOverflowException( | ||
| f"The fitted or loaded model size exceeds $getModelMaxSize bytes. " + | ||
| f"Please fit or load a model smaller than $getModelMaxSize bytes. " + | ||
| f"or increase Spark config 'spark.connect.session.connectML.modelCache.maxSize'.") | ||
| } | ||
WeichenXu123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| private def estimateObjectSize(obj: Object): Long = { | ||
| obj match { | ||
| case model: Model[_] => | ||
|
|
@@ -81,7 +106,11 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { | |
| def register(obj: Object): String = { | ||
| val objectId = UUID.randomUUID().toString | ||
| val sizeBytes = estimateObjectSize(obj) | ||
| checkModelSize(sizeBytes) | ||
|
|
||
| totalSizeBytes.addAndGet(sizeBytes) | ||
| totalModelCacheSizeBytes.addAndGet(sizeBytes) | ||
|
|
||
| cachedModel.put(objectId, CacheItem(obj, sizeBytes)) | ||
| objectId | ||
| } | ||
|
|
@@ -108,6 +137,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { | |
| */ | ||
| def remove(refId: String): Boolean = { | ||
| val removed = cachedModel.remove(refId) | ||
| totalModelCacheSizeBytes.addAndGet(-removed.sizeBytes) | ||
| // remove returns null if the key is not present | ||
| removed != null | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.