@@ -37,7 +37,7 @@ def __init__(
37
37
super ().__init__ ()
38
38
39
39
@abstractmethod
40
- def forward (self , x : Tensor , z : Tensor ) -> Tensor :
40
+ def forward (self , x : torch . Tensor , z : torch . Tensor ) -> torch . Tensor :
41
41
"""
42
42
Input:
43
43
x: Feature vector of state action pairs
@@ -71,7 +71,7 @@ def __init__(
71
71
self .prior_net : nn .Module = mlp_block (input_dim , hidden_dims , output_dim ).eval ()
72
72
self .scale = scale
73
73
74
- def forward (self , x : Tensor ) -> Tensor :
74
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
75
75
"""
76
76
Input:
77
77
x: Tensor. Feature vector input
@@ -116,7 +116,7 @@ def __init__(
116
116
117
117
self ._resample_epistemic_index ()
118
118
119
- def forward (self , x : Tensor , z : Tensor , persistent : bool = False ) -> Tensor :
119
+ def forward (self , x : torch . Tensor , z : torch . Tensor , persistent : bool = False ) -> torch . Tensor :
120
120
"""
121
121
Input:
122
122
x: Feature vector of state action pairs
@@ -176,14 +176,14 @@ def generate_params_buffers(self) -> None:
176
176
self .params , self .buffers = torch .func .stack_module_state (self .models )
177
177
178
178
def call_single_model (
179
- self , params : dict [str , Any ], buffers : dict [str , Any ], data : Tensor
180
- ) -> Tensor :
179
+ self , params : dict [str , Any ], buffers : dict [str , Any ], data : torch . Tensor
180
+ ) -> torch . Tensor :
181
181
"""
182
182
Method for parallelizing priornet forward passes with torch.vmap.
183
183
"""
184
184
return torch .func .functional_call (self .base_model , (params , buffers ), (data ,))
185
185
186
- def forward (self , x : Tensor , z : Tensor ) -> Tensor :
186
+ def forward (self , x : torch . Tensor , z : torch . Tensor ) -> torch . Tensor :
187
187
"""
188
188
Perform forward pass on the priornet ensemble and weight by epistemic index
189
189
x and z are assumed to already be formatted.
@@ -235,7 +235,7 @@ def __init__(
235
235
self .input_dim , self .prior_hiddens , self .output_dim , self .index_dim
236
236
)
237
237
238
- def format_xz (self , x : Tensor , z : Tensor ) -> Tensor :
238
+ def format_xz (self , x : torch . Tensor , z : torch . Tensor ) -> torch . Tensor :
239
239
"""
240
240
Take cartesian product of x and z and concatenate for forward pass.
241
241
Input:
@@ -251,7 +251,7 @@ def format_xz(self, x: Tensor, z: Tensor) -> Tensor:
251
251
xz = torch .cat ([x_expanded , z_expanded ], dim = - 1 )
252
252
return xz .view (batch_size * num_indices , d + self .index_dim )
253
253
254
- def forward (self , x : Tensor , z : Tensor , persistent : bool = False ) -> Tensor :
254
+ def forward (self , x : torch . Tensor , z : torch . Tensor , persistent : bool = False ) -> torch . Tensor :
255
255
"""
256
256
Input:
257
257
x: Feature vector containing item and user embeddings and interactions
0 commit comments