-
Notifications
You must be signed in to change notification settings - Fork 173
Open
Description
Thanks for all the effort still being put in this library.
It would be great if you could support RFC36, async testbench functions. The yield syntax creates some issues for me further down the road.
I now have the following fix, which might be of use. Code works and is tested. As a disclaimer, I partly generated the code using OpenAI 04 model. For my current case, I only need SPIGatewareTestCase.
Furthermore, in initialize_signal, you need to set self.sim = sim, which is sub optimal.
utils.py
import os
import math
import unittest
from functools import wraps
from amaranth.sim import Simulator
def async_test_case(process_function, *, domain="sync"):
"""Decorator for writing async test cases using Amaranth's new async testbench system."""
@wraps(process_function)
def run_test(self):
self.domain = domain
self._ensure_clocks_present()
async def testbench(sim):
await self.initialize_signals(sim)
await process_function(self, sim)
self.sim.add_testbench(testbench)
self.simulate(vcd_suffix=process_function.__name__)
return run_test
def usb_domain_test_case(process_function):
return async_test_case(process_function, domain="usb")
def fast_domain_test_case(process_function):
return async_test_case(process_function, domain="fast")
def ss_domain_test_case(process_function):
return async_test_case(process_function, domain="ss")
class LunaGatewareTestCase(unittest.TestCase):
domain = "sync"
FRAGMENT_UNDER_TEST = None
FRAGMENT_ARGUMENTS = {}
FAST_CLOCK_FREQUENCY = None
SYNC_CLOCK_FREQUENCY = 120e6
USB_CLOCK_FREQUENCY = None
SS_CLOCK_FREQUENCY = None
def instantiate_dut(self):
return self.FRAGMENT_UNDER_TEST(**self.FRAGMENT_ARGUMENTS)
def get_vcd_name(self):
return f"test_{self.__class__.__name__}"
def setUp(self):
self.dut = self.instantiate_dut()
self.sim = Simulator(self.dut)
if self.USB_CLOCK_FREQUENCY:
self.sim.add_clock(1 / self.USB_CLOCK_FREQUENCY, domain="usb")
if self.SYNC_CLOCK_FREQUENCY:
self.sim.add_clock(1 / self.SYNC_CLOCK_FREQUENCY, domain="sync")
if self.FAST_CLOCK_FREQUENCY:
self.sim.add_clock(1 / self.FAST_CLOCK_FREQUENCY, domain="fast")
if self.SS_CLOCK_FREQUENCY:
self.sim.add_clock(1 / self.SS_CLOCK_FREQUENCY, domain="ss")
async def initialize_signals(self, sim):
pass # override in your test class as needed
def traces_of_interest(self):
return ()
def simulate(self, *, vcd_suffix=None):
if os.getenv("GENERATE_VCDS", default=False):
vcd_name = self.get_vcd_name()
if vcd_suffix:
vcd_name = f"{vcd_name}_{vcd_suffix}"
traces = self.traces_of_interest()
with self.sim.write_vcd(
vcd_name + ".vcd", vcd_name + ".gtkw", traces=traces
):
self.sim.run()
else:
self.sim.run()
async def pulse(self, signal, *, step_after=True):
sim = self.sim
sim.set(signal, 1)
await sim.tick()
sim.set(signal, 0)
if step_after:
await sim.tick()
async def advance_cycles(self, cycles):
for _ in range(cycles):
await self.sim.tick()
async def wait_until(self, strobe, *, timeout=None):
cycles_passed = 0
while not self.sim.get(strobe):
await self.sim.tick()
cycles_passed += 1
if timeout and cycles_passed > timeout:
raise RuntimeError(f"Timeout waiting for '{strobe.name}' to go high!")
def _ensure_clocks_present(self):
frequencies = {
"sync": self.SYNC_CLOCK_FREQUENCY,
"usb": self.USB_CLOCK_FREQUENCY,
"fast": self.FAST_CLOCK_FREQUENCY,
"ss": self.SS_CLOCK_FREQUENCY,
}
self.assertIsNotNone(
frequencies[self.domain],
f"no frequency provided for `{self.domain}`-domain clock!",
)
async def wait(self, time):
if self.domain == "sync":
period = 1 / self.SYNC_CLOCK_FREQUENCY
elif self.domain == "usb":
period = 1 / self.USB_CLOCK_FREQUENCY
elif self.domain == "fast":
period = 1 / self.FAST_CLOCK_FREQUENCY
elif self.domain == "ss":
period = 1 / self.SS_CLOCK_FREQUENCY
else:
raise ValueError(f"Unknown domain: {self.domain}")
cycles = math.ceil(time / period)
await self.advance_cycles(cycles)
class LunaUSBGatewareTestCase(LunaGatewareTestCase):
SYNC_CLOCK_FREQUENCY = None
USB_CLOCK_FREQUENCY = 60e6
class LunaSSGatewareTestCase(LunaGatewareTestCase):
SYNC_CLOCK_FREQUENCY = None
SS_CLOCK_FREQUENCY = 125e6
My spi.py
"""SPI and derived interfaces."""
from .utils import LunaGatewareTestCase
class SPIGatewareTestCase(LunaGatewareTestCase):
"""Extended version of the LunaGatewareTestCase.
Adds three SPI-simulation methods:
- spi_send_bit
- spi_exchange_byte
- spi_exchange_data
"""
async def spi_send_bit(self, bit):
"""Sends a single bit over the SPI bus."""
cycles_per_bit = 4
spi = self.dut.spi
sim = self.sim
# Apply the new bit...
if hasattr(spi, "sdi"):
sim.set(spi.sdi, bit)
await self.advance_cycles(cycles_per_bit)
# Rising edge of serial clock
sim.set(spi.sck, 1)
await self.advance_cycles(cycles_per_bit)
# Sample the output bit
return_value = sim.get(spi.sdo)
await self.advance_cycles(cycles_per_bit)
# Falling edge of serial clock
sim.set(spi.sck, 0)
await self.advance_cycles(cycles_per_bit)
return return_value
async def spi_exchange_byte(self, datum, *, msb_first=True):
"""Sends a byte over the virtual SPI bus."""
bits = "{:08b}".format(datum)
data_received = ""
if not msb_first:
bits = bits[::-1]
for bit in bits:
received = await self.spi_send_bit(int(bit))
data_received += "1" if received else "0"
if not msb_first:
data_received = data_received[::-1]
return int(data_received, 2)
async def spi_exchange_data(self, data, msb_first=True):
"""Sends a list of bytes over our virtual SPI bus."""
sim = self.sim
sim.set(self.dut.spi.cs, 1)
await sim.tick()
response = bytearray()
for byte in data:
response_byte = await self.spi_exchange_byte(byte, msb_first=msb_first)
response.append(response_byte)
sim.set(self.dut.spi.cs, 0)
await sim.tick()
return response
My test_spi.py
from luna.gateware.interface.spi import SPIDeviceInterface
from spi import SPIGatewareTestCase
from utils import async_test_case
class SPIDeviceInterfaceTest(SPIGatewareTestCase):
FRAGMENT_UNDER_TEST = SPIDeviceInterface
FRAGMENT_ARGUMENTS = dict(word_size=16, clock_polarity=1)
async def initialize_signals(self, sim):
self.sim = sim # this required now and not optimal
sim.set(self.dut.spi.cs, 0)
await sim.tick()
@async_test_case
async def test_spi_interface(self, sim):
# Ensure that we don't complete a word while CS is deasserted.
for _ in range(10):
self.assertEqual(sim.get(self.dut.word_complete), 0)
await sim.tick()
# Set the word to send and assert CS.
sim.set(self.dut.word_out, 0xABCD)
await sim.tick()
sim.set(self.dut.spi.cs, 1)
await sim.tick()
# Exchange bytes via SPI and verify response.
response = await self.spi_exchange_data(b"\xCA\xFE")
self.assertEqual(response, b"\xAB\xCD")
self.assertEqual(sim.get(self.dut.word_in), 0xCAFE)
@async_test_case
async def test_spi_transmit_second_word(self, sim):
sim.set(self.dut.word_out, 0x0F00)
await sim.tick()
sim.set(self.dut.spi.cs, 1)
await sim.tick()
response = await self.spi_exchange_data(b"\x00\x00")
self.assertEqual(response, b"\x0F\x00")
if __name__ == "__main__":
import unittest
unittest.main()
Metadata
Metadata
Assignees
Labels
No labels