-
Notifications
You must be signed in to change notification settings - Fork 6
/
models.py
101 lines (89 loc) · 4.67 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import torch
import openai
import anthropic
from transformers import AutoModelForCausalLM, AutoTokenizer
class ModelGPT:
def __init__(self, model_name):
self.model_name = model_name
self.client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def get_response(self, prompt, max_n_tokens, temperature):
# o1 models don't support system messages and max_tokens
if 'o1' in self.model_name:
messages = [
{"role": "user", "content": prompt}
]
# when input filters kick in
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
# max_completion_tokens=max_n_tokens, # hard to set since it also counts CoT tokens
temperature=temperature,
seed=0,
)
generation = response.choices[0].message.content
# truncate the generation to save tokens of the GPT-4 judge and make it more comparable
# to the other models that adhere to this limit
generation = generation[:int(4.6*max_n_tokens)]
except Exception as e:
print(f"Error: {e}")
generation = ""
else:
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt}
]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_n_tokens,
temperature=temperature,
seed=0,
)
generation = response.choices[0].message.content
print(f"Generation: {generation}")
return generation
class ModelClaude:
def __init__(self, model_name):
self.model_name = model_name
self.client = anthropic.Anthropic()
def get_response(self, prompt, max_n_tokens, temperature):
messages = [
{"role": "user", "content": [{"type": "text", "text": prompt}]}
]
output = self.client.messages.create(
model=self.model_name,
max_tokens=max_n_tokens,
temperature=temperature,
messages=messages
)
return output.content[0].text
class ModelHuggingFace:
def __init__(self, model_name):
model_dict = {
"phi3": "microsoft/Phi-3-mini-128k-instruct",
"gemma2-9b": "google/gemma-2-9b-it",
"llama3-8b": "meta-llama/Meta-Llama-3-8B-Instruct",
"r2d2": "cais/zephyr_7b_r2d2",
}
self.system_prompts = {
"phi3": "You are a helpful AI assistant.",
"gemma2-9b": "",
"llama3-8b": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to a question, please don’t share false information.",
"r2d2": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human’s questions.",
}
self.device = torch.device("cuda")
self.model_name = model_name
self.model = AutoModelForCausalLM.from_pretrained(model_dict[model_name], torch_dtype=torch.float16, device_map=self.device,token=os.getenv("HF_TOKEN"), trust_remote_code=True).eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_dict[model_name], token=os.getenv("HF_TOKEN"))
def get_response(self, prompt, max_n_tokens, temperature):
conv = [{"role": "user", "content": prompt}]
if self.system_prompts[self.model_name] != "":
conv = [{"role": "system", "content": self.system_prompts[self.model_name]}] + conv
prompt_formatted = self.tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(prompt_formatted, return_tensors='pt').to(self.device)
outputs = self.model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_n_tokens, temperature=temperature, do_sample=True)
outputs_truncated = outputs[0][len(inputs['input_ids'][0]):]
response = self.tokenizer.decode(outputs_truncated, skip_special_tokens=True)
return response