Skip to content

Commit 12421c2

Browse files
puririshi98pre-commit-ci[bot]akihironittarusty1s
authored
Add G-retriever (GNN+LLM) example (#9167)
1. #9462 2. #9480 3. #9481 4. **->** #9167 --- repro: Latest NVIDIA PyG container + `git config --global credential.helper store; huggingface-cli login; cd /opt/pyg; pip uninstall -y torch-geometric; rm -rf pytorch_geometric; git clone -b gnn-llm-model-integration https://github.com/pyg-team/pytorch_geometric.git; cd /opt/pyg/pytorch_geometric; pip install .; pip install peft datasets transformers pcst_fast sentencepiece; python3 examples/llm_plus_gnn/g_retriever.py` old PR: #9154 note: pure cpu is 220x slower than pure GPU using a single Grace Hopper (for llama-7b) info: tried gemma, performs worse in all train/val/test metrics. most likely needs some tuning, will leave this as future work as part of the community sprint to try many LLM and GNN combos and tune them. Therefore keeping the default llama2 the new gemma-v2 is also much worse than llama2 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: rusty1s <[email protected]>
1 parent bfc6d1a commit 12421c2

File tree

5 files changed

+296
-8
lines changed

5 files changed

+296
-8
lines changed

Diff for: CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88
### Added
99

1010
- Added the `WebQSPDataset` dataset ([#9481](https://github.com/pyg-team/pytorch_geometric/pull/9481))
11-
- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
11+
- Added the `GRetriever` model and an example ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480), [#9167](https://github.com/pyg-team/pytorch_geometric/pull/9167))
1212
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
1313
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
1414
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))

Diff for: examples/llm/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Examples for Co-training LLMs and GNNs
22

3-
| Example | Description |
4-
| ------- | ----------- |
5-
| | |
3+
| Example | Description |
4+
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
5+
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |

Diff for: examples/llm/g_retriever.py

+272
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""This example implements the G-Retriever model
2+
(https://arxiv.org/abs/2402.07630) using PyG.
3+
4+
G-Retriever significantly reduces hallucinations by 54% compared to the
5+
stand-alone LLM baseline.
6+
7+
Requirements:
8+
`pip install datasets transformers pcst_fast sentencepiece accelerate`
9+
"""
10+
import argparse
11+
import math
12+
import os.path as osp
13+
import re
14+
import time
15+
16+
import pandas as pd
17+
import torch
18+
from torch import Tensor
19+
from torch.nn.utils import clip_grad_norm_
20+
from tqdm import tqdm
21+
22+
from torch_geometric import seed_everything
23+
from torch_geometric.datasets import WebQSPDataset
24+
from torch_geometric.loader import DataLoader
25+
from torch_geometric.nn.models import GAT, GRetriever
26+
from torch_geometric.nn.nlp import LLM
27+
28+
29+
def compute_metrics(eval_output):
30+
df = pd.concat([pd.DataFrame(d) for d in eval_output])
31+
all_hit = []
32+
all_precision = []
33+
all_recall = []
34+
all_f1 = []
35+
36+
for pred, label in zip(df.pred.tolist(), df.label.tolist()):
37+
try:
38+
pred = pred.split('[/s]')[0].strip().split('|')
39+
hit = re.findall(pred[0], label)
40+
all_hit.append(len(hit) > 0)
41+
42+
label = label.split('|')
43+
matches = set(pred).intersection(set(label))
44+
precision = len(matches) / len(set(label))
45+
recall = len(matches) / len(set(pred))
46+
if recall + precision == 0:
47+
f1 = 0
48+
else:
49+
f1 = 2 * precision * recall / (precision + recall)
50+
51+
all_precision.append(precision)
52+
all_recall.append(recall)
53+
all_f1.append(f1)
54+
55+
except Exception as e:
56+
print(f'Label: {label}')
57+
print(f'Pred: {pred}')
58+
print(f'Exception: {e}')
59+
print('------------------')
60+
61+
hit = sum(all_hit) / len(all_hit)
62+
precision = sum(all_precision) / len(all_precision)
63+
recall = sum(all_recall) / len(all_recall)
64+
f1 = sum(all_f1) / len(all_f1)
65+
66+
print(f'Hit: {hit:.4f}')
67+
print(f'Precision: {precision:.4f}')
68+
print(f'Recall: {recall:.4f}')
69+
print(f'F1: {f1:.4f}')
70+
71+
72+
def save_params_dict(model, save_path):
73+
state_dict = model.state_dict()
74+
param_grad_dict = {
75+
k: v.requires_grad
76+
for (k, v) in model.named_parameters()
77+
}
78+
for k in list(state_dict.keys()):
79+
if k in param_grad_dict.keys() and not param_grad_dict[k]:
80+
del state_dict[k] # Delete parameters that do not require gradient
81+
torch.save(state_dict, save_path)
82+
83+
84+
def load_params_dict(model, save_path):
85+
state_dict = torch.load(save_path)
86+
model.load_state_dict(state_dict)
87+
return model
88+
89+
90+
def get_loss(model, batch, model_save_name) -> Tensor:
91+
if model_save_name == 'llm':
92+
return model(batch.question, batch.label, batch.desc)
93+
else:
94+
return model(batch.question, batch.x, batch.edge_index, batch.batch,
95+
batch.label, batch.edge_attr, batch.desc)
96+
97+
98+
def inference_step(model, batch, model_save_name):
99+
if model_save_name == 'llm':
100+
return model.inference(batch.question, batch.desc)
101+
else:
102+
return model.inference(batch.question, batch.x, batch.edge_index,
103+
batch.batch, batch.edge_attr, batch.desc)
104+
105+
106+
def train(
107+
num_epochs,
108+
hidden_channels,
109+
num_gnn_layers,
110+
batch_size,
111+
eval_batch_size,
112+
lr,
113+
checkpointing=False,
114+
tiny_llama=False,
115+
):
116+
def adjust_learning_rate(param_group, LR, epoch):
117+
# Decay the learning rate with half-cycle cosine after warmup
118+
min_lr = 5e-6
119+
warmup_epochs = 1
120+
if epoch < warmup_epochs:
121+
lr = LR
122+
else:
123+
lr = min_lr + (LR - min_lr) * 0.5 * (
124+
1.0 + math.cos(math.pi * (epoch - warmup_epochs) /
125+
(num_epochs - warmup_epochs)))
126+
param_group['lr'] = lr
127+
return lr
128+
129+
start_time = time.time()
130+
path = osp.dirname(osp.realpath(__file__))
131+
path = osp.join(path, '..', '..', 'data', 'WebQSPDataset')
132+
train_dataset = WebQSPDataset(path, split='train')
133+
val_dataset = WebQSPDataset(path, split='val')
134+
test_dataset = WebQSPDataset(path, split='test')
135+
136+
seed_everything(42)
137+
138+
train_loader = DataLoader(train_dataset, batch_size=batch_size,
139+
drop_last=True, pin_memory=True, shuffle=True)
140+
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size,
141+
drop_last=False, pin_memory=True, shuffle=False)
142+
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size,
143+
drop_last=False, pin_memory=True, shuffle=False)
144+
145+
gnn = GAT(
146+
in_channels=1024,
147+
hidden_channels=hidden_channels,
148+
out_channels=1024,
149+
num_layers=num_gnn_layers,
150+
heads=4,
151+
)
152+
if tiny_llama:
153+
llm = LLM(
154+
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
155+
num_params=1,
156+
)
157+
model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048)
158+
else:
159+
llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)
160+
model = GRetriever(llm=llm, gnn=gnn)
161+
162+
model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm'
163+
params = [p for _, p in model.named_parameters() if p.requires_grad]
164+
optimizer = torch.optim.AdamW([
165+
{
166+
'params': params,
167+
'lr': lr,
168+
'weight_decay': 0.05
169+
},
170+
], betas=(0.9, 0.95))
171+
grad_steps = 2
172+
173+
best_epoch = 0
174+
best_val_loss = float('inf')
175+
for epoch in range(num_epochs):
176+
model.train()
177+
epoch_loss = 0
178+
if epoch == 0:
179+
print(f"Total Preparation Time: {time.time() - start_time:2f}s")
180+
start_time = time.time()
181+
print("Training beginning...")
182+
epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'
183+
loader = tqdm(train_loader, desc=epoch_str)
184+
for step, batch in enumerate(loader):
185+
optimizer.zero_grad()
186+
loss = get_loss(model, batch, model_save_name)
187+
loss.backward()
188+
189+
clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
190+
191+
if (step + 1) % grad_steps == 0:
192+
adjust_learning_rate(optimizer.param_groups[0], lr,
193+
step / len(train_loader) + epoch)
194+
195+
optimizer.step()
196+
epoch_loss = epoch_loss + float(loss)
197+
198+
if (step + 1) % grad_steps == 0:
199+
lr = optimizer.param_groups[0]['lr']
200+
train_loss = epoch_loss / len(train_loader)
201+
print(epoch_str + f', Train Loss: {train_loss:4f}')
202+
203+
val_loss = 0
204+
eval_output = []
205+
model.eval()
206+
with torch.no_grad():
207+
for step, batch in enumerate(val_loader):
208+
loss = get_loss(model, batch, model_save_name)
209+
val_loss += loss.item()
210+
val_loss = val_loss / len(val_loader)
211+
print(epoch_str + f", Val Loss: {val_loss:4f}")
212+
if checkpointing and val_loss < best_val_loss:
213+
print("Checkpointing best model...")
214+
best_val_loss = val_loss
215+
best_epoch = epoch
216+
save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt')
217+
torch.cuda.empty_cache()
218+
torch.cuda.reset_max_memory_allocated()
219+
220+
if checkpointing and best_epoch != num_epochs - 1:
221+
print("Loading best checkpoint...")
222+
model = load_params_dict(
223+
model,
224+
f'{model_save_name}_best_val_loss_ckpt.pt',
225+
)
226+
227+
model.eval()
228+
eval_output = []
229+
print("Final evaluation...")
230+
progress_bar_test = tqdm(range(len(test_loader)))
231+
for step, batch in enumerate(test_loader):
232+
with torch.no_grad():
233+
pred = inference_step(model, batch, model_save_name)
234+
eval_data = {
235+
'pred': pred,
236+
'question': batch.question,
237+
'desc': batch.desc,
238+
'label': batch.label
239+
}
240+
eval_output.append(eval_data)
241+
progress_bar_test.update(1)
242+
243+
compute_metrics(eval_output)
244+
print(f"Total Training Time: {time.time() - start_time:2f}s")
245+
save_params_dict(model, f'{model_save_name}.pt')
246+
torch.save(eval_output, f'{model_save_name}_eval_outs.pt')
247+
248+
249+
if __name__ == '__main__':
250+
parser = argparse.ArgumentParser()
251+
parser.add_argument('--gnn_hidden_channels', type=int, default=1024)
252+
parser.add_argument('--num_gnn_layers', type=int, default=4)
253+
parser.add_argument('--lr', type=float, default=1e-5)
254+
parser.add_argument('--epochs', type=int, default=2)
255+
parser.add_argument('--batch_size', type=int, default=8)
256+
parser.add_argument('--eval_batch_size', type=int, default=16)
257+
parser.add_argument('--checkpointing', action='store_true')
258+
parser.add_argument('--tiny_llama', action='store_true')
259+
args = parser.parse_args()
260+
261+
start_time = time.time()
262+
train(
263+
args.epochs,
264+
args.gnn_hidden_channels,
265+
args.num_gnn_layers,
266+
args.batch_size,
267+
args.eval_batch_size,
268+
args.lr,
269+
checkpointing=args.checkpointing,
270+
tiny_llama=args.tiny_llama,
271+
)
272+
print(f"Total Time: {time.time() - start_time:2f}s")

Diff for: torch_geometric/nn/models/g_retriever.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
from torch import Tensor
55

6-
from torch_geometric.nn.models import GAT
76
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
87
from torch_geometric.utils import scatter
98

@@ -43,7 +42,6 @@ def __init__(
4342
llm: LLM,
4443
gnn: torch.nn.Module,
4544
use_lora: bool = False,
46-
gnn_to_use=GAT,
4745
mlp_out_channels: int = 4096,
4846
) -> None:
4947
super().__init__()
@@ -126,7 +124,15 @@ def forward(
126124
"""
127125
x = self.encode(x, edge_index, batch, edge_attr)
128126
x = self.projector(x)
129-
xs = x.split(x.size(0), dim=0)
127+
xs = x.split(1, dim=0)
128+
129+
# Handle questions without node features:
130+
batch_unique = batch.unique()
131+
batch_size = len(question)
132+
if len(batch_unique) < batch_size:
133+
xs = [
134+
xs[i] if i in batch_unique else None for i in range(batch_size)
135+
]
130136

131137
(
132138
inputs_embeds,
@@ -174,7 +180,15 @@ def inference(
174180
"""
175181
x = self.encode(x, edge_index, batch, edge_attr)
176182
x = self.projector(x)
177-
xs = x.split(x.size(0), dim=0)
183+
xs = x.split(1, dim=0)
184+
185+
# Handle questions without node features:
186+
batch_unique = batch.unique()
187+
batch_size = len(question)
188+
if len(batch_unique) < batch_size:
189+
xs = [
190+
xs[i] if i in batch_unique else None for i in range(batch_size)
191+
]
178192

179193
inputs_embeds, attention_mask, _ = self.llm._get_embeds(
180194
question, additional_text_context, xs)

Diff for: torch_geometric/nn/nlp/llm.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from contextlib import nullcontext
23
from typing import Any, Dict, List, Optional
34

@@ -85,6 +86,7 @@ def __init__(
8586
self.word_embedding = self.llm.model.get_input_embeddings()
8687

8788
if 'max_memory' not in kwargs: # Pure CPU:
89+
warnings.warn("LLM is being used on CPU, which may be slow")
8890
self.device = torch.device('cpu')
8991
self.autocast_context = nullcontext()
9092
else:

0 commit comments

Comments
 (0)