Skip to content

RFC36 async testbench functions #291

@hstarmans

Description

@hstarmans

Thanks for all the effort still being put in this library.
It would be great if you could support RFC36, async testbench functions. The yield syntax creates some issues for me further down the road.
I now have the following fix, which might be of use. Code works and is tested. As a disclaimer, I partly generated the code using OpenAI 04 model. For my current case, I only need SPIGatewareTestCase.
Furthermore, in initialize_signal, you need to set self.sim = sim, which is sub optimal.

utils.py

import os
import math
import unittest
from functools import wraps

from amaranth.sim import Simulator


def async_test_case(process_function, *, domain="sync"):
    """Decorator for writing async test cases using Amaranth's new async testbench system."""

    @wraps(process_function)
    def run_test(self):
        self.domain = domain
        self._ensure_clocks_present()

        async def testbench(sim):
            await self.initialize_signals(sim)
            await process_function(self, sim)

        self.sim.add_testbench(testbench)
        self.simulate(vcd_suffix=process_function.__name__)

    return run_test


def usb_domain_test_case(process_function):
    return async_test_case(process_function, domain="usb")


def fast_domain_test_case(process_function):
    return async_test_case(process_function, domain="fast")


def ss_domain_test_case(process_function):
    return async_test_case(process_function, domain="ss")


class LunaGatewareTestCase(unittest.TestCase):
    domain = "sync"

    FRAGMENT_UNDER_TEST = None
    FRAGMENT_ARGUMENTS = {}

    FAST_CLOCK_FREQUENCY = None
    SYNC_CLOCK_FREQUENCY = 120e6
    USB_CLOCK_FREQUENCY = None
    SS_CLOCK_FREQUENCY = None

    def instantiate_dut(self):
        return self.FRAGMENT_UNDER_TEST(**self.FRAGMENT_ARGUMENTS)

    def get_vcd_name(self):
        return f"test_{self.__class__.__name__}"

    def setUp(self):
        self.dut = self.instantiate_dut()
        self.sim = Simulator(self.dut)

        if self.USB_CLOCK_FREQUENCY:
            self.sim.add_clock(1 / self.USB_CLOCK_FREQUENCY, domain="usb")
        if self.SYNC_CLOCK_FREQUENCY:
            self.sim.add_clock(1 / self.SYNC_CLOCK_FREQUENCY, domain="sync")
        if self.FAST_CLOCK_FREQUENCY:
            self.sim.add_clock(1 / self.FAST_CLOCK_FREQUENCY, domain="fast")
        if self.SS_CLOCK_FREQUENCY:
            self.sim.add_clock(1 / self.SS_CLOCK_FREQUENCY, domain="ss")

    async def initialize_signals(self, sim):
        pass  # override in your test class as needed

    def traces_of_interest(self):
        return ()

    def simulate(self, *, vcd_suffix=None):
        if os.getenv("GENERATE_VCDS", default=False):
            vcd_name = self.get_vcd_name()
            if vcd_suffix:
                vcd_name = f"{vcd_name}_{vcd_suffix}"
            traces = self.traces_of_interest()
            with self.sim.write_vcd(
                vcd_name + ".vcd", vcd_name + ".gtkw", traces=traces
            ):
                self.sim.run()
        else:
            self.sim.run()

    async def pulse(self, signal, *, step_after=True):
        sim = self.sim
        sim.set(signal, 1)
        await sim.tick()
        sim.set(signal, 0)
        if step_after:
            await sim.tick()

    async def advance_cycles(self, cycles):
        for _ in range(cycles):
            await self.sim.tick()

    async def wait_until(self, strobe, *, timeout=None):
        cycles_passed = 0
        while not self.sim.get(strobe):
            await self.sim.tick()
            cycles_passed += 1
            if timeout and cycles_passed > timeout:
                raise RuntimeError(f"Timeout waiting for '{strobe.name}' to go high!")

    def _ensure_clocks_present(self):
        frequencies = {
            "sync": self.SYNC_CLOCK_FREQUENCY,
            "usb": self.USB_CLOCK_FREQUENCY,
            "fast": self.FAST_CLOCK_FREQUENCY,
            "ss": self.SS_CLOCK_FREQUENCY,
        }
        self.assertIsNotNone(
            frequencies[self.domain],
            f"no frequency provided for `{self.domain}`-domain clock!",
        )

    async def wait(self, time):
        if self.domain == "sync":
            period = 1 / self.SYNC_CLOCK_FREQUENCY
        elif self.domain == "usb":
            period = 1 / self.USB_CLOCK_FREQUENCY
        elif self.domain == "fast":
            period = 1 / self.FAST_CLOCK_FREQUENCY
        elif self.domain == "ss":
            period = 1 / self.SS_CLOCK_FREQUENCY
        else:
            raise ValueError(f"Unknown domain: {self.domain}")

        cycles = math.ceil(time / period)
        await self.advance_cycles(cycles)


class LunaUSBGatewareTestCase(LunaGatewareTestCase):
    SYNC_CLOCK_FREQUENCY = None
    USB_CLOCK_FREQUENCY = 60e6


class LunaSSGatewareTestCase(LunaGatewareTestCase):
    SYNC_CLOCK_FREQUENCY = None
    SS_CLOCK_FREQUENCY = 125e6

My spi.py

"""SPI and derived interfaces."""

from .utils import LunaGatewareTestCase


class SPIGatewareTestCase(LunaGatewareTestCase):
    """Extended version of the LunaGatewareTestCase.

    Adds three SPI-simulation methods:
        - spi_send_bit
        - spi_exchange_byte
        - spi_exchange_data
    """

    async def spi_send_bit(self, bit):
        """Sends a single bit over the SPI bus."""
        cycles_per_bit = 4
        spi = self.dut.spi
        sim = self.sim

        # Apply the new bit...
        if hasattr(spi, "sdi"):
            sim.set(spi.sdi, bit)
            await self.advance_cycles(cycles_per_bit)

        # Rising edge of serial clock
        sim.set(spi.sck, 1)
        await self.advance_cycles(cycles_per_bit)

        # Sample the output bit
        return_value = sim.get(spi.sdo)
        await self.advance_cycles(cycles_per_bit)

        # Falling edge of serial clock
        sim.set(spi.sck, 0)
        await self.advance_cycles(cycles_per_bit)

        return return_value

    async def spi_exchange_byte(self, datum, *, msb_first=True):
        """Sends a byte over the virtual SPI bus."""
        bits = "{:08b}".format(datum)
        data_received = ""

        if not msb_first:
            bits = bits[::-1]

        for bit in bits:
            received = await self.spi_send_bit(int(bit))
            data_received += "1" if received else "0"

        if not msb_first:
            data_received = data_received[::-1]

        return int(data_received, 2)

    async def spi_exchange_data(self, data, msb_first=True):
        """Sends a list of bytes over our virtual SPI bus."""
        sim = self.sim
        sim.set(self.dut.spi.cs, 1)
        await sim.tick()

        response = bytearray()
        for byte in data:
            response_byte = await self.spi_exchange_byte(byte, msb_first=msb_first)
            response.append(response_byte)

        sim.set(self.dut.spi.cs, 0)
        await sim.tick()

        return response

My test_spi.py

from luna.gateware.interface.spi import SPIDeviceInterface
from spi import SPIGatewareTestCase
from utils import async_test_case

class SPIDeviceInterfaceTest(SPIGatewareTestCase):
    FRAGMENT_UNDER_TEST = SPIDeviceInterface
    FRAGMENT_ARGUMENTS = dict(word_size=16, clock_polarity=1)

    async def initialize_signals(self, sim):
        self.sim = sim # this required now and not optimal
        sim.set(self.dut.spi.cs, 0)
        await sim.tick()

    @async_test_case
    async def test_spi_interface(self, sim):
        # Ensure that we don't complete a word while CS is deasserted.
        for _ in range(10):
            self.assertEqual(sim.get(self.dut.word_complete), 0)
            await sim.tick()
        # Set the word to send and assert CS.
        sim.set(self.dut.word_out, 0xABCD)
        await sim.tick()

        sim.set(self.dut.spi.cs, 1)
        await sim.tick()

        # Exchange bytes via SPI and verify response.
        response = await self.spi_exchange_data(b"\xCA\xFE")
        self.assertEqual(response, b"\xAB\xCD")
        self.assertEqual(sim.get(self.dut.word_in), 0xCAFE)

    @async_test_case
    async def test_spi_transmit_second_word(self, sim):
        sim.set(self.dut.word_out, 0x0F00)
        await sim.tick()

        sim.set(self.dut.spi.cs, 1)
        await sim.tick()

        response = await self.spi_exchange_data(b"\x00\x00")
        self.assertEqual(response, b"\x0F\x00")

if __name__ == "__main__":
    import unittest
    unittest.main()

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions