Skip to content

Commit d0d01d6

Browse files
committed
Implement a new process for saving steps in validation with model
1 parent c9728c8 commit d0d01d6

File tree

3 files changed

+51
-48
lines changed

3 files changed

+51
-48
lines changed

VSharp.Explorer/AISearcher.fs

+9-37
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ module GameUtils =
6565
)
6666

6767
Some <| GameState (vertices.ToArray (), states, edges.ToArray ())
68+
69+
let convertOutputToJson (output: IDisposableReadOnlyCollection<OrtValue>) =
70+
seq { 0 .. output.Count - 1 }
71+
|> Seq.map (fun i -> output[i].GetTensorDataAsSpan<float32>().ToArray ())
72+
6873
type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
6974
let stepsToSwitchToAI =
7075
match aiAgentTrainingMode with
@@ -220,44 +225,8 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
220225
let numOfStateAttributes = 7
221226
let numOfHistoryEdgeAttributes = 2
222227

223-
let serializeOutput (output: IDisposableReadOnlyCollection<OrtValue>) =
224-
let arrayOutput =
225-
seq { 0 .. output.Count - 1 }
226-
|> Seq.map (fun i -> output[i].GetTensorDataAsSpan<float32>().ToArray ())
227-
228-
let arrayOutputJson =
229-
JsonSerializer.Serialize arrayOutput
230-
arrayOutputJson
231-
232-
let stepToString (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
233-
let gameStateJson =
234-
JsonSerializer.Serialize gameState
235-
let outputJson = serializeOutput output
236-
let DELIM = Environment.NewLine
237-
let strToSaveAsList =
238-
[
239-
gameStateJson
240-
DELIM
241-
outputJson
242-
DELIM
243-
]
244-
String.concat " " strToSaveAsList
245228

246229
let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>) =
247-
let stream =
248-
match aiAgentTrainingModelOptions with
249-
| Some options -> options.stream
250-
| None -> None
251-
252-
let saveStep (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
253-
match stream with
254-
| Some stream ->
255-
let bytes =
256-
Encoding.UTF8.GetBytes (stepToString gameState output)
257-
stream.Write (bytes, 0, bytes.Length)
258-
stream.Flush ()
259-
| None -> ()
260-
261230
let sessionOptions =
262231
if useGPU then
263232
SessionOptions.MakeSessionOptionWithCudaProvider (0)
@@ -471,7 +440,10 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
471440

472441
let _ =
473442
match aiAgentTrainingModelOptions with
474-
| Some _ -> saveStep gameStateOrDelta output
443+
| Some aiAgentOptions ->
444+
aiAgentOptions.stepSaver (
445+
AIGameStep (gameState = gameStateOrDelta, output = GameUtils.convertOutputToJson output)
446+
)
475447
| None -> ()
476448

477449
stepsPlayed <- stepsPlayed + 1

VSharp.Explorer/Options.fs

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ open System.Diagnostics
44
open System.IO
55
open VSharp.ML.GameServer.Messages
66
open System.Net.Sockets
7+
open Microsoft.ML.OnnxRuntime
78

89
type searchMode =
910
| DFSMode
@@ -54,6 +55,17 @@ type Oracle =
5455
/// <param name="mapName">Name of map to play.</param>
5556
/// <param name="mapName">Name of map to play.</param>
5657
58+
[<Struct>]
59+
type AIGameStep =
60+
interface IRawOutgoingMessageBody
61+
val GameState: GameState
62+
val Output: seq<array<float32>>
63+
new(gameState, output) =
64+
{
65+
GameState = gameState
66+
Output = output
67+
}
68+
5769

5870
type AIBaseOptions =
5971
{
@@ -79,7 +91,7 @@ type AIAgentTrainingModelOptions =
7991
{
8092
aiAgentTrainingOptions: AIAgentTrainingOptions
8193
outputDirectory: string
82-
stream: Option<NetworkStream> // use it for sending steps
94+
stepSaver: AIGameStep -> Unit
8395
}
8496

8597

VSharp.ML.GameServer.Runner/Main.fs

+29-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
open System.IO
2+
open System.Text
3+
open System.Text.Json
24
open System.Net.Sockets
35
open System.Reflection
46
open Argu
@@ -370,20 +372,16 @@ let runTrainingSendModelMode
370372
oracle = None
371373
}
372374

373-
let stream =
374-
let host = "localhost" // TODO: working within a local network
375-
let client = new TcpClient ()
376-
client.Connect (host, port)
377-
client.SendBufferSize <- 2048
378-
Some <| client.GetStream ()
379-
375+
let mutable steps = []
376+
let stepSaver (aiGameStep: AIGameStep) = steps <- aiGameStep :: steps in
377+
()
380378
let aiOptions: AIOptions =
381379
Training (
382380
SendModel
383381
{
384382
aiAgentTrainingOptions = aiTrainingOptions
385383
outputDirectory = outputDirectory
386-
stream = stream
384+
stepSaver = stepSaver
387385
}
388386
)
389387

@@ -410,8 +408,29 @@ let runTrainingSendModelMode
410408

411409
printfn
412410
$"Running for {gameMap.MapName} finished with coverage {explorationResult.ActualCoverage}, tests {explorationResult.TestsCount}, steps {explorationResult.StepsCount},errors {explorationResult.ErrorsCount}."
413-
414-
411+
let steps = List.rev steps
412+
let stream =
413+
let host = "localhost" // TODO: working within a local network
414+
let client = new TcpClient ()
415+
client.Connect (host, port)
416+
client.SendBufferSize <- 4096
417+
client.GetStream ()
418+
419+
let needToSendSteps =
420+
let buffer = Array.zeroCreate<byte> 1
421+
let bytesRead = stream.Read (buffer, 0, 1)
422+
if bytesRead = 0 then
423+
failwith "Connection is closed?!"
424+
buffer.[0] <> byte 0
425+
426+
if needToSendSteps then
427+
let bytes =
428+
Encoding.UTF8.GetBytes (JsonSerializer.Serialize steps)
429+
stream.Write (bytes, 0, bytes.Length)
430+
stream.Flush ()
431+
stream.Close ()
432+
else
433+
()
415434

416435
[<EntryPoint>]
417436
let main args =

0 commit comments

Comments
 (0)