Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 61b5731

Browse files
committedMay 17, 2024
Added exported pytorch models for NHL94 to be use with Pythorch C++
1 parent 21f426a commit 61b5731

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed
 

‎export_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from stable_baselines3.common.policies import BasePolicy
1010

1111

12-
MODEL_PATH = "./models/ScoreGoal.zip"
12+
#MODEL_PATH = "./models/ScoreGoal.zip"
13+
#jit_path = "ScoreGoal.pt"
14+
15+
MODEL_PATH = "./models/DefenseZone.zip"
16+
jit_path = "DefenseZone.pt"
17+
1318

1419
class OnnxableSB3Policy(th.nn.Module):
1520
def __init__(self, policy: BasePolicy):
@@ -43,7 +48,7 @@ def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tens
4348

4449
# Pytorch JIT
4550
# See "ONNX export" for imports and OnnxablePolicy
46-
jit_path = "ppo_traced.pt"
51+
4752

4853
# Trace and optimize the module
4954
traced_module = th.jit.trace(onnx_policy.eval(), dummy_input)

‎models/DefenseZone.pt

567 KB
Binary file not shown.

‎models/ScoreGoal.pt

567 KB
Binary file not shown.

0 commit comments

Comments
 (0)
Please sign in to comment.