Skip to content

Commit 46b9b42

Browse files
committed
Save steps with sockets
1 parent f732fb4 commit 46b9b42

File tree

3 files changed

+98
-61
lines changed

3 files changed

+98
-61
lines changed

VSharp.Explorer/AISearcher.fs

+94-59
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,25 @@ namespace VSharp.Explorer
22

33
open System.Collections.Generic
44
open Microsoft.ML.OnnxRuntime
5+
open System.IO
6+
open System
7+
open System.Net
8+
open System.Net.Sockets
9+
open System.Text
510
open System.Text.Json
611
open VSharp
712
open VSharp.IL.Serializer
813
open VSharp.ML.GameServer.Messages
9-
open System.IO
1014

1115
type AIMode =
1216
| Runner
1317
| TrainingSendModel
1418
| TrainingSendEachStep
1519

16-
type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
17-
let stepsToSwitchToAI =
18-
match aiAgentTrainingMode with
19-
| None -> 0u<step>
20-
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
21-
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
22-
23-
let stepsToPlay =
24-
match aiAgentTrainingMode with
25-
| None -> 0u<step>
26-
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
27-
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay
28-
29-
let mutable lastCollectedStatistics =
30-
Statistics ()
31-
let mutable defaultSearcherSteps = 0u<step>
32-
let mutable (gameState: Option<GameState>) =
33-
None
34-
let mutable useDefaultSearcher =
35-
stepsToSwitchToAI > 0u<step>
36-
let mutable afterFirstAIPeek = false
37-
let mutable incorrectPredictedStateId =
38-
false
39-
40-
let defaultSearcher =
41-
let pickSearcher =
42-
function
43-
| BFSMode -> BFSSearcher () :> IForwardSearcher
44-
| DFSMode -> DFSSearcher () :> IForwardSearcher
45-
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."
46-
47-
match aiAgentTrainingMode with
48-
| None -> BFSSearcher () :> IForwardSearcher
49-
| Some (SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
50-
| Some (SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
51-
52-
let mutable stepsPlayed = 0u<step>
53-
54-
let isInAIMode () =
55-
(not useDefaultSearcher) && afterFirstAIPeek
56-
57-
let q = ResizeArray<_> ()
58-
let availableStates = HashSet<_> ()
59-
60-
let updateGameState (delta: GameState) =
20+
module GameUtils =
21+
let updateGameState (delta: GameState) (gameState: Option<GameState>) =
6122
match gameState with
62-
| None -> gameState <- Some delta
23+
| None -> Some delta
6324
| Some s ->
6425
let updatedBasicBlocks =
6526
delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
@@ -106,7 +67,52 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
10667
)
10768
)
10869

109-
gameState <- Some <| GameState (vertices.ToArray (), states, edges.ToArray ())
70+
Some <| GameState (vertices.ToArray (), states, edges.ToArray ())
71+
type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
72+
let stepsToSwitchToAI =
73+
match aiAgentTrainingMode with
74+
| None -> 0u<step>
75+
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
76+
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
77+
78+
let stepsToPlay =
79+
match aiAgentTrainingMode with
80+
| None -> 0u<step>
81+
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
82+
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay
83+
84+
let mutable lastCollectedStatistics =
85+
Statistics ()
86+
let mutable defaultSearcherSteps = 0u<step>
87+
let mutable (gameState: Option<GameState>) =
88+
None
89+
let mutable useDefaultSearcher =
90+
stepsToSwitchToAI > 0u<step>
91+
let mutable afterFirstAIPeek = false
92+
let mutable incorrectPredictedStateId =
93+
false
94+
95+
let defaultSearcher =
96+
let pickSearcher =
97+
function
98+
| BFSMode -> BFSSearcher () :> IForwardSearcher
99+
| DFSMode -> DFSSearcher () :> IForwardSearcher
100+
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."
101+
102+
match aiAgentTrainingMode with
103+
| None -> BFSSearcher () :> IForwardSearcher
104+
| Some (SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
105+
| Some (SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
106+
107+
let mutable stepsPlayed = 0u<step>
108+
109+
let isInAIMode () =
110+
(not useDefaultSearcher) && afterFirstAIPeek
111+
112+
let q = ResizeArray<_> ()
113+
let availableStates = HashSet<_> ()
114+
115+
110116

111117
let init states =
112118
q.AddRange states
@@ -153,7 +159,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
153159
if Seq.length availableStates > 0 then
154160
let gameStateDelta =
155161
collectGameStateDelta ()
156-
updateGameState gameStateDelta
162+
gameState <- GameUtils.updateGameState gameStateDelta gameState
157163
let statistics =
158164
computeStatistics gameState.Value
159165
Application.applicationGraphDelta.Clear ()
@@ -168,7 +174,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
168174
else
169175
let gameStateDelta =
170176
collectGameStateDelta ()
171-
updateGameState gameStateDelta
177+
gameState <- GameUtils.updateGameState gameStateDelta gameState
172178
let statistics =
173179
computeStatistics gameState.Value
174180

@@ -184,12 +190,12 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
184190
else
185191
let toPredict =
186192
match aiMode with
187-
| TrainingSendEachStep ->
193+
| TrainingSendEachStep
194+
| TrainingSendModel ->
188195
if stepsPlayed > 0u<step> then
189196
gameStateDelta
190197
else
191198
gameState.Value
192-
| TrainingSendModel
193199
| Runner -> gameState.Value
194200

195201
let stateId = oracle.Predict toPredict
@@ -225,15 +231,37 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
225231
let arrayOutputJson =
226232
JsonSerializer.Serialize arrayOutput
227233
arrayOutputJson
228-
229-
let writeStep (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) (filePath: string) =
234+
let stepToString (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
230235
let gameStateJson =
231236
JsonSerializer.Serialize gameState
232-
let stateJson = serializeOutput output
233-
File.WriteAllText (filePath + "_gameState", gameStateJson)
234-
File.WriteAllText (filePath + "_nn_output", stateJson)
237+
let outputJson = serializeOutput output
238+
let DELIM = Environment.NewLine
239+
let strToSaveAsList =
240+
[
241+
gameStateJson
242+
DELIM
243+
outputJson
244+
DELIM
245+
]
246+
String.concat " " strToSaveAsList
235247

236248
let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>) =
249+
let host = "localhost"
250+
let port =
251+
match aiAgentTrainingModelOptions with
252+
| Some options -> options.port
253+
| None -> 0
254+
255+
let client = new TcpClient ()
256+
client.Connect (host, port)
257+
client.SendBufferSize <- 2048
258+
let stream = client.GetStream ()
259+
260+
let saveStep (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
261+
let bytes =
262+
Encoding.UTF8.GetBytes (stepToString gameState output)
263+
stream.Write (bytes, 0, bytes.Length)
264+
stream.Flush ()
237265

238266
let sessionOptions =
239267
if useGPU then
@@ -254,8 +282,15 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
254282
let feedback (x: Feedback) = ()
255283

256284
let mutable stepsPlayed = 0
285+
let mutable currentGameState = None
257286

258-
let predict (gameState: GameState) =
287+
let predict (gameStateOrDelta: GameState) =
288+
let _ =
289+
match aiAgentTrainingModelOptions with
290+
| Some _ when not (stepsPlayed = 0) ->
291+
currentGameState <- GameUtils.updateGameState gameStateOrDelta currentGameState
292+
| _ -> currentGameState <- Some gameStateOrDelta
293+
let gameState = currentGameState.Value
259294
let stateIds =
260295
Dictionary<uint<stateId>, int> ()
261296
let verticesIds =
@@ -441,7 +476,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
441476

442477
let _ =
443478
match aiAgentTrainingModelOptions with
444-
| Some options -> writeStep gameState output (options.outputDirectory + ($"/{stepsPlayed}"))
479+
| Some _ -> saveStep gameStateOrDelta output
445480
| None -> ()
446481

447482
stepsPlayed <- stepsPlayed + 1

VSharp.Explorer/Options.fs

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ type AIAgentTrainingModelOptions =
7878
{
7979
aiAgentTrainingOptions: AIAgentTrainingOptions
8080
outputDirectory: string
81+
port: int
8182
}
8283

8384

VSharp.ML.GameServer.Runner/Main.fs

+3-2
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ let generateDataForPretraining outputDirectory datasetBasePath (maps: ResizeArra
334334
API.Reset ()
335335
HashMap.hashMap.Clear ()
336336

337-
let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) =
337+
let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) (port: int) =
338338
printfn $"Run infer on {gameMap.MapName} have started."
339339

340340
let aiTrainingOptions =
@@ -360,6 +360,7 @@ let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: st
360360
{
361361
aiAgentTrainingOptions = aiTrainingOptions
362362
outputDirectory = outputDirectory
363+
port = port
363364
}
364365
)
365366

@@ -473,7 +474,7 @@ let main args =
473474
let optimize =
474475
(args.TryGetResult <@ Optimize @>).IsSome
475476

476-
runTrainingSendModelMode outputDirectory gameMap model useGPU optimize
477+
runTrainingSendModelMode outputDirectory gameMap model useGPU optimize port
477478
| Mode.Generator ->
478479
let datasetDescription =
479480
args.GetResult <@ DatasetDescription @>

0 commit comments

Comments
 (0)