Skip to content

Commit ade40a7

Browse files
add test for observables
1 parent bbf298f commit ade40a7

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
from dataclasses import dataclass
3+
4+
import torch
5+
from rydberggpt.models.rydberg_encoder_decoder import get_rydberg_graph_encoder_decoder
6+
from rydberggpt.models.utils import generate_prompt
7+
from rydberggpt.observables.rydberg_energy import (
8+
get_rydberg_energy,
9+
get_staggered_magnetization,
10+
get_x_magnetization,
11+
)
12+
from rydberggpt.utils import create_config_from_yaml, load_yaml_file
13+
from rydberggpt.utils_ckpt import get_model_from_ckpt
14+
from torch_geometric.data import Batch
15+
16+
device = "cuda" if torch.cuda.is_available() else "cpu"
17+
18+
base_path = os.path.abspath(".")
19+
log_path = os.path.join(base_path, "models/ds_1/")
20+
21+
yaml_dict = load_yaml_file(log_path, "hparams.yaml")
22+
config: dataclass = create_config_from_yaml(yaml_dict)
23+
24+
model = get_model_from_ckpt(
25+
log_path, model=get_rydberg_graph_encoder_decoder(config), ckpt="best"
26+
)
27+
model.to(device=device)
28+
model.eval() # don't forget to set to eval mode
29+
30+
31+
def test_observables():
32+
L = 5
33+
delta = 1.0
34+
omega = 1.0
35+
beta = 64.0
36+
Rb = 1.15
37+
num_samples = 5
38+
39+
pyg_graph = generate_prompt(
40+
model_config=config,
41+
n_rows=L,
42+
n_cols=L,
43+
delta=delta,
44+
omega=omega,
45+
beta=beta,
46+
Rb=Rb,
47+
)
48+
49+
# duplicate the prompt for num_samples
50+
cond = [pyg_graph for _ in range(num_samples)]
51+
cond = Batch.from_data_list(cond)
52+
53+
samples = model.get_samples(
54+
batch_size=len(cond), cond=cond, num_atoms=L**2, fmt_onehot=False
55+
)
56+
57+
energy = get_rydberg_energy(model, samples, cond=pyg_graph, device=device)
58+
print(energy.mean() / L**2)
59+
60+
staggered_magnetization = get_staggered_magnetization(samples, L, L, device=device)
61+
print(staggered_magnetization.mean() / L**2)
62+
63+
x_magnetization = get_x_magnetization(model, samples, cond=pyg_graph, device=device)
64+
print(x_magnetization.mean() / L**2)
65+
assert True
66+
67+
68+
test_observables()

0 commit comments

Comments
 (0)