@@ -12,6 +12,10 @@ import java.util.concurrent.TimeUnit
12
12
import scala .concurrent .ExecutionContext
13
13
import scala .concurrent .Future
14
14
import scala .jdk .CollectionConverters .*
15
+ import scala .util .control .NonFatal
16
+ import scala .util .Try
17
+
18
+ import scala .concurrent .ExecutionContext .global as GlobalEC
15
19
16
20
object Limits :
17
21
lazy val MaxFileLength =
@@ -38,41 +42,47 @@ object Optimizer extends cask.MainRoutes:
38
42
def assets () = " assets"
39
43
40
44
@ cask.websocket(" /ws/connect/:id" )
41
- def showUserProfile (id : String ): cask.WebsocketResult =
42
- val j = parseJobId(id)
43
- cask.WsHandler { channel =>
45
+ def jobUpdates (id : String ): cask.WebsocketResult =
46
+ val jobId = parseJobId(id)
47
+ val handler = cask.WsHandler { channel =>
44
48
var sc : Option [ScheduledFuture [? ]] = None
45
49
46
50
val actor = cask.WsActor {
51
+ case cask.Ws .Ping (dt) =>
52
+ cask.Ws .Pong (dt)
47
53
48
54
case cask.Ws .Text (" ping" ) =>
49
- heartbeatJob(j ) match
55
+ heartbeatJob(jobId ) match
50
56
case None =>
51
- sc.map (_.cancel(true ))
57
+ sc.foreach (_.cancel(true ))
52
58
channel.send(cask.Ws .Close ())
53
59
54
60
case Some (value) if value.instruction == TrainingInstruction .Halt =>
55
- sc.map (_.cancel(true ))
61
+ sc.foreach (_.cancel(true ))
56
62
channel.send(cask.Ws .Close ())
57
-
58
63
case _ =>
64
+ cask.Ws .Pong ()
65
+
59
66
case cask.Ws .Close (_, _) =>
60
- haltJob(j, " websocket connection closed" )
61
- sc.map(_.cancel(true ))
67
+ haltJob(jobId, " websocket connection closed" )
68
+ sc.foreach(_.cancel(true ))
69
+
62
70
}
63
71
64
72
sc = Some (
65
73
scheduler.scheduleAtFixedRate(
66
- run(sendJobUpdate(j , channel)),
67
- 0 ,
74
+ run(sendJobUpdate(jobId , channel, () => sc.foreach(_.cancel( true )) )),
75
+ 1 ,
68
76
1 ,
69
77
TimeUnit .SECONDS
70
78
)
71
79
)
72
80
73
81
actor
74
82
}
75
- end showUserProfile
83
+
84
+ if jobs.containsKey(jobId) then handler else cask.Response (" " , 404 )
85
+ end jobUpdates
76
86
77
87
@ cask.post(" /api/halt/:id" )
78
88
def doThing (id : String ) =
@@ -94,7 +104,7 @@ object Optimizer extends cask.MainRoutes:
94
104
else if attrs.generations > Limits .MaxGenerations then
95
105
error(s " Number of generations above maximum [ ${Limits .MaxGenerations }] " )
96
106
else
97
- val (id, job) = createJob(attrs)( using executionContext )
107
+ val (id, job) = createJob(attrs, GlobalEC )
98
108
cask.Response (id.toString(), 200 )
99
109
end if
100
110
end doThing
@@ -163,40 +173,64 @@ case class Result(
163
173
generations : Int
164
174
) derives ReadWriter
165
175
166
- def sendJobUpdate (id : UUID , channel : cask.WsChannelActor ) = () =>
167
- Option (jobs.get(id))
168
- .foreach: job =>
169
- job.result.foreach: result =>
170
- val serialised =
171
- Conf .printHocon(result.config.toScalafmt(base))
172
-
173
- val serialisedBase =
174
- Conf .printHocon(base)
175
-
176
- val codec = ScalafmtConfig .encoder
177
-
178
- val diff = munit.diff
179
- .Diff (job.attributes.file, result.formattedFile)
180
- .unifiedDiff
181
-
182
- val configDiff =
183
- Conf .printHocon(
184
- Conf .patch(
185
- codec.write(base),
186
- codec.write(result.config.toScalafmt(base))
176
+ inline def absorb [T ](msg : String , f : => T ) =
177
+ try f
178
+ catch
179
+ case NonFatal (exc) =>
180
+ scribe.error(s " Failed to [ $msg] " , exc)
181
+
182
+ def sendJobUpdate (id : UUID , channel : cask.WsChannelActor , cancel : () => Unit ) =
183
+ () =>
184
+ Option (jobs.get(id)) match
185
+ case None =>
186
+ absorb(
187
+ s " closing WS connection for [ $id] " ,
188
+ channel.send(cask.Ws .Close ())
189
+ )
190
+ cancel()
191
+ case Some (job) if job.instruction == TrainingInstruction .Halt =>
192
+ absorb(
193
+ s " closing WS connection for [ $id] " ,
194
+ channel.send(cask.Ws .Close ())
195
+ )
196
+ cancel()
197
+ case Some (job) =>
198
+ channel.send(cask.Ws .Ping ())
199
+ job.result.foreach: result =>
200
+ val serialised =
201
+ Conf .printHocon(result.config.toScalafmt(base))
202
+
203
+ val serialisedBase =
204
+ Conf .printHocon(base)
205
+
206
+ val codec = ScalafmtConfig .encoder
207
+
208
+ val diff = munit.diff
209
+ .Diff (job.attributes.file, result.formattedFile)
210
+ .unifiedDiff
211
+
212
+ val configDiff =
213
+ Conf .printHocon(
214
+ Conf .patch(
215
+ codec.write(base),
216
+ codec.write(result.config.toScalafmt(base))
217
+ )
187
218
)
188
- )
189
219
190
- val res = Result (
191
- config = serialised,
192
- formattedFile = result.formattedFile,
193
- fileDiff = diff,
194
- configDiff = configDiff,
195
- generation = result.generation,
196
- generations = job.attributes.generations
197
- )
220
+ val res = Result (
221
+ config = serialised,
222
+ formattedFile = result.formattedFile,
223
+ fileDiff = diff,
224
+ configDiff = configDiff,
225
+ generation = result.generation,
226
+ generations = job.attributes.generations
227
+ )
198
228
199
- channel.send(cask.Ws .Text (upickle.default.write(res)))
229
+ absorb(
230
+ s " sending update for [ $id] " ,
231
+ channel.send(cask.Ws .Text (upickle.default.write(res)))
232
+ )
233
+ end match
200
234
201
235
def haltJob (id : UUID , reason : String | Null = null ) =
202
236
Option (
@@ -220,7 +254,7 @@ case class TrainingResult(
220
254
generation : Int
221
255
)
222
256
223
- def createJob (attrs : JobAttributes )( using ExecutionContext ): (UUID , Job ) =
257
+ def createJob (attrs : JobAttributes , ec : ExecutionContext ): (UUID , Job ) =
224
258
val id = generateJobId()
225
259
226
260
val job = Job (
@@ -277,15 +311,15 @@ def createJob(attrs: JobAttributes)(using ExecutionContext): (UUID, Job) =
277
311
end handle
278
312
end handler
279
313
280
- Future :
314
+ ec.execute : () =>
281
315
Train (
282
316
featureful = summon[Featureful [ScalafmtConfigSubset ]],
283
317
config = trainingConfig,
284
318
fitness = cachedFitness(trainingConfig.random, 500 )(
285
319
Fitness (fitness(attrs.file, _))
286
320
),
287
321
events = handler,
288
- evaluator = ParallelCollectionsEvaluator
322
+ evaluator = SequentialEvaluator
289
323
).train()
290
324
291
325
id -> job
0 commit comments