Skip to content

Commit 3276492

Browse files
Merge pull request #149 from GFNOrg/no_more_class_factories
No more class factories
2 parents eedc7e8 + ae3fa2e commit 3276492

22 files changed

+531
-380
lines changed

docs/requirements_docs.txt

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
pre-commit
22
black
33
pytest
4-
sphinx==5.3.0
5-
myst-parser==0.18.1
6-
sphinx_rtd_theme==1.1.1
7-
sphinx-math-dollar==1.2.1
8-
sphinx-autoapi==2.0.0
4+
sphinx>=6.2.1
5+
myst-parser
6+
sphinx_rtd_theme
7+
sphinx-math-dollar
8+
sphinx-autoapi>=3.0.0
99
renku-sphinx-theme

pyproject.toml

+2-4
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ myst-parser = { version = "*", optional = true }
3333
pre-commit = { version = "*", optional = true }
3434
pytest = { version = "*", optional = true }
3535
renku-sphinx-theme = { version = "*", optional = true }
36-
sphinx = { version = "*", optional = true }
36+
sphinx = { version = ">=6.2.1", optional = true }
3737
sphinx_rtd_theme = { version = "*", optional = true }
38-
sphinx-autoapi = { version = "*", optional = true }
38+
sphinx-autoapi = { version = ">=3.0.0", optional = true }
3939
sphinx-math-dollar = { version = "*", optional = true }
4040
tox = { version = "*", optional = true }
4141

@@ -85,8 +85,6 @@ all = [
8585
"Homepage" = "https://gfn.readthedocs.io/en/latest/"
8686
"Bug Tracker" = "https://github.com/saleml/gfn/issues"
8787

88-
89-
9088
[tool.black]
9189
py36 = true
9290
include = '\.pyi?$'

src/gfn/containers/replay_buffer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(
4646
self.training_objects = Transitions(env)
4747
self.objects_type = "transitions"
4848
elif objects_type == "states":
49-
self.training_objects = env.States.from_batch_shape((0,))
50-
self.terminating_states = env.States.from_batch_shape((0,))
49+
self.training_objects = env.states_from_batch_shape((0,))
50+
self.terminating_states = env.states_from_batch_shape((0,))
5151
self.objects_type = "states"
5252
else:
5353
raise ValueError(f"Unknown objects_type: {objects_type}")

src/gfn/containers/trajectories.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,11 @@ def __init__(
7979
self.states = (
8080
states
8181
if states is not None
82-
else env.States.from_batch_shape(batch_shape=(0, 0))
82+
else env.states_from_batch_shape((0, 0))
8383
)
8484
assert len(self.states.batch_shape) == 2
8585
self.actions = (
86-
actions
87-
if actions is not None
88-
else env.Actions.make_dummy_actions(batch_shape=(0, 0))
86+
actions if actions is not None else env.actions_from_batch_shape((0, 0))
8987
)
9088
assert len(self.actions.batch_shape) == 2
9189
self.when_is_done = (
@@ -253,9 +251,13 @@ def extend(self, other: Trajectories) -> None:
253251

254252
# Either set, or append, estimator outputs if they exist in the submitted
255253
# trajectory.
256-
if self.estimator_outputs is None and is_tensor(other.estimator_outputs):
254+
if self.estimator_outputs is None and isinstance(
255+
other.estimator_outputs, Tensor
256+
):
257257
self.estimator_outputs = other.estimator_outputs
258-
elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs):
258+
elif isinstance(self.estimator_outputs, Tensor) and isinstance(
259+
other.estimator_outputs, Tensor
260+
):
259261
batch_shape = self.actions.batch_shape
260262
n_bs = len(batch_shape)
261263
output_dtype = self.estimator_outputs.dtype

src/gfn/containers/transitions.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@ def __init__(
6868
self.states = (
6969
states
7070
if states is not None
71-
else env.States.from_batch_shape(batch_shape=(0,))
71+
else env.states_from_batch_shape(batch_shape=(0,))
7272
)
7373
assert len(self.states.batch_shape) == 1
7474

7575
self.actions = (
76-
actions
77-
if actions is not None
78-
else env.Actions.make_dummy_actions(batch_shape=(0,))
76+
actions if actions is not None else env.actions_from_batch_shape((0,))
7977
)
8078
self.is_done = (
8179
is_done
@@ -85,7 +83,7 @@ def __init__(
8583
self.next_states = (
8684
next_states
8785
if next_states is not None
88-
else env.States.from_batch_shape(batch_shape=(0,))
86+
else env.states_from_batch_shape(batch_shape=(0,))
8987
)
9088
assert (
9189
len(self.next_states.batch_shape) == 1

0 commit comments

Comments
 (0)