Skip to content

Commit

Permalink
fix server I think
Browse files Browse the repository at this point in the history
  • Loading branch information
keynmol committed Sep 23, 2024
1 parent a762850 commit 0dd3cf3
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 50 deletions.
19 changes: 17 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@ RUN wget https://raw.githubusercontent.com/VirtusLab/scala-cli/main/scala-cli.sh
scala-cli version && \
echo '@main def hello = println(42)' | scala-cli run _ --js -S 3.5.0

WORKDIR /scratch

COPY backend/project.scala .

RUN scala-cli compile project.scala

COPY frontend/src/project.scala .

RUN scala-cli compile project.scala

WORKDIR /source/frontend
COPY frontend/package.json .
COPY frontend/package-lock.json .
RUN npm install

WORKDIR /source

COPY . .
Expand All @@ -21,9 +36,9 @@ RUN npm install
RUN npm run build

WORKDIR /source/backend
RUN scala-cli package . --assembly -f -o ./optimizer-backend
RUN scala-cli package . --assembly -f -o ./optimizer-backend --offline --server=false

FROM eclipse-temurin:22
FROM ghcr.io/graalvm/jdk-community:23

COPY --from=build /source/backend/optimizer-backend /app/optimizer-backend

Expand Down
128 changes: 81 additions & 47 deletions backend/server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.jdk.CollectionConverters.*
import scala.util.control.NonFatal
import scala.util.Try

import scala.concurrent.ExecutionContext.global as GlobalEC

object Limits:
lazy val MaxFileLength =
Expand All @@ -38,41 +42,47 @@ object Optimizer extends cask.MainRoutes:
def assets() = "assets"

@cask.websocket("/ws/connect/:id")
def showUserProfile(id: String): cask.WebsocketResult =
val j = parseJobId(id)
cask.WsHandler { channel =>
def jobUpdates(id: String): cask.WebsocketResult =
val jobId = parseJobId(id)
val handler = cask.WsHandler { channel =>
var sc: Option[ScheduledFuture[?]] = None

val actor = cask.WsActor {
case cask.Ws.Ping(dt) =>
cask.Ws.Pong(dt)

case cask.Ws.Text("ping") =>
heartbeatJob(j) match
heartbeatJob(jobId) match
case None =>
sc.map(_.cancel(true))
sc.foreach(_.cancel(true))
channel.send(cask.Ws.Close())

case Some(value) if value.instruction == TrainingInstruction.Halt =>
sc.map(_.cancel(true))
sc.foreach(_.cancel(true))
channel.send(cask.Ws.Close())

case _ =>
cask.Ws.Pong()

case cask.Ws.Close(_, _) =>
haltJob(j, "websocket connection closed")
sc.map(_.cancel(true))
haltJob(jobId, "websocket connection closed")
sc.foreach(_.cancel(true))

}

sc = Some(
scheduler.scheduleAtFixedRate(
run(sendJobUpdate(j, channel)),
0,
run(sendJobUpdate(jobId, channel, () => sc.foreach(_.cancel(true)))),
1,
1,
TimeUnit.SECONDS
)
)

actor
}
end showUserProfile

if jobs.containsKey(jobId) then handler else cask.Response("", 404)
end jobUpdates

@cask.post("/api/halt/:id")
def doThing(id: String) =
Expand All @@ -94,7 +104,7 @@ object Optimizer extends cask.MainRoutes:
else if attrs.generations > Limits.MaxGenerations then
error(s"Number of generations above maximum [${Limits.MaxGenerations}]")
else
val (id, job) = createJob(attrs)(using executionContext)
val (id, job) = createJob(attrs, GlobalEC)
cask.Response(id.toString(), 200)
end if
end doThing
Expand Down Expand Up @@ -163,40 +173,64 @@ case class Result(
generations: Int
) derives ReadWriter

def sendJobUpdate(id: UUID, channel: cask.WsChannelActor) = () =>
Option(jobs.get(id))
.foreach: job =>
job.result.foreach: result =>
val serialised =
Conf.printHocon(result.config.toScalafmt(base))

val serialisedBase =
Conf.printHocon(base)

val codec = ScalafmtConfig.encoder

val diff = munit.diff
.Diff(job.attributes.file, result.formattedFile)
.unifiedDiff

val configDiff =
Conf.printHocon(
Conf.patch(
codec.write(base),
codec.write(result.config.toScalafmt(base))
inline def absorb[T](msg: String, f: => T) =
try f
catch
case NonFatal(exc) =>
scribe.error(s"Failed to [$msg]", exc)

def sendJobUpdate(id: UUID, channel: cask.WsChannelActor, cancel: () => Unit) =
() =>
Option(jobs.get(id)) match
case None =>
absorb(
s"closing WS connection for [$id]",
channel.send(cask.Ws.Close())
)
cancel()
case Some(job) if job.instruction == TrainingInstruction.Halt =>
absorb(
s"closing WS connection for [$id]",
channel.send(cask.Ws.Close())
)
cancel()
case Some(job) =>
channel.send(cask.Ws.Ping())
job.result.foreach: result =>
val serialised =
Conf.printHocon(result.config.toScalafmt(base))

val serialisedBase =
Conf.printHocon(base)

val codec = ScalafmtConfig.encoder

val diff = munit.diff
.Diff(job.attributes.file, result.formattedFile)
.unifiedDiff

val configDiff =
Conf.printHocon(
Conf.patch(
codec.write(base),
codec.write(result.config.toScalafmt(base))
)
)
)

val res = Result(
config = serialised,
formattedFile = result.formattedFile,
fileDiff = diff,
configDiff = configDiff,
generation = result.generation,
generations = job.attributes.generations
)
val res = Result(
config = serialised,
formattedFile = result.formattedFile,
fileDiff = diff,
configDiff = configDiff,
generation = result.generation,
generations = job.attributes.generations
)

channel.send(cask.Ws.Text(upickle.default.write(res)))
absorb(
s"sending update for [$id]",
channel.send(cask.Ws.Text(upickle.default.write(res)))
)
end match

def haltJob(id: UUID, reason: String | Null = null) =
Option(
Expand All @@ -220,7 +254,7 @@ case class TrainingResult(
generation: Int
)

def createJob(attrs: JobAttributes)(using ExecutionContext): (UUID, Job) =
def createJob(attrs: JobAttributes, ec: ExecutionContext): (UUID, Job) =
val id = generateJobId()

val job = Job(
Expand Down Expand Up @@ -277,15 +311,15 @@ def createJob(attrs: JobAttributes)(using ExecutionContext): (UUID, Job) =
end handle
end handler

Future:
ec.execute: () =>
Train(
featureful = summon[Featureful[ScalafmtConfigSubset]],
config = trainingConfig,
fitness = cachedFitness(trainingConfig.random, 500)(
Fitness(fitness(attrs.file, _))
),
events = handler,
evaluator = ParallelCollectionsEvaluator
evaluator = SequentialEvaluator
).train()

id -> job
Expand Down
3 changes: 2 additions & 1 deletion frontend/src/main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import com.raquo.laminar.api.L.*
import io.laminext.websocket.WebSocket

import upickle.default.ReadWriter
import com.raquo.airstream.core.Signal

case class Result(
config: String,
Expand Down Expand Up @@ -42,7 +43,7 @@ case class JobAttributes(

val error = Var(Option.empty[String])

val updateJob =
val updateJob: Signal[Option[Result]] =
id.signal.flatMapSwitch:
case None => Signal.fromValue(None)
case Some(id) =>
Expand Down

0 comments on commit 0dd3cf3

Please sign in to comment.