Skip to content

Commit f044638

Browse files
committed
chore: dedicated gears scheduler
1 parent 84d00f3 commit f044638

File tree

2 files changed

+96
-11
lines changed

2 files changed

+96
-11
lines changed

build.sbt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ lazy val root =
3131
)
3232
.jvmSettings(
3333
Seq(
34-
javaOptions += "--version 21"
34+
javaOptions += "--version 21",
35+
Test / javaOptions ++= Seq(
36+
"--add-opens=java.base/java.lang=ALL-UNNAMED",
37+
"--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED"
38+
)
3539
)
3640
)
3741
.nativeSettings(

jvm/src/main/scala/async/VThreadSupport.scala

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,98 @@
11
package gears.async
22

3-
import java.lang.invoke.{MethodHandles, VarHandle}
3+
import java.lang.invoke.{MethodHandles, MethodType}
4+
import java.util.concurrent.TimeUnit.SECONDS
5+
import java.util.concurrent.atomic.AtomicLong
46
import java.util.concurrent.locks.ReentrantLock
7+
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory}
58
import scala.annotation.unchecked.uncheckedVariance
69
import scala.concurrent.duration.FiniteDuration
10+
import scala.util.control.NonFatal
711

812
object VThreadScheduler extends Scheduler:
9-
private val VTFactory = Thread
10-
.ofVirtual()
11-
.name("gears.async.VThread-", 0L)
12-
.factory()
13+
private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup()
14+
15+
private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
16+
private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread")
17+
private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool]))
18+
private val counter = new AtomicLong(0L)
19+
20+
override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
21+
val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread]
22+
t.setName("gears-CarrierThread-" + counter.getAndIncrement())
23+
t
24+
}
25+
}
26+
27+
private val DEFAULT_SCHEDULER: ForkJoinPool = {
28+
val parallelismValue = sys.props
29+
.get("gears.default-scheduler.parallelism")
30+
.map(_.toInt)
31+
.getOrElse(Runtime.getRuntime.availableProcessors())
32+
33+
val maxPoolSizeValue = sys.props
34+
.get("gears.default-scheduler.max-pool-size")
35+
.map(_.toInt)
36+
.getOrElse(256)
37+
38+
val minRunnableValue = sys.props
39+
.get("gears.default-scheduler.min-runnable")
40+
.map(_.toInt)
41+
.getOrElse(parallelismValue / 2)
42+
43+
new ForkJoinPool(
44+
Runtime.getRuntime.availableProcessors(),
45+
CarrierThreadFactory,
46+
(t: Thread, e: Throwable) => {
47+
// noop for now
48+
},
49+
true,
50+
0,
51+
maxPoolSizeValue,
52+
minRunnableValue,
53+
(pool: ForkJoinPool) => true,
54+
60,
55+
SECONDS
56+
)
57+
}
58+
59+
private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER)
60+
61+
/** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread.
62+
*
63+
* The executor should run task on platform threads.
64+
*
65+
* returns null if not supported.
66+
*/
67+
private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory =
68+
try {
69+
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
70+
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
71+
val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
72+
var builder = ofVirtualMethod.invoke()
73+
if (executor != null) {
74+
val clazz = builder.getClass
75+
val privateLookup = MethodHandles.privateLookupIn(
76+
clazz,
77+
LOOKUP
78+
)
79+
val schedulerFieldSetter = privateLookup
80+
.findSetter(clazz, "scheduler", classOf[Executor])
81+
schedulerFieldSetter.invoke(builder, executor)
82+
}
83+
val nameMethod = LOOKUP.findVirtual(
84+
ofVirtualClass,
85+
"name",
86+
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])
87+
)
88+
val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
89+
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
90+
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
91+
} catch {
92+
case NonFatal(e) =>
93+
// --add-opens java.base/java.lang=ALL-UNNAMED
94+
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
95+
}
1396

1497
override def execute(body: Runnable): Unit =
1598
val th = VTFactory.newThread(body)
@@ -31,11 +114,9 @@ object VThreadScheduler extends Scheduler:
31114
}
32115

33116
private object ScheduledRunnable:
34-
val interruptGuardVar =
35-
MethodHandles
36-
.lookup()
37-
.in(classOf[ScheduledRunnable])
38-
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
117+
val interruptGuardVar = LOOKUP
118+
.in(classOf[ScheduledRunnable])
119+
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
39120

40121
object VThreadSupport extends AsyncSupport:
41122

0 commit comments

Comments
 (0)