Skip to content

Commit cfc8a94

Browse files
committed
Resolve merge conflict
2 parents c7a57aa + 981b11e commit cfc8a94

21 files changed

+1810
-202
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
build:
13-
runs-on: ubuntu-latest
13+
runs-on: ubuntu 24.04
1414

1515
strategy:
1616
matrix:
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
OPENAI_API_KEY=xxx
1+
ENV=dev
2+
# LLM_MODE can be explicitly set to local, remote, or mock.
3+
# For testing, you can set ENV=test (defaulting LLM mode to mock) or override LLM_MODE directly.
4+
# LLM_MODE=local
5+
REMOTE_MODEL_URL=https://tai.berkeley.edu/api/chat

evaluation/README.md

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Evaluation Pipeline
2+
3+
TAI is equipped with the ability to measure the reliability and accuracy of its underlying Retrieval-Augmented Generation (RAG) agent. To simplify dataset creation and evaluation, this module provides customized evaluation functionality, ranging from creating evaluation datasets to implementing evaluation algorithms specifically designed for TAI.
4+
5+
## Features
6+
7+
- **Dataset Generation**: Seamlessly generate evaluation datasets tailored to the needs of the TAI system.
8+
- **Analysis Tools**: Analyze generated datasets to uncover biases and visualize relationships using Sankey graphs.
9+
10+
## Setup
11+
12+
1. **Install Requirements**: Install the required dependencies by running the following command:
13+
```sh
14+
pip install -r requirements.txt
15+
```
16+
17+
2. **Set Environment Variables**: Ensure the `OPENAI_API_KEY` is stored as an environment variable. You can add it to your `.bashrc`, `.zshrc`, or `.env` file for persistent configuration:
18+
```sh
19+
export OPENAI_API_KEY="your_api_key_here"
20+
```
21+
22+
3. **Prepare Input Data**: Place your input JSON file in the following directory:
23+
```
24+
/evaluation/dataset_generate/input
25+
```
26+
27+
## Dataset Generation
28+
29+
To generate an evaluation dataset, run the following command:
30+
```sh
31+
python -m evaluation.dataset_generate.generate <input_filename> [--num_pairs] [--quiet]
32+
```
33+
34+
### Arguments:
35+
- `<input_filename>`: Name of the input JSON file located in `/evaluation/dataset_generate/input`.
36+
- `--num_pairs`: (Optional) Specify the number of pairs to generate.
37+
- `--quiet`: (Optional) Suppress output logs for a cleaner console experience.
38+
39+
Example:
40+
```sh
41+
python -m evaluation.dataset_generate.generate sample_input.json --num_pairs 50
42+
```
43+
44+
## Dataset Analysis
45+
46+
To analyze the generated dataset for bias statistics and visualize relationships using a Sankey graph, use the following command:
47+
```sh
48+
python -m evaluation.dataset_analyze.analyze <input_filename> [--graph]
49+
```
50+
51+
### Arguments:
52+
- `<input_filename>`: Name of the input dataset file to analyze biases based on the input and output dataset labels.
53+
- `--graph`: Option to generate the Sankey graph visualization.
54+
55+
Example:
56+
```sh
57+
python -m evaluation.dataset_analyze.analyze generated_dataset.json --graph
58+
```
59+
60+
## Output Files
61+
62+
- **Generated Datasets**: Saved in `/evaluation/dataset_generate/output`.
63+
- **Analysis Results**: Saved in `/evaluation/dataset_analyze/output`.
+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import json
2+
import os
3+
from pprint import pprint
4+
import argparse
5+
import plotly.graph_objects as go
6+
from evaluation.dataset_generate.generate import generate_qa_pairs
7+
from rag.file_conversion_router.conversion.ed_converter import json_kb_filter
8+
9+
def compute_bias(data):
10+
11+
categories = []
12+
category_dict = {}
13+
total_count = 0
14+
15+
for entry in data:
16+
17+
category = entry['category']
18+
if category not in categories and category not in category_dict:
19+
categories.append(category)
20+
category_dict[category] = 1
21+
else:
22+
category_dict[category] += 1
23+
total_count += 1
24+
25+
biases = {key: value / total_count for key, value in category_dict.items()}
26+
27+
return biases
28+
29+
30+
def analyze(input_filename, graph=True):
31+
32+
original_file_path = os.path.join("evaluation", "dataset_generate", "input", input_filename)
33+
generated_file_path = os.path.join("evaluation", "dataset_generate", "output", f"evaluation_dataset_{input_filename}")
34+
35+
with open(original_file_path, 'r') as file:
36+
original_file = json.load(file)
37+
38+
with open(generated_file_path, 'r') as file:
39+
generated_dataset = json.load(file)
40+
41+
original_cleaned = json_kb_filter(original_file)
42+
original_dataset = generate_qa_pairs(original_cleaned)
43+
44+
biases_before = compute_bias(original_dataset)
45+
biases_after = compute_bias(generated_dataset)
46+
47+
biases_stats = {
48+
"Before": biases_before,
49+
"After": biases_after
50+
}
51+
52+
after_dict = {(entry["question"], entry["answer"]): entry["category"] for entry in generated_dataset}
53+
54+
change = []
55+
56+
for entry in original_dataset:
57+
question, answer, category_before = entry["question"], entry["answer"], entry["category"]
58+
new_category = after_dict.get((question, answer))
59+
60+
if new_category:
61+
change_status = {
62+
"question": question,
63+
"answer": answer,
64+
"before_category": category_before,
65+
"after_category": new_category
66+
}
67+
change.append(change_status)
68+
else:
69+
invalid_pair = {
70+
"question": question,
71+
"answer": answer,
72+
"before_category": category_before,
73+
"after_category": "Invalid"
74+
}
75+
change.append(invalid_pair)
76+
77+
result = {
78+
"biases": biases_stats,
79+
"results_comparison": change
80+
}
81+
82+
pprint(biases_stats)
83+
84+
output_path = os.path.join("evaluation", "dataset_generate", "output", f"datasets_analysis_{input_filename}")
85+
with open(output_path, 'w') as file:
86+
json.dump(result, file, indent=4)
87+
88+
if graph:
89+
before_categories = set()
90+
after_categories = set()
91+
transitions = {}
92+
93+
category_colors = {
94+
"General": "#aec7e8", # Light Blue
95+
"Problem Sets": "#ffbb78", # Light Orange
96+
"Assignments": "#98df8a", # Light Green
97+
"Lectures": "#ff9896", # Light Red
98+
"Sections": "#c5b0d5", # Light Purple
99+
"Social": "#c49c94" # Light Brown
100+
}
101+
102+
for entry in change:
103+
before = entry.get("before_category")
104+
after = entry.get("after_category")
105+
106+
before_categories.add(before)
107+
after_categories.add(after)
108+
109+
if (before, after) in transitions:
110+
transitions[(before, after)] += 1
111+
else:
112+
transitions[(before, after)] = 1
113+
114+
before_labels = sorted(before_categories)
115+
after_labels = sorted(after_categories)
116+
all_labels = before_labels + after_labels
117+
node_indices = {label: i for i, label in enumerate(all_labels)}
118+
119+
link_sources = []
120+
link_targets = []
121+
link_values = []
122+
link_colors = []
123+
124+
for (before, after), count in transitions.items():
125+
link_sources.append(node_indices[before])
126+
link_targets.append(node_indices[after])
127+
link_values.append(count)
128+
link_colors.append(category_colors.get(before, "#d3d3d3"))
129+
130+
fig = go.Figure(go.Sankey(
131+
node=dict(
132+
pad=15,
133+
thickness=20,
134+
line=dict(color="black", width=0.5),
135+
label=all_labels,
136+
color=["#a0c4ff"] * len(before_labels) + ["#ffc09f"] * len(after_labels)
137+
),
138+
link=dict(
139+
source=link_sources,
140+
target=link_targets,
141+
value=link_values,
142+
color=link_colors
143+
)
144+
))
145+
146+
output_path = os.path.join("evaluation", "dataset_generate", "output", f"category_flow_diagram_{os.path.splitext(input_filename)[0]}.png")
147+
148+
fig.update_layout(title_text="Diagram of Category Transitions", font_size=10)
149+
fig.write_image(output_path)
150+
151+
print(f"Sankey diagram saved to {output_path}")
152+
153+
if __name__ == "__main__":
154+
parser = argparse.ArgumentParser(description="Generate evaluation dataset")
155+
parser.add_argument("input_filename", type=str, help="The input JSON file")
156+
parser.add_argument("--graph", action="store_true", help="Generate Graph")
157+
158+
args = parser.parse_args()
159+
analyze(args.input_filename, args.graph)
160+
161+
162+

0 commit comments

Comments
 (0)