Skip to content

Commit

Permalink
Make pylace default transitions hit everything
Browse files Browse the repository at this point in the history
I went in to all a slice row kernel in sams, but found that `sams`
didn't update state_alpha and `fast` didn't update anything but the row
and column reassignment. Fixed.
  • Loading branch information
Baxter Eaves committed Jan 31, 2024
1 parent 66e5a67 commit 73e134e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Initializing an engine with a codebook that has a different number of rows than the data will result in an error instead of printing a bunch on nonsense.
- Pylace default transition sets didn't hit all required transitions
- Typo in pylace internal `Dimension` class

## [python-0.6.0] - 2024-01-23

Expand Down
14 changes: 9 additions & 5 deletions pylace/lace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class Dimension:
Rows = 0
Colums = 1
Columns = 1


FN_IS_SYMMETRIC = {
Expand All @@ -28,8 +28,8 @@ class Dimension:


FN_DIMENSION = {
"mi": Dimension.Colums,
"depprob": Dimension.Colums,
"mi": Dimension.Columns,
"depprob": Dimension.Columns,
"rowsim": Dimension.Rows,
}

Expand Down Expand Up @@ -132,10 +132,10 @@ def infer_column_metadata(
StateTransition.view_alphas(),
StateTransition.row_assignment(RowKernel.sams()),
StateTransition.view_alphas(),
StateTransition.row_assignment(RowKernel.slice()),
StateTransition.component_parameters(),
StateTransition.column_assignment(ColumnKernel.gibbs()),
StateTransition.column_assignment(ColumnKernel.slice()),
StateTransition.view_alphas(),
StateTransition.state_alpha(),
StateTransition.feature_priors(),
],
"flat": [
Expand All @@ -151,8 +151,12 @@ def infer_column_metadata(
StateTransition.feature_priors(),
],
"fast": [
StateTransition.view_alphas(),
StateTransition.row_assignment(RowKernel.slice()),
StateTransition.component_parameters(),
StateTransition.feature_priors(),
StateTransition.column_assignment(ColumnKernel.slice()),
StateTransition.state_alpha(),
],
}

Expand Down

0 comments on commit 73e134e

Please sign in to comment.