|
4 | 4 | Inspired by a document by Anton Osika and Axel Theorell.
|
5 | 5 | """
|
6 | 6 |
|
| 7 | +import csv |
7 | 8 | import inspect
|
8 | 9 | import logging
|
| 10 | +import subprocess |
9 | 11 | import sys
|
10 | 12 | import time
|
| 13 | +from datetime import datetime |
| 14 | +from pathlib import Path |
11 | 15 |
|
12 |
| -from .agents import GPTMe |
| 16 | +from .agents import Agent, GPTMe |
13 | 17 | from .evals import tests, tests_map
|
14 | 18 | from .execenv import SimpleExecutionEnv
|
15 | 19 | from .types import (
|
|
21 | 25 |
|
22 | 26 | logger = logging.getLogger(__name__)
|
23 | 27 |
|
| 28 | +project_dir = Path(__file__).parent.parent |
24 | 29 |
|
25 |
| -def execute(test: ExecTest) -> ExecResult: |
| 30 | + |
| 31 | +def execute(test: ExecTest, agent: Agent) -> ExecResult: |
26 | 32 | """
|
27 |
| - Executes the code. |
| 33 | + Executes the code for a specific model. |
28 | 34 | """
|
29 |
| - print(f"Running test {test['name']} with prompt: {test['prompt']}") |
30 |
| - agent = GPTMe() |
| 35 | + print( |
| 36 | + f"Running test {test['name']} with prompt: {test['prompt']} for model: {agent.model}" |
| 37 | + ) |
31 | 38 |
|
32 | 39 | # generate code
|
33 | 40 | gen_start = time.time()
|
@@ -74,48 +81,129 @@ def execute(test: ExecTest) -> ExecResult:
|
74 | 81 |
|
75 | 82 |
|
76 | 83 | def main():
|
| 84 | + models = [ |
| 85 | + # "openai/gpt-3.5-turbo", |
| 86 | + # "openai/gpt-4-turbo", |
| 87 | + # "openai/gpt-4o", |
| 88 | + "openai/gpt-4o-mini", |
| 89 | + # "anthropic/claude-3-5-sonnet-20240620", |
| 90 | + "anthropic/claude-3-haiku-20240307", |
| 91 | + ] |
77 | 92 | test_name = sys.argv[1] if len(sys.argv) > 1 else None
|
78 |
| - results = [] |
79 |
| - if test_name: |
80 |
| - print(f"=== Running test {test_name} ===") |
81 |
| - result = execute(tests_map[test_name]) |
82 |
| - results.append(result) |
83 |
| - else: |
84 |
| - print("=== Running all tests ===") |
85 |
| - for test in tests: |
86 |
| - result = execute(test) |
| 93 | + |
| 94 | + all_results = {} |
| 95 | + for model in models: |
| 96 | + print(f"\n=== Running tests for model: {model} ===") |
| 97 | + llm, model = model.split("/") |
| 98 | + agent = GPTMe(llm=llm, model=model) |
| 99 | + |
| 100 | + results = [] |
| 101 | + if test_name: |
| 102 | + print(f"=== Running test {test_name} ===") |
| 103 | + result = execute(tests_map[test_name], agent) |
87 | 104 | results.append(result)
|
| 105 | + else: |
| 106 | + print("=== Running all tests ===") |
| 107 | + for test in tests: |
| 108 | + result = execute(test, agent) |
| 109 | + results.append(result) |
88 | 110 |
|
89 |
| - print("=== Finished ===\n") |
90 |
| - duration_total = sum( |
91 |
| - result["timings"]["gen"] + result["timings"]["run"] + result["timings"]["eval"] |
92 |
| - for result in results |
93 |
| - ) |
94 |
| - print(f"Completed {len(results)} tests in {duration_total:.2f}s:") |
95 |
| - for result in results: |
96 |
| - checkmark = "✅" if all(case["passed"] for case in result["results"]) else "❌" |
97 |
| - duration_result = ( |
| 111 | + all_results[model] = results |
| 112 | + |
| 113 | + print("\n=== Finished ===\n") |
| 114 | + |
| 115 | + for model, results in all_results.items(): |
| 116 | + print(f"\nResults for model: {model}") |
| 117 | + duration_total = sum( |
98 | 118 | result["timings"]["gen"]
|
99 | 119 | + result["timings"]["run"]
|
100 | 120 | + result["timings"]["eval"]
|
| 121 | + for result in results |
101 | 122 | )
|
102 |
| - print( |
103 |
| - f"- {result['name']} in {duration_result:.2f}s (gen: {result['timings']['gen']:.2f}s, run: {result['timings']['run']:.2f}s, eval: {result['timings']['eval']:.2f}s)" |
104 |
| - ) |
105 |
| - for case in result["results"]: |
106 |
| - checkmark = "✅" if case["passed"] else "❌" |
107 |
| - print(f" {checkmark} {case['name']}") |
| 123 | + print(f"Completed {len(results)} tests in {duration_total:.2f}s:") |
| 124 | + for result in results: |
| 125 | + checkmark = ( |
| 126 | + "✅" if all(case["passed"] for case in result["results"]) else "❌" |
| 127 | + ) |
| 128 | + duration_result = ( |
| 129 | + result["timings"]["gen"] |
| 130 | + + result["timings"]["run"] |
| 131 | + + result["timings"]["eval"] |
| 132 | + ) |
| 133 | + print( |
| 134 | + f"- {result['name']} in {duration_result:.2f}s (gen: {result['timings']['gen']:.2f}s, run: {result['timings']['run']:.2f}s, eval: {result['timings']['eval']:.2f}s)" |
| 135 | + ) |
| 136 | + for case in result["results"]: |
| 137 | + checkmark = "✅" if case["passed"] else "❌" |
| 138 | + print(f" {checkmark} {case['name']}") |
| 139 | + |
| 140 | + print("\n=== Model Comparison ===") |
| 141 | + for test in tests: |
| 142 | + print(f"\nTest: {test['name']}") |
| 143 | + for model, results in all_results.items(): |
| 144 | + result = next(r for r in results if r["name"] == test["name"]) |
| 145 | + passed = all(case["passed"] for case in result["results"]) |
| 146 | + checkmark = "✅" if passed else "❌" |
| 147 | + duration = sum(result["timings"].values()) |
| 148 | + print(f"{model}: {checkmark} {duration:.2f}s") |
108 | 149 |
|
109 | 150 | all_success = all(
|
110 |
| - all(case["passed"] for case in result["results"]) for result in results |
| 151 | + all(all(case["passed"] for case in result["results"]) for result in results) |
| 152 | + for results in all_results.values() |
111 | 153 | )
|
112 | 154 | if all_success:
|
113 |
| - print("\n✅ All tests passed!") |
| 155 | + print("\n✅ All tests passed for all models!") |
114 | 156 | else:
|
115 | 157 | print("\n❌ Some tests failed!")
|
116 | 158 |
|
| 159 | + # Write results to CSV |
| 160 | + write_results_to_csv(all_results) |
| 161 | + |
117 | 162 | sys.exit(0 if all_success else 1)
|
118 | 163 |
|
119 | 164 |
|
| 165 | +def write_results_to_csv(all_results): |
| 166 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 167 | + # get current commit hash and dirty status, like: a8b2ef0-dirty |
| 168 | + commit_hash = subprocess.run( |
| 169 | + ["git", "describe", "--always", "--dirty", "--exclude", "'*'"], |
| 170 | + text=True, |
| 171 | + capture_output=True, |
| 172 | + ).stdout.strip() |
| 173 | + filename = project_dir / f"eval_results_{timestamp}.csv" |
| 174 | + |
| 175 | + with open(filename, "w", newline="") as csvfile: |
| 176 | + fieldnames = [ |
| 177 | + "Model", |
| 178 | + "Test", |
| 179 | + "Passed", |
| 180 | + "Total Duration", |
| 181 | + "Generation Time", |
| 182 | + "Run Time", |
| 183 | + "Eval Time", |
| 184 | + "Commit Hash", |
| 185 | + ] |
| 186 | + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
| 187 | + |
| 188 | + writer.writeheader() |
| 189 | + for model, results in all_results.items(): |
| 190 | + for result in results: |
| 191 | + passed = all(case["passed"] for case in result["results"]) |
| 192 | + writer.writerow( |
| 193 | + { |
| 194 | + "Model": model, |
| 195 | + "Test": result["name"], |
| 196 | + "Passed": "true" if passed else "false", |
| 197 | + "Total Duration": sum(result["timings"].values()), |
| 198 | + "Generation Time": result["timings"]["gen"], |
| 199 | + "Run Time": result["timings"]["run"], |
| 200 | + "Eval Time": result["timings"]["eval"], |
| 201 | + "Commit Hash": commit_hash, |
| 202 | + } |
| 203 | + ) |
| 204 | + |
| 205 | + print(f"\nResults saved to {filename.resolve()}") |
| 206 | + |
| 207 | + |
120 | 208 | if __name__ == "__main__":
|
121 | 209 | main()
|
0 commit comments