Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
keynmol committed Sep 25, 2024
1 parent 97a8e78 commit aed7617
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 3 deletions.
6 changes: 5 additions & 1 deletion backend/project.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
//> using dependency "com.outr::scribe::3.15.0"
//> using dependency "org.scalameta::munit-diff::1.0.2"
//> using dependency "org.scalameta:scalafmt-core_2.13:3.8.3"
//> using dependency org.http4s::http4s-ember-server::0.23.28
//> using dependency org.http4s::http4s-dsl::0.23.28
//> using dep org.http4s::http4s-circe::0.23.28
//> using dep io.circe::circe-core::0.14.10
//> using dep com.outr::scribe-cats::3.15.0
//> using file "../shared/protocol.scala"

//> using resourceDirs "../frontend/dist"

4 changes: 2 additions & 2 deletions backend/server.scala → backend/server.cask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import scala.util.Try

import scala.concurrent.ExecutionContext.global as GlobalEC

object Optimizer extends cask.MainRoutes:
object OptimizerServerCask extends cask.MainRoutes:
override def port = 9999
override def host: String = "0.0.0.0"

Expand Down Expand Up @@ -136,7 +136,7 @@ object Optimizer extends cask.MainRoutes:

initialize()

end Optimizer
end OptimizerServerCask

def parseJobId(id: String) =
UUID.fromString(id)
Expand Down
333 changes: 333 additions & 0 deletions backend/server.http4s.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
import _root_.io.circe.*
import cats.data.Kleisli
import cats.effect.*
import cats.effect.std.Dispatcher
import cats.effect.std.Queue
import cats.effect.std.Supervisor
import com.comcast.ip4s.*
import fs2.concurrent.Channel
import fs2.concurrent.Signal
import fs2.concurrent.Topic
import genovese.TrainingInstruction
import metaconfig.Conf
import org.http4s.HttpApp
import org.http4s.HttpRoutes
import org.http4s.dsl.io.*
import org.http4s.ember.server.EmberServerBuilder
import org.http4s.server.websocket.WebSocketBuilder2
import org.http4s.websocket.WebSocketFrame
import org.scalafmt.config.ScalafmtConfig
import scribe.Scribe

import java.util.UUID

import std.MapRef
import concurrent.duration.*

object OptimizerServerHttp4s extends IOApp.Simple:

import org.http4s.implicits.*

override def run: IO[Unit] =
JobManager.create
.map(routes(_))
.flatMap: routes =>
EmberServerBuilder
.default[IO]
.withPort(port"9999")
.withHost(host"0.0.0.0")
.withHttpWebSocketApp(wbs =>
handleErrors(scribe.cats.io, routes(wbs).orNotFound)
)
.build
.useForever

def handleErrors(logger: Scribe[IO], routes: HttpApp[IO]): HttpApp[IO] =
import cats.syntax.all.*
routes.onError { exc =>
Kleisli(request => logger.error("Request failed", request.toString, exc))
}

end OptimizerServerHttp4s

given Codec[JobAttributes] = Codec.derived[JobAttributes]
given Codec[JobProgress] = Codec.derived[JobProgress]

def routes(
manager: JobManager
)(wbs: WebSocketBuilder2[IO]): HttpRoutes[IO] =
HttpRoutes.of[IO]:
case GET -> Root / "api" / "stats" =>
import org.http4s.circe.*
manager.get.flatMap: mp =>
val active = mp.filter(_._2.instruction != TrainingInstruction.Halt)
Ok(
Json.fromValues(
active.map((id, j) =>
Json.obj(
"id" -> Json.fromString(id.toString),
"heartbeat" -> Json.fromString(j.heartbeat.toString)
)
)
)
)

case GET -> Root / "ws" / "connect" / id =>
val uuid = UUID.fromString(id)
manager.get
.map(_.get(uuid))
.flatMap:
case None => NotFound(s"Job $id not found")
case Some(job) =>
IO.deferred[Either[Throwable, Unit]]
.product(Queue.bounded[IO, WebSocketFrame](1024))
.flatMap: (latch, q) =>
def sendJobUpdate =
import io.circe.syntax.*
def sendProgress(progress: JobProgress) =
q.offer(
WebSocketFrame
.Text(progress.asJson.noSpaces)
)

manager.get
.map(_.get(uuid))
.flatMap:
case None =>
q.offer(WebSocketFrame.Close())
case Some(job)
if job.instruction == TrainingInstruction.Halt =>
latch.complete(Right(())) *>
sendProgress(JobProgress.Finished) *>
q.offer(WebSocketFrame.Close())
case Some(job) =>
job.result match
case None =>
sendProgress(JobProgress.Started)
case Some(result) =>
val serialised =
Conf.printHocon(result.config.toScalafmt(base))

val serialisedBase =
Conf.printHocon(base)

val codec = ScalafmtConfig.encoder

val diff = munit.diff
.Diff(job.attributes.file, result.formattedFile)
.unifiedDiff

val configDiff =
Conf.printHocon(
Conf.patch(
codec.write(base),
codec.write(result.config.toScalafmt(base))
)
)

val res = JobProgress.Result(
config = serialised,
formattedFile = result.formattedFile,
fileDiff = diff,
configDiff = configDiff,
generation = result.generation,
generations = job.attributes.generations
)

q.offer(WebSocketFrame.Ping()) *>
sendProgress(res)
end sendJobUpdate

val repeatedUpdates = fs2.Stream
.repeatEval(sendJobUpdate)
.metered(1.second)
.interruptWhen(latch)

val send = fs2.Stream
.fromQueueUnterminated(q)
.concurrently(repeatedUpdates)

val receive: fs2.Pipe[IO, WebSocketFrame, Unit] =
_.evalTap:
case WebSocketFrame.Text("ping", _) =>
manager.heartbeatJob(uuid).void
case WebSocketFrame.Close(_) => latch.complete(Right(()))
case _ => IO.unit
.drain

wbs.withFilterPingPongs(false).build(send, receive)

case req @ POST -> Root / "api" / "create" =>
import org.http4s.circe.CirceEntityDecoder.*

req
.as[JobAttributes]
.flatMap: attrs =>
inline def error(msg: String) = BadRequest(msg)

if attrs.file.length > Limits.MaxFileLength then
error(
s"File length [${attrs.file.length}] above maximum length [${Limits.MaxFileLength}]"
)
else if attrs.populationSize > Limits.MaxPopulation then
error(s"Population size above maximum [${Limits.MaxPopulation}]")
else if attrs.generations > Limits.MaxGenerations then
error(
s"Number of generations above maximum [${Limits.MaxGenerations}]"
)
else manager.createJob(attrs).flatMap((id, _) => Ok(id.toString()))

end if

// state.ensuring
end routes

object JobManager:
import cats.syntax.all.*
def create =
(IO.ref(Map.empty).toResource, Supervisor[IO], Dispatcher[IO])
.mapN(JobManager(_, _, _))

class JobManager(
state: Ref[IO, Map[UUID, Job]],
supervisor: Supervisor[IO],
dispatcher: Dispatcher[IO]
):
def cleanupOldJobs =
IO.realTimeInstant.flatMap: now =>
state.update(
_.filterNot((_, job) => now.minusSeconds(60L).isAfter(job.heartbeat))
)

def reportNumberOfJobs =
state.get.flatTap: jobs =>
if jobs.size > 0 then
scribe.cats.io.info(s"Number of active jobs: ${jobs.size}")
else IO.unit

def heartbeatJob(id: UUID) =
IO.realTimeInstant.flatMap: now =>
state.modify: jobs =>
jobs.get(id) match
case None => jobs -> None
case Some(value) =>
jobs.updated(id, value.copy(heartbeat = now)) -> Some(value)

def generateJobId() = IO(UUID.randomUUID())

def get = state.get

def createJob(
attrs: JobAttributes
): IO[(UUID, Job)] =

import genovese.*
given RuntimeChecks = RuntimeChecks.None

for
id <- generateJobId()
now <- IO.realTimeInstant

job = Job(
heartbeat = now,
instruction = TrainingInstruction.Continue,
attributes = attrs,
result = None
)

trainingConfig = TrainingConfig(
populationSize = attrs.populationSize,
mutationRate = NormalisedFloat(0.1f),
steps = attrs.generations,
random =
attrs.seed.fold(scala.util.Random())(seed => scala.util.Random(seed)),
selection = Selection.Top(0.8)
)

_ <- state.update(_.updated(id, job))

handler = new EventHandler:
import TrainingEvent.*, TrainingInstruction.*
def handle[T](
t: TrainingEvent[T],
data: T | Null
): TrainingInstruction =
Option(jobs.get(id)) match
case None => Halt
case Some(value) if value.instruction == Halt => Halt
case _ =>
t match
case TopSpecimen =>
val specimen = Featureful.fromFeatures[ScalafmtConfigSubset](
TopSpecimen.cast(data)
)

dispatcher.unsafeRunSync(updateResult(id, specimen))
case ReportFitness =>
case EpochFinished =>
dispatcher.unsafeRunSync(
updateJob(
id,
job =>
job.copy(result =
job.result.map(
_.copy(generation = EpochFinished.cast(data) + 1)
)
)
)
)
case TrainingFinished =>
dispatcher.unsafeRunSync(haltJob(id, "training finished"))
case _ =>
end match

Continue
end handle

training = IO:
Train(
featureful = summon[Featureful[ScalafmtConfigSubset]],
config = trainingConfig,
fitness = cachedFitness(trainingConfig.random, 500)(
Fitness(fitness(attrs.file, _))
),
events = handler,
evaluator = ParallelCollectionsEvaluator
).train()

_ <- supervisor.supervise(training)
yield id -> job
end for
end createJob

def updateJob(id: UUID, f: Job => Job) =
state.update(_.updatedWith(id)(_.map(f)))

def updateResult(id: UUID, config: ScalafmtConfigSubset) =
updateJob(
id,
job =>
val formatted =
format(job.attributes.file, config).fold(_.toString(), identity)
job.copy(result = job.result match
case None =>
Some(TrainingResult(config, formatted, 1))
case Some(value) =>
Some(value.copy(config = config, formattedFile = formatted))
)
)

def haltJob(id: UUID, message: String): IO[Unit] =
state
.modify: jobs =>
jobs.get(id) match
case None => jobs -> IO.unit
case Some(value) =>
jobs.updated(
id,
value.copy(instruction = TrainingInstruction.Halt)
) -> scribe.cats.io
.info(s"Job [$id] halted: [$message]")
.flatten

end JobManager

0 comments on commit aed7617

Please sign in to comment.