Skip to content

Add tests for derived signal #850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def derived_signal_rw(
if raw_to_derived_datatype != set_derived_datatype:
msg = (
f"{raw_to_derived} has datatype {raw_to_derived_datatype} "
f"!= {set_derived_datatype} dataype {set_derived_datatype}"
f"!= {set_derived_datatype} datatype {set_derived_datatype}"
)
raise TypeError(msg)

Expand Down
10 changes: 5 additions & 5 deletions src/ophyd_async/core/_derived_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def derived_to_raw(self, *, derived1: float, derived2: float) -> MyRaw:
TransformT = TypeVar("TransformT", bound=Transform)


def filter_by_type(raw_devices: Mapping[str, Any], type_: type[T]) -> dict[str, T]:
def validate_by_type(raw_devices: Mapping[str, Any], type_: type[T]) -> dict[str, T]:
filtered_devices: dict[str, T] = {}
for name, device in raw_devices.items():
if not isinstance(device, type_):
Expand Down Expand Up @@ -91,21 +91,21 @@ def __init__(

@cached_property
def raw_locatables(self) -> dict[str, AsyncLocatable]:
return filter_by_type(self._raw_devices, AsyncLocatable)
return validate_by_type(self._raw_devices, AsyncLocatable)

@cached_property
def transform_readables(self) -> dict[str, AsyncReadable]:
return filter_by_type(self._transform_devices, AsyncReadable)
return validate_by_type(self._transform_devices, AsyncReadable)

@cached_property
def raw_and_transform_readables(self) -> dict[str, AsyncReadable]:
return filter_by_type(
return validate_by_type(
self._raw_devices | self._transform_devices, AsyncReadable
)

@cached_property
def raw_and_transform_subscribables(self) -> dict[str, Subscribable]:
return filter_by_type(self._raw_devices | self._transform_devices, Subscribable)
return validate_by_type(self._raw_devices | self._transform_devices, Subscribable)

def _complete_cached_reading(self) -> dict[str, Reading] | None:
if self._cached_readings and len(self._cached_readings) == len(
Expand Down
1 change: 1 addition & 0 deletions src/ophyd_async/epics/testing/_example_ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class EpicsTestCaDevice(EpicsDevice):
longstr: A[SignalRW[str], PvSuffix("longstr")]
longstr2: A[SignalRW[str], PvSuffix("longstr2.VAL$")]
a_bool: A[SignalRW[bool], PvSuffix("bool")]
slowseq: A[SignalRW[int], PvSuffix("slowseq")]
enum: A[SignalRW[EpicsTestEnum], PvSuffix("enum")]
enum2: A[SignalRW[EpicsTestEnum], PvSuffix("enum2")]
subset_enum: A[SignalRW[EpicsTestSubsetEnum], PvSuffix("subset_enum")]
Expand Down
5 changes: 5 additions & 0 deletions src/ophyd_async/epics/testing/test_records.db
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ record(mbbo, "$(device)enum_str_fallback") {
field(PINI, "YES")
}

record(seq, "$(device)slowseq") {
field(DLY1, "0.5")
field(LNK1, "$(device)slowseq.DESC")
}

record(waveform, "$(device)uint8a") {
field(NELM, "3")
field(FTVL, "UCHAR")
Expand Down
51 changes: 51 additions & 0 deletions tests/core/test_multi_derived_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
DerivedSignalFactory,
soft_signal_rw,
)
from ophyd_async.core._derived_signal_backend import (
DerivedSignalBackend, # noqa: PLC2701
SignalTransformer, # noqa: PLC2701
Transform, # noqa: PLC2701
)
from ophyd_async.core._signal import SignalRW # noqa: PLC2701
from ophyd_async.core._table import Table # noqa: PLC2701
from ophyd_async.sim import (
HorizontalMirror,
HorizontalMirrorDerived,
Expand Down Expand Up @@ -128,3 +135,47 @@ def test_mismatching_args():
jack22=soft_signal_rw(float),
distance=soft_signal_rw(float),
)


@pytest.fixture
def derived_signal_backend() -> DerivedSignalBackend:
return DerivedSignalBackend(Table, "derived_backend",
SignalTransformer(Transform, None, None))


async def test_derived_signal_backend_connect_pass(
derived_signal_backend: DerivedSignalBackend
) -> None:
result = await derived_signal_backend.connect(0.0)
assert result is None


def test_derived_signal_backend_set_value(
derived_signal_backend: DerivedSignalBackend
) -> None:
with pytest.raises(RuntimeError):
derived_signal_backend.set_value(0.0)


async def test_derived_signal_backend_put_fails(
derived_signal_backend: DerivedSignalBackend
) -> None:
with pytest.raises(RuntimeError):
await derived_signal_backend.put(value=None, wait=False)
with pytest.raises(RuntimeError):
await derived_signal_backend.put(value=None, wait=True)


def test_make_rw_signal_type_mismatch():
factory = DerivedSignalFactory(
TwoJackTransform,
set_derived=None,
distance=soft_signal_rw(float),
jack1=soft_signal_rw(float),
jack2=soft_signal_rw(float),
)
with pytest.raises(
ValueError,
match=re.escape("Must define a set_derived method to support derived")
):
factory._make_signal(signal_cls=SignalRW, datatype=Table, name="")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Public interface is

Suggested change
factory._make_signal(signal_cls=SignalRW, datatype=Table, name="")
factory.derived_signal_rw(datatype=Table, name="")

138 changes: 131 additions & 7 deletions tests/core/test_single_derived_signal.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these are possible without using patch or MagicMock or referring to private members of the variable, but let me know if any of them are difficult and I'll have a think how...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have to say am really struggling trying to find the way around using public interface without mock for e.g. ?

async def test_set_derived_not_initialized(derived_signal_backend: SignalBackend):
    with patch.object(
        derived_signal_backend.transformer,  # type: ignore
        "_set_derived",
        None,
    ):
        with pytest.raises(RuntimeError):
            await derived_signal_backend.put("name", True)

any ideas on that one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't think you can test that one with the public interface. Best I can do is:

async def test_set_derived_not_initialized():
    def _get(ts: int) -> float:
        return ts

    sig = derived_signal_r(_get, ts=soft_signal_rw(int, initial_value=4))
    with pytest.raises(
        RuntimeError,
        match="Cannot put as no set_derived method given",
    ):
        await sig._connector.backend.put(1.0, True)

If you've tried using the public interface and it's not possible then I'm happy with patches or private member variables

Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import asyncio
import re
from unittest.mock import call
from unittest.mock import call, patch

import pytest
from bluesky.protocols import Reading

from ophyd_async.core import (
derived_signal_r,
soft_signal_rw,
)
from ophyd_async.core._derived_signal import (
_get_first_arg_datatype, # noqa: PLC2701
_get_return_datatype, # noqa: PLC2701
derived_signal_rw, # noqa: PLC2701
)
from ophyd_async.core._derived_signal_backend import (
SignalTransformer, # noqa: PLC2701
Transform, # noqa: PLC2701
validate_by_type, # noqa: PLC2701
)
from ophyd_async.core._signal import SignalR # noqa: PLC2701
from ophyd_async.core._utils import StrictEnumMeta # noqa: PLC2701
from ophyd_async.testing import (
BeamstopPosition,
Exploder,
Expand All @@ -20,6 +33,16 @@
)


@pytest.fixture
def movable_beamstop() -> MovableBeamstop:
return MovableBeamstop("device")


@pytest.fixture
def readonly_beamstop() -> ReadOnlyBeamstop:
return ReadOnlyBeamstop("device")


@pytest.mark.parametrize(
"x, y, position",
[
Expand Down Expand Up @@ -70,19 +93,18 @@ async def test_monitoring_position(cls: type[ReadOnlyBeamstop | MovableBeamstop]
assert results.empty()


async def test_setting_position():
inst = MovableBeamstop("inst")
async def test_setting_position(movable_beamstop: MovableBeamstop):
# Connect in mock mode so we can see what would have been set
await inst.connect(mock=True)
m = get_mock(inst)
await inst.position.set(BeamstopPosition.OUT_OF_POSITION)
await movable_beamstop.connect(mock=True)
m = get_mock(movable_beamstop)
await movable_beamstop.position.set(BeamstopPosition.OUT_OF_POSITION)
assert m.mock_calls == [
call.position.put(BeamstopPosition.OUT_OF_POSITION, wait=True),
call.x.put(3, wait=True),
call.y.put(5, wait=True),
]
m.reset_mock()
await inst.position.set(BeamstopPosition.IN_POSITION)
await movable_beamstop.position.set(BeamstopPosition.IN_POSITION)
assert m.mock_calls == [
call.position.put(BeamstopPosition.IN_POSITION, wait=True),
call.x.put(0, wait=True),
Expand Down Expand Up @@ -118,3 +140,105 @@ def _get_position(x: float, y: float) -> BeamstopPosition:
derived_signal_r(
_get_position, foo=soft_signal_rw(float), bar=soft_signal_rw(float)
)


@patch("ophyd_async.core._derived_signal_backend.TYPE_CHECKING", True)
def test_validate_by_type(
movable_beamstop: MovableBeamstop,
readonly_beamstop: ReadOnlyBeamstop
) -> None:
invalid_devices_dict = {device.name: device for device in [movable_beamstop,
readonly_beamstop]}
with pytest.raises(TypeError):
validate_by_type(invalid_devices_dict, MovableBeamstop)
with pytest.raises(TypeError):
validate_by_type({movable_beamstop.name: movable_beamstop}, ReadOnlyBeamstop)
valid_devices_dict = {device.name: device for device in [movable_beamstop,
MovableBeamstop("mvb2")]}
assert validate_by_type(valid_devices_dict, MovableBeamstop) == valid_devices_dict


@pytest.fixture
def null_transformer() -> SignalTransformer:
return SignalTransformer(Transform, None, None)


@pytest.fixture
def new_readings() -> dict[str, Reading]:
return {"device-position": Reading(value=0.0, timestamp=0.0)}


async def test_set_derived_not_initialized(null_transformer: SignalTransformer):
with pytest.raises(RuntimeError):
await null_transformer.set_derived("name", None)


async def test_get_transform_cached(
null_transformer: SignalTransformer,
new_readings: dict[str, Reading]
) -> None:
with patch.object(null_transformer, '_cached_readings', new_readings):
with patch.object(null_transformer, 'raw_and_transform_subscribables', {"device": SignalR}): # noqa: E501
assert null_transformer._cached_readings == new_readings
r = await null_transformer.get_transform()
assert isinstance(r, Transform)


def test_update_cached_reading_non_initialized(
null_transformer: SignalTransformer,
new_readings: dict[str, Reading]
) -> None:
with pytest.raises(RuntimeError):
null_transformer._update_cached_reading(new_readings)


def test_update_cached_reading_initialized(
null_transformer: SignalTransformer,
new_readings: dict[str, Reading]
) -> None:
null_transformer._cached_readings = {}
null_transformer._update_cached_reading(new_readings)
assert null_transformer._cached_readings == new_readings


@patch("ophyd_async.core._utils.Callback")
def test_set_callback_already_set(
mock_class,
null_transformer: SignalTransformer
) -> None:
device_name = "device"
with patch.object(null_transformer, '_derived_callbacks', {device_name: mock_class}):
with pytest.raises(
RuntimeError,
match=re.escape(f"Callback already set for {device_name}")
):
null_transformer.set_callback(device_name, mock_class)


@patch("ophyd_async.core._derived_signal.get_type_hints", return_value={})
def test_get_return_datatype_no_type(movable_beamstop: MovableBeamstop):
with pytest.raises(
TypeError,
match=re.escape("does not have a type hint for it's return value")
):
_get_return_datatype(movable_beamstop._get_position)


def test_get_return_datatype(movable_beamstop: MovableBeamstop):
result = _get_return_datatype(movable_beamstop._get_position)
assert isinstance(result, StrictEnumMeta)


@patch("ophyd_async.core._derived_signal.get_type_hints", return_value={})
def test_get_first_arg_datatype_no_type(movable_beamstop: MovableBeamstop):
with pytest.raises(
TypeError,
match=re.escape("does not have a type hinted argument")
):
_get_first_arg_datatype(movable_beamstop._set_from_position)


def test_derived_signal_rw_type_error(movable_beamstop: MovableBeamstop):
with patch.object(movable_beamstop, '_set_from_position', movable_beamstop._get_position): # noqa: E501
with pytest.raises(TypeError):
derived_signal_rw(movable_beamstop._get_position, movable_beamstop._set_from_position) # noqa: E501
22 changes: 22 additions & 0 deletions tests/epics/signal/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,25 @@ def a_plan():
assert yaml.safe_load(actual_file) == yaml.safe_load(expected_file)

RE(a_plan())


@pytest.mark.parametrize("protocol", get_args(Protocol))
async def test_put_completion(
RE, ioc_devices: EpicsTestIocAndDevices, protocol: Protocol
):
# Check that we can put to an epics signal and wait for put completion
slow_seq_pv = ioc_devices.get_pv(protocol, "slowseq")
slow_seq = epics_signal_rw(int, slow_seq_pv)
await slow_seq.connect()

# First, do a set with blocking and make sure it takes a while
start = time.time()
await slow_seq.set(1, wait=True)
stop = time.time()
assert stop - start == pytest.approx(0.5, rel=0.1)

# Then, make sure if we don't wait it returns ~instantly
start = time.time()
await slow_seq.set(2, wait=False)
stop = time.time()
assert stop - start < 0.1
1 change: 1 addition & 0 deletions tests/epics/signal/test_yaml_save_ca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ bool_unnamed: true
enum: Bbb
enum2: Bbb
enum_str_fallback: Bbb
slowseq: 0
float32a: [1.9999999949504854e-06, -123.12300109863281]
float64a: [0.1, -12345678.123]
float_prec_0: 3
Expand Down
1 change: 1 addition & 0 deletions tests/epics/signal/test_yaml_save_pva.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ enum2: Bbb
enum_str_fallback: Bbb
float32a: [1.9999999949504854e-06, -123.12300109863281]
float64a: [0.1, -12345678.123]
slowseq: 0
float_prec_0: 3.0
int16a: [-32768, 32767]
int32a: [-2147483648, 2147483647]
Expand Down
Loading