Skip to content

Commit

Permalink
WIP documentation generation via gh actions (facebookresearch#107)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#107

Differential Revision: D67157339
  • Loading branch information
Dmytro Korenkevych authored and facebook-github-bot committed Dec 12, 2024
1 parent c9576ed commit dedcf28
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest]
env:
OS: ${{ matrix.os }}
PYTHON: '3.9'
PYTHON: "3.10"
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: "3.10"
- name: Install dependencies
run: |
pip install pdoc3
pip install pytest
pip install pytest-cov
pip install -e .
- name: Update documentation
run: |
pdoc -o ./html pearl
- name: Generate coverage report
run: |
pytest --cov=./ --cov-report=xml
Expand Down
6 changes: 3 additions & 3 deletions pearl/api/action_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class ActionResult:
reward: Reward
terminated: bool
truncated: bool
info: dict[str, Any] | None = None
cost: float | None = None
available_action_space: ActionSpace | None = None
info: Optional[dict[str, Any]] = None
cost: Optional[float] = None
available_action_space: Optional[ActionSpace] = None

@property
def done(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HistorySummarizationModule(ABC, nn.Module):

@abstractmethod
def summarize_history(
self, observation: Observation, action: Action | None
self, observation: Observation, action: Optional[Action]
) -> SubjectiveState:
pass

Expand Down
16 changes: 8 additions & 8 deletions pearl/neural_networks/common/epistemic_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
super().__init__()

@abstractmethod
def forward(self, x: Tensor, z: Tensor) -> Tensor:
def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""
Input:
x: Feature vector of state action pairs
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
self.prior_net: nn.Module = mlp_block(input_dim, hidden_dims, output_dim).eval()
self.scale = scale

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Input:
x: Tensor. Feature vector input
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(

self._resample_epistemic_index()

def forward(self, x: Tensor, z: Tensor, persistent: bool = False) -> Tensor:
def forward(self, x: torch.Tensor, z: torch.Tensor, persistent: bool = False) -> torch.Tensor:
"""
Input:
x: Feature vector of state action pairs
Expand Down Expand Up @@ -176,14 +176,14 @@ def generate_params_buffers(self) -> None:
self.params, self.buffers = torch.func.stack_module_state(self.models)

def call_single_model(
self, params: dict[str, Any], buffers: dict[str, Any], data: Tensor
) -> Tensor:
self, params: dict[str, Any], buffers: dict[str, Any], data: torch.Tensor
) -> torch.Tensor:
"""
Method for parallelizing priornet forward passes with torch.vmap.
"""
return torch.func.functional_call(self.base_model, (params, buffers), (data,))

def forward(self, x: Tensor, z: Tensor) -> Tensor:
def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""
Perform forward pass on the priornet ensemble and weight by epistemic index
x and z are assumed to already be formatted.
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
self.input_dim, self.prior_hiddens, self.output_dim, self.index_dim
)

def format_xz(self, x: Tensor, z: Tensor) -> Tensor:
def format_xz(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""
Take cartesian product of x and z and concatenate for forward pass.
Input:
Expand All @@ -251,7 +251,7 @@ def format_xz(self, x: Tensor, z: Tensor) -> Tensor:
xz = torch.cat([x_expanded, z_expanded], dim=-1)
return xz.view(batch_size * num_indices, d + self.index_dim)

def forward(self, x: Tensor, z: Tensor, persistent: bool = False) -> Tensor:
def forward(self, x: torch.Tensor, z: torch.Tensor, persistent: bool = False) -> torch.Tensor:
"""
Input:
x: Feature vector containing item and user embeddings and interactions
Expand Down

0 comments on commit dedcf28

Please sign in to comment.