Skip to content

Commit ed19746

Browse files
committed
feat(eval): added ability to run evals against different providers/models
1 parent f4b1f40 commit ed19746

File tree

2 files changed

+130
-34
lines changed

2 files changed

+130
-34
lines changed

eval/agents.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111

1212
class Agent:
13+
def __init__(self, llm: str, model: str):
14+
self.llm = llm
15+
self.model = model
16+
1317
@abstractmethod
1418
def act(self, files: Files | None, prompt: str) -> Files:
1519
"""
@@ -28,14 +32,18 @@ def act(self, files: Files | None, prompt: str):
2832

2933
print("\n--- Start of generation ---")
3034
print(f"Working in {store.working_dir}")
35+
prompt_sys = get_prompt()
36+
prompt_sys.content += (
37+
"\n\nIf you have trouble and dont seem to make progress, stop trying."
38+
)
3139
# TODO: add timeout
3240
try:
3341
gptme_chat(
3442
[Message("user", prompt)],
35-
[get_prompt()],
43+
[prompt_sys],
3644
f"gptme-evals-{store.id}",
37-
llm=None,
38-
model=None,
45+
llm=self.llm,
46+
model=self.model,
3947
no_confirm=True,
4048
interactive=False,
4149
)

eval/main.py

+119-31
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
Inspired by a document by Anton Osika and Axel Theorell.
55
"""
66

7+
import csv
78
import inspect
89
import logging
10+
import subprocess
911
import sys
1012
import time
13+
from datetime import datetime
14+
from pathlib import Path
1115

12-
from .agents import GPTMe
16+
from .agents import Agent, GPTMe
1317
from .evals import tests, tests_map
1418
from .execenv import SimpleExecutionEnv
1519
from .types import (
@@ -21,13 +25,16 @@
2125

2226
logger = logging.getLogger(__name__)
2327

28+
project_dir = Path(__file__).parent.parent
2429

25-
def execute(test: ExecTest) -> ExecResult:
30+
31+
def execute(test: ExecTest, agent: Agent) -> ExecResult:
2632
"""
27-
Executes the code.
33+
Executes the code for a specific model.
2834
"""
29-
print(f"Running test {test['name']} with prompt: {test['prompt']}")
30-
agent = GPTMe()
35+
print(
36+
f"Running test {test['name']} with prompt: {test['prompt']} for model: {agent.model}"
37+
)
3138

3239
# generate code
3340
gen_start = time.time()
@@ -74,48 +81,129 @@ def execute(test: ExecTest) -> ExecResult:
7481

7582

7683
def main():
84+
models = [
85+
# "openai/gpt-3.5-turbo",
86+
# "openai/gpt-4-turbo",
87+
# "openai/gpt-4o",
88+
"openai/gpt-4o-mini",
89+
# "anthropic/claude-3-5-sonnet-20240620",
90+
"anthropic/claude-3-haiku-20240307",
91+
]
7792
test_name = sys.argv[1] if len(sys.argv) > 1 else None
78-
results = []
79-
if test_name:
80-
print(f"=== Running test {test_name} ===")
81-
result = execute(tests_map[test_name])
82-
results.append(result)
83-
else:
84-
print("=== Running all tests ===")
85-
for test in tests:
86-
result = execute(test)
93+
94+
all_results = {}
95+
for model in models:
96+
print(f"\n=== Running tests for model: {model} ===")
97+
llm, model = model.split("/")
98+
agent = GPTMe(llm=llm, model=model)
99+
100+
results = []
101+
if test_name:
102+
print(f"=== Running test {test_name} ===")
103+
result = execute(tests_map[test_name], agent)
87104
results.append(result)
105+
else:
106+
print("=== Running all tests ===")
107+
for test in tests:
108+
result = execute(test, agent)
109+
results.append(result)
88110

89-
print("=== Finished ===\n")
90-
duration_total = sum(
91-
result["timings"]["gen"] + result["timings"]["run"] + result["timings"]["eval"]
92-
for result in results
93-
)
94-
print(f"Completed {len(results)} tests in {duration_total:.2f}s:")
95-
for result in results:
96-
checkmark = "✅" if all(case["passed"] for case in result["results"]) else "❌"
97-
duration_result = (
111+
all_results[model] = results
112+
113+
print("\n=== Finished ===\n")
114+
115+
for model, results in all_results.items():
116+
print(f"\nResults for model: {model}")
117+
duration_total = sum(
98118
result["timings"]["gen"]
99119
+ result["timings"]["run"]
100120
+ result["timings"]["eval"]
121+
for result in results
101122
)
102-
print(
103-
f"- {result['name']} in {duration_result:.2f}s (gen: {result['timings']['gen']:.2f}s, run: {result['timings']['run']:.2f}s, eval: {result['timings']['eval']:.2f}s)"
104-
)
105-
for case in result["results"]:
106-
checkmark = "✅" if case["passed"] else "❌"
107-
print(f" {checkmark} {case['name']}")
123+
print(f"Completed {len(results)} tests in {duration_total:.2f}s:")
124+
for result in results:
125+
checkmark = (
126+
"✅" if all(case["passed"] for case in result["results"]) else "❌"
127+
)
128+
duration_result = (
129+
result["timings"]["gen"]
130+
+ result["timings"]["run"]
131+
+ result["timings"]["eval"]
132+
)
133+
print(
134+
f"- {result['name']} in {duration_result:.2f}s (gen: {result['timings']['gen']:.2f}s, run: {result['timings']['run']:.2f}s, eval: {result['timings']['eval']:.2f}s)"
135+
)
136+
for case in result["results"]:
137+
checkmark = "✅" if case["passed"] else "❌"
138+
print(f" {checkmark} {case['name']}")
139+
140+
print("\n=== Model Comparison ===")
141+
for test in tests:
142+
print(f"\nTest: {test['name']}")
143+
for model, results in all_results.items():
144+
result = next(r for r in results if r["name"] == test["name"])
145+
passed = all(case["passed"] for case in result["results"])
146+
checkmark = "✅" if passed else "❌"
147+
duration = sum(result["timings"].values())
148+
print(f"{model}: {checkmark} {duration:.2f}s")
108149

109150
all_success = all(
110-
all(case["passed"] for case in result["results"]) for result in results
151+
all(all(case["passed"] for case in result["results"]) for result in results)
152+
for results in all_results.values()
111153
)
112154
if all_success:
113-
print("\n✅ All tests passed!")
155+
print("\n✅ All tests passed for all models!")
114156
else:
115157
print("\n❌ Some tests failed!")
116158

159+
# Write results to CSV
160+
write_results_to_csv(all_results)
161+
117162
sys.exit(0 if all_success else 1)
118163

119164

165+
def write_results_to_csv(all_results):
166+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
167+
# get current commit hash and dirty status, like: a8b2ef0-dirty
168+
commit_hash = subprocess.run(
169+
["git", "describe", "--always", "--dirty", "--exclude", "'*'"],
170+
text=True,
171+
capture_output=True,
172+
).stdout.strip()
173+
filename = project_dir / f"eval_results_{timestamp}.csv"
174+
175+
with open(filename, "w", newline="") as csvfile:
176+
fieldnames = [
177+
"Model",
178+
"Test",
179+
"Passed",
180+
"Total Duration",
181+
"Generation Time",
182+
"Run Time",
183+
"Eval Time",
184+
"Commit Hash",
185+
]
186+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
187+
188+
writer.writeheader()
189+
for model, results in all_results.items():
190+
for result in results:
191+
passed = all(case["passed"] for case in result["results"])
192+
writer.writerow(
193+
{
194+
"Model": model,
195+
"Test": result["name"],
196+
"Passed": "true" if passed else "false",
197+
"Total Duration": sum(result["timings"].values()),
198+
"Generation Time": result["timings"]["gen"],
199+
"Run Time": result["timings"]["run"],
200+
"Eval Time": result["timings"]["eval"],
201+
"Commit Hash": commit_hash,
202+
}
203+
)
204+
205+
print(f"\nResults saved to {filename.resolve()}")
206+
207+
120208
if __name__ == "__main__":
121209
main()

0 commit comments

Comments
 (0)