Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphs/tests/edges/test_multiscale_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_edges(self, tri_graph: HeteroData):
def test_fast_1_hop_method(selg, tri_graph: HeteroData, edge_resolutions):
nodes = tri_graph["hidden"]
all_points_mask_builder = KNNAreaMaskBuilder("all_nodes", 1.0)
all_points_mask_builder.fit_coords(nodes.x.numpy())
all_points_mask_builder.fit_coords(nodes.x.detach().cpu().numpy())

fast_edges = tri_icosahedron.add_1_hop_edges(
nodes_coords_rad=nodes["x"],
Expand Down
53 changes: 32 additions & 21 deletions models/tests/layers/mapper/test_graphconv_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from anemoi.utils.config import DotDict


@pytest.fixture(scope="module")
def device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ConcreteGNNBaseMapper(GNNBaseMapper):
"""Concrete implementation of GNNBaseMapper for testing."""

Expand Down Expand Up @@ -64,41 +69,42 @@ def mapper_init(self):
return MapperConfig()

@pytest.fixture
def graph_provider(self, fake_graph):
return create_graph_provider(
def graph_provider(self, fake_graph, device):
provider = create_graph_provider(
graph=fake_graph[("nodes", "to", "nodes")],
edge_attributes=["edge_attr1", "edge_attr2"],
src_size=self.NUM_SRC_NODES,
dst_size=self.NUM_DST_NODES,
trainable_size=6,
)
return provider.to(device)

@pytest.fixture
def mapper(self, mapper_init, graph_provider):
def mapper(self, mapper_init, graph_provider, device):
config = asdict(mapper_init)
config["edge_dim"] = graph_provider.edge_dim
return ConcreteGNNBaseMapper(**config)
return ConcreteGNNBaseMapper(**config).to(device)

@pytest.fixture
def pair_tensor(self, mapper_init):
def pair_tensor(self, mapper_init, device):
return (
torch.rand(self.NUM_SRC_NODES, mapper_init.in_channels_src),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_dst),
torch.rand(self.NUM_SRC_NODES, mapper_init.in_channels_src, device=device),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_dst, device=device),
)

@pytest.fixture
def fake_graph(self) -> HeteroData:
def fake_graph(self, device) -> HeteroData:
"""Fake graph."""
graph = HeteroData()
graph[("nodes", "to", "nodes")].edge_index = torch.concat(
[
torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES)),
torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES)),
torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES), device=device),
torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES), device=device),
],
axis=0,
)
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 1))
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 32))
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 1), device=device)
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 32), device=device)
return graph

def test_initialization(self, mapper, mapper_init):
Expand Down Expand Up @@ -141,11 +147,11 @@ class TestGNNForwardMapper(TestGNNBaseMapper):
"""Test the GNNForwardMapper class."""

@pytest.fixture
def mapper(self, mapper_init, graph_provider):
def mapper(self, mapper_init, graph_provider, device):
config = asdict(mapper_init)
config["edge_dim"] = graph_provider.edge_dim
del config["out_channels_dst"] # Not needed for forward mapper
return GNNForwardMapper(**config)
return GNNForwardMapper(**config).to(device)

def test_initialization(self, mapper, mapper_init):
assert isinstance(mapper, GNNBaseMapper)
Expand Down Expand Up @@ -183,7 +189,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor, graph_provider
assert x_dst.shape == torch.Size([self.NUM_DST_NODES, mapper_init.hidden_dim])

# Dummy loss
target = torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim)
target = torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim, device=x_dst.device)
loss_fn = nn.MSELoss()

loss = loss_fn(x_dst, target)
Expand All @@ -208,10 +214,10 @@ class TestGNNBackwardMapper(TestGNNBaseMapper):
"""Test the GNNBackwardMapper class."""

@pytest.fixture
def mapper(self, mapper_init, graph_provider):
def mapper(self, mapper_init, graph_provider, device):
config = asdict(mapper_init)
config["edge_dim"] = graph_provider.edge_dim
return GNNBackwardMapper(**config)
return GNNBackwardMapper(**config).to(device)

def test_pre_process(self, mapper, mapper_init, pair_tensor):
x = pair_tensor
Expand All @@ -230,7 +236,11 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor):
assert shapes_dst == [[self.NUM_DST_NODES, mapper_init.hidden_dim]]

def test_post_process(self, mapper, mapper_init):
x_dst = torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim)
x_dst = torch.rand(
self.NUM_DST_NODES,
mapper_init.hidden_dim,
device=next(mapper.parameters()).device,
)
shapes_dst = [list(x_dst.shape)]

result = mapper.post_process(x_dst, shapes_dst=shapes_dst)
Expand All @@ -243,17 +253,18 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor, graph_provider
shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)]
batch_size = 1

device = next(mapper.parameters()).device
x = (
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim),
torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim),
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim, device=device),
torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim, device=device),
)

edge_attr, edge_index, _ = graph_provider.get_edges(batch_size=batch_size)
result = mapper.forward(x, batch_size, shard_shapes, edge_attr, edge_index)
assert result.shape == torch.Size([self.NUM_DST_NODES, mapper_init.out_channels_dst])

# Dummy loss
target = torch.rand(self.NUM_DST_NODES, mapper_init.out_channels_dst)
target = torch.rand(self.NUM_DST_NODES, mapper_init.out_channels_dst, device=result.device)
loss_fn = nn.MSELoss()

loss = loss_fn(result, target)
Expand Down
63 changes: 38 additions & 25 deletions models/tests/layers/mapper/test_graphtransformer_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from anemoi.utils.config import DotDict


@pytest.fixture(scope="module")
def device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ConcreteGraphTransformerBaseMapper(GraphTransformerBaseMapper):
"""Concrete implementation of GraphTransformerBaseMapper for testing."""

Expand Down Expand Up @@ -69,44 +74,45 @@ def mapper_init(self):
return MapperConfig()

@pytest.fixture
def graph_provider(self, fake_graph):
return create_graph_provider(
def graph_provider(self, fake_graph, device):
provider = create_graph_provider(
graph=fake_graph[("nodes", "to", "nodes")],
edge_attributes=["edge_attr1", "edge_attr2"],
src_size=self.NUM_SRC_NODES,
dst_size=self.NUM_DST_NODES,
trainable_size=6,
)
return provider.to(device)

@pytest.fixture
def mapper(self, mapper_init, graph_provider):
def mapper(self, mapper_init, graph_provider, device):
config = asdict(mapper_init)
config["edge_dim"] = graph_provider.edge_dim
return ConcreteGraphTransformerBaseMapper(
**config,
out_channels_dst=self.OUT_CHANNELS_DST,
)
).to(device)

@pytest.fixture
def pair_tensor(self, mapper_init):
def pair_tensor(self, mapper_init, device):
return (
torch.rand(self.NUM_SRC_NODES, mapper_init.in_channels_src),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_dst),
torch.rand(self.NUM_SRC_NODES, mapper_init.in_channels_src, device=device),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_dst, device=device),
)

@pytest.fixture
def fake_graph(self) -> HeteroData:
def fake_graph(self, device) -> HeteroData:
"""Fake graph."""
graph = HeteroData()
graph[("nodes", "to", "nodes")].edge_index = torch.concat(
[
torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES)),
torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES)),
torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES), device=device),
torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES), device=device),
],
axis=0,
)
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 1))
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 32))
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 1), device=device)
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 32), device=device)
return graph

def test_initialization(self, mapper, mapper_init):
Expand Down Expand Up @@ -151,10 +157,10 @@ class TestGraphTransformerForwardMapper(TestGraphTransformerBaseMapper):
OUT_CHANNELS_DST = None

@pytest.fixture
def mapper(self, mapper_init, graph_provider):
def mapper(self, mapper_init, graph_provider, device):
config = asdict(mapper_init)
config["edge_dim"] = graph_provider.edge_dim
return GraphTransformerForwardMapper(**config)
return GraphTransformerForwardMapper(**config).to(device)

def test_pre_process(self, mapper, mapper_init, pair_tensor):
x = pair_tensor
Expand Down Expand Up @@ -183,7 +189,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor, graph_provider
assert x_dst.shape == torch.Size([self.NUM_DST_NODES, mapper_init.hidden_dim])

# Dummy loss
target = torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim)
target = torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim, device=x_dst.device)
loss_fn = nn.MSELoss()

loss = loss_fn(x_dst, target)
Expand Down Expand Up @@ -243,13 +249,13 @@ class TestGraphTransformerBackwardMapper(TestGraphTransformerBaseMapper):
"""Test the GraphTransformerBackwardMapper class."""

@pytest.fixture
def mapper(self, mapper_init, graph_provider):
def mapper(self, mapper_init, graph_provider, device):
config = asdict(mapper_init)
config["edge_dim"] = graph_provider.edge_dim
return GraphTransformerBackwardMapper(
**config,
out_channels_dst=self.OUT_CHANNELS_DST,
)
).to(device)

def test_pre_process(self, mapper, mapper_init, pair_tensor):
x = pair_tensor
Expand All @@ -268,7 +274,11 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor):
assert shapes_dst == [[self.NUM_DST_NODES, mapper_init.hidden_dim]]

def test_post_process(self, mapper, mapper_init):
x_dst = torch.rand(self.NUM_DST_NODES, mapper_init.hidden_dim)
x_dst = torch.rand(
self.NUM_DST_NODES,
mapper_init.hidden_dim,
device=next(mapper.parameters()).device,
)
shapes_dst = [list(x_dst.shape)]

result = mapper.post_process(x_dst, shapes_dst=shapes_dst)
Expand All @@ -281,17 +291,18 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor, graph_provider
batch_size = 1

# Different size for x_dst, as the Backward mapper changes the channels in shape in pre-processor
device = next(mapper.parameters()).device
x = (
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_src),
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim, device=device),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_src, device=device),
)

edge_attr, edge_index, _ = graph_provider.get_edges(batch_size=batch_size)
result = mapper.forward(x, batch_size, shard_shapes, edge_attr, edge_index)
assert result.shape == torch.Size([self.NUM_DST_NODES, self.OUT_CHANNELS_DST])

# Dummy loss
target = torch.rand(self.NUM_DST_NODES, self.OUT_CHANNELS_DST)
target = torch.rand(self.NUM_DST_NODES, self.OUT_CHANNELS_DST, device=result.device)
loss_fn = nn.MSELoss()

loss = loss_fn(result, target)
Expand All @@ -315,9 +326,10 @@ def test_chunking(self, mapper_init, mapper, pair_tensor, graph_provider):
shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)]
batch_size = 1

device = next(mapper.parameters()).device
x = (
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_src),
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim, device=device),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_src, device=device),
)

edge_attr, edge_index, _ = graph_provider.get_edges(batch_size=batch_size)
Expand All @@ -334,9 +346,10 @@ def test_strategy(self, mapper_init, mapper, pair_tensor, graph_provider):
shard_shapes = [list(pair_tensor[0].shape)], [list(pair_tensor[1].shape)]
batch_size = 1

device = next(mapper.parameters()).device
x = (
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_src),
torch.rand(self.NUM_SRC_NODES, mapper_init.hidden_dim, device=device),
torch.rand(self.NUM_DST_NODES, mapper_init.in_channels_src, device=device),
)

edge_attr, edge_index, _ = graph_provider.get_edges(batch_size=batch_size)
Expand Down
32 changes: 21 additions & 11 deletions models/tests/layers/processor/test_graphconv_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from anemoi.utils.config import DotDict


@pytest.fixture(scope="module")
def device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@dataclass
class GNNProcessorInit:
num_channels: int = 128
Expand All @@ -44,33 +49,36 @@ class TestGNNProcessor:
NUM_EDGES: int = 200

@pytest.fixture
def fake_graph(self) -> tuple[HeteroData, int]:
def fake_graph(self, device) -> tuple[HeteroData, int]:
graph = HeteroData()
graph["nodes"].x = torch.rand((self.NUM_NODES, 2))
graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, self.NUM_NODES, (2, self.NUM_EDGES))
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3))
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4))
graph["nodes"].x = torch.rand((self.NUM_NODES, 2), device=device)
graph[("nodes", "to", "nodes")].edge_index = torch.randint(
0, self.NUM_NODES, (2, self.NUM_EDGES), device=device
)
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3), device=device)
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4), device=device)
return graph

@pytest.fixture
def graphconv_init(self):
return GNNProcessorInit()

@pytest.fixture
def graph_provider(self, fake_graph):
return create_graph_provider(
def graph_provider(self, fake_graph, device):
provider = create_graph_provider(
graph=fake_graph[("nodes", "to", "nodes")],
edge_attributes=["edge_attr1", "edge_attr2"],
src_size=self.NUM_NODES,
dst_size=self.NUM_NODES,
trainable_size=8,
)
return provider.to(device)

@pytest.fixture
def graphconv_processor(self, graphconv_init, graph_provider):
def graphconv_processor(self, graphconv_init, graph_provider, device):
config = asdict(graphconv_init)
config["edge_dim"] = graph_provider.edge_dim
return GNNProcessor(**config)
return GNNProcessor(**config).to(device)

def test_graphconv_processor_init(self, graphconv_processor, graphconv_init, graph_provider):
assert graphconv_processor.num_chunks == graphconv_init.num_chunks
Expand All @@ -83,7 +91,9 @@ def test_all_blocks(self, graphconv_processor):

def test_forward(self, graphconv_processor, graphconv_init, graph_provider):
batch_size = 1
x = torch.rand((self.NUM_NODES, graphconv_init.num_channels))
x = torch.rand(
(self.NUM_NODES, graphconv_init.num_channels), device=next(graphconv_processor.parameters()).device
)
shard_shapes = [list(x.shape)]

# Run forward pass of processor
Expand All @@ -93,7 +103,7 @@ def test_forward(self, graphconv_processor, graphconv_init, graph_provider):

# Generate dummy target and loss function
loss_fn = torch.nn.MSELoss()
target = torch.rand((self.NUM_NODES, graphconv_init.num_channels))
target = torch.rand((self.NUM_NODES, graphconv_init.num_channels), device=output.device)
loss = loss_fn(output, target)

# Check loss
Expand Down
Loading