Skip to content

Commit a34c4f4

Browse files
committed
chore: dedicated gears scheduler
1 parent 84d00f3 commit a34c4f4

File tree

2 files changed

+92
-11
lines changed

2 files changed

+92
-11
lines changed

build.sbt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ lazy val root =
3131
)
3232
.jvmSettings(
3333
Seq(
34-
javaOptions += "--version 21"
34+
javaOptions += "--version 21",
35+
Test / javaOptions ++= Seq(
36+
"--add-opens=java.base/java.lang=ALL-UNNAMED",
37+
"--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED"
38+
)
3539
)
3640
)
3741
.nativeSettings(

jvm/src/main/scala/async/VThreadSupport.scala

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,94 @@
11
package gears.async
22

3-
import java.lang.invoke.{MethodHandles, VarHandle}
3+
import java.lang.invoke.{MethodHandles, MethodType}
4+
import java.util.concurrent.TimeUnit.SECONDS
45
import java.util.concurrent.locks.ReentrantLock
6+
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory}
57
import scala.annotation.unchecked.uncheckedVariance
68
import scala.concurrent.duration.FiniteDuration
9+
import scala.util.control.NonFatal
710

811
object VThreadScheduler extends Scheduler:
9-
private val VTFactory = Thread
10-
.ofVirtual()
11-
.name("gears.async.VThread-", 0L)
12-
.factory()
12+
private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup()
13+
14+
private object CarrieThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
15+
private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread")
16+
private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool]))
17+
18+
override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
19+
constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread]
20+
}
21+
}
22+
23+
private val DEFAULT_SCHEDULER: ForkJoinPool = {
24+
val parallelismValue = sys.props
25+
.get("gears.default-scheduler.parallelism")
26+
.map(_.toInt)
27+
.getOrElse(Runtime.getRuntime.availableProcessors())
28+
29+
val maxPoolSizeValue = sys.props
30+
.get("gears.default-scheduler.max-pool-size")
31+
.map(_.toInt)
32+
.getOrElse(256)
33+
34+
val minRunnableValue = sys.props
35+
.get("gears.default-scheduler.min-runnable")
36+
.map(_.toInt)
37+
.getOrElse(parallelismValue / 2)
38+
39+
new ForkJoinPool(
40+
Runtime.getRuntime.availableProcessors(),
41+
CarrieThreadFactory,
42+
(t: Thread, e: Throwable) => {
43+
// noop for now
44+
},
45+
true,
46+
0,
47+
maxPoolSizeValue,
48+
minRunnableValue,
49+
(pool: ForkJoinPool) => true,
50+
60,
51+
SECONDS
52+
)
53+
}
54+
55+
private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER)
56+
57+
/** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread.
58+
*
59+
* The executor should run task on platform threads.
60+
*
61+
* returns null if not supported.
62+
*/
63+
private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory =
64+
try {
65+
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
66+
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
67+
val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
68+
var builder = ofVirtualMethod.invoke()
69+
if (executor != null) {
70+
val clazz = builder.getClass
71+
val privateLookup = MethodHandles.privateLookupIn(
72+
clazz,
73+
LOOKUP
74+
)
75+
val schedulerFieldSetter = privateLookup
76+
.findSetter(clazz, "scheduler", classOf[Executor])
77+
schedulerFieldSetter.invoke(builder, executor)
78+
}
79+
val nameMethod = LOOKUP.findVirtual(
80+
ofVirtualClass,
81+
"name",
82+
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])
83+
)
84+
val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
85+
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
86+
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
87+
} catch {
88+
case NonFatal(e) =>
89+
// --add-opens java.base/java.lang=ALL-UNNAMED
90+
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
91+
}
1392

1493
override def execute(body: Runnable): Unit =
1594
val th = VTFactory.newThread(body)
@@ -31,11 +110,9 @@ object VThreadScheduler extends Scheduler:
31110
}
32111

33112
private object ScheduledRunnable:
34-
val interruptGuardVar =
35-
MethodHandles
36-
.lookup()
37-
.in(classOf[ScheduledRunnable])
38-
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
113+
val interruptGuardVar = LOOKUP
114+
.in(classOf[ScheduledRunnable])
115+
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
39116

40117
object VThreadSupport extends AsyncSupport:
41118

0 commit comments

Comments
 (0)