Skip to content

Commit ad55824

Browse files
committed
Clean up RagasEvaluator interfaces
Refactor RagasEvaluator Class for use for `ilab` interface. Signed-off-by: Ali Maredia <[email protected]>
1 parent 8034f7e commit ad55824

File tree

1 file changed

+38
-58
lines changed

1 file changed

+38
-58
lines changed

src/instructlab/eval/ragas.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class ModelConfig(BaseModel):
8383
max_tokens: int = 768
8484

8585
# Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
86-
seed: int = DEFAULT_SEED
86+
seed: int = 42
8787

8888

8989
class RagasEvaluator(Evaluator):
@@ -105,29 +105,25 @@ def __init__(
105105
self.judge_openai_api_key = judge_openai_api_key
106106

107107
@staticmethod
108-
def _validate_dataset(df: DataFrame):
108+
def validate_dataset(df: DataFrame):
109109
"""
110110
Validates whether or not the given `df` is a valid dataset of `Sample` objects.
111111
112112
Args:
113-
df (DataFrame): DataFrame containing the dataset to be evaluated.
113+
df (DataFrame): DataFrame containing the dataset to be evaluated.
114114
"""
115-
# We have to hardcode these fields because the automated way of resolving the required fields from a TypedDict
116-
# is only included by default in Python3.11+. For earlier versions, the `typing_extensions` package is required.
117-
# See: https://docs.python.org/3/whatsnew/3.11.html#pep-655-marking-individual-typeddict-items-as-required-or-not-required
118-
required_keys = {"user_input", "reference"}
119-
missing_keys = required_keys - set(df.columns)
120-
if missing_keys:
115+
required_keys = {"user_input", "reference", "response"}
116+
117+
columns_list = set(df.columns)
118+
if not columns_list.issubset(required_keys):
121119
raise ValueError(
122-
f"invalid dataset provided, missing the following keys: {', '.join(missing_keys)}"
120+
f"Dataset can only have the following keys: {', '.join(required_keys)}. Keys provided were: {', '.join(columns_list)}"
123121
)
124122

125123
def run(
126124
self,
127-
dataset: List[Sample] | Path,
128-
student_model: ModelConfig | None = None,
125+
dataset: List[Sample] | DataFrame,
129126
run_config: RunConfig | None = None,
130-
student_openai_client: OpenAIClient | None = None,
131127
judge_model_name: str | None = None,
132128
judge_openai_api_key: str | None = None,
133129
) -> EvaluationResult:
@@ -141,17 +137,12 @@ def run(
141137
dataset (List[Sample] | Path):
142138
Can be either a list of `Sample` objects or a path to a jsonl file containing
143139
records matching `Sample`.
144-
student_model: (StudentModelConfig):
145-
When this parameter is provided, we'll attempt to use the described model in order to
146140
generate the responses from the given list of questions.
147141
run_config (RunConfig | None, optional):
148142
Configuration to use when running evaluations. If none is provided, then
149143
a default one is created containing extremely permissive settings when handling
150144
timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
151145
rate limits resulting in heavy throttling during evaluations.
152-
student_openai_client (openai.Client | None, optional):
153-
The client to use when generating questions from the student model, must be compatible with the OpenAI API.
154-
This field is required when `student_model` is provided.
155146
judge_model_name (str | None, optional):
156147
Name of the OpenAI model to use as the judge model. Defaults to "gpt-4o" when none is specified.
157148
judge_openai_api_key (str | None, optional):
@@ -167,50 +158,29 @@ def run(
167158
judge_openai_api_key = (
168159
judge_openai_api_key if judge_openai_api_key else self.judge_openai_api_key
169160
)
170-
student_model = student_model if student_model else self.student_model
171161
run_config = run_config if run_config else self.run_config
172-
student_openai_client = (
173-
student_openai_client
174-
if student_openai_client
175-
else self.student_openai_client
176-
)
177162

178163
# ensure we are in the dataframe format
179-
input_df = None
164+
input_df = dataset
180165
if isinstance(dataset, list):
181166
input_df = DataFrame(dataset)
182-
elif isinstance(dataset, Path):
183-
input_df = read_json(dataset, orient="records", lines=True)
184-
else:
167+
elif not isinstance(dataset, DataFrame):
185168
raise TypeError(f"invalid type of dataset: {type(dataset)}")
186169

187170
# this should never happen, but pylint is not smart enough to detect it
188171
if TYPE_CHECKING:
189172
assert input_df is not None
190173

191174
# ensure the dataset is in the format we expect it
192-
self._validate_dataset(input_df)
193-
194-
need_to_generate_questions = "response" not in input_df.columns
195-
if need_to_generate_questions:
196-
logger.debug(
197-
"`response` is missing in the input dataframe columns, generating questions from the model is required."
198-
)
199-
if not student_model or not student_openai_client:
200-
raise ValueError(
201-
"provided dataset doesn't contain the model `response`, but either `student_model` or `student_openai_client` wasn't provided for inference"
202-
)
203-
204-
# if the student model was provided then we always generate regardless
205-
if student_model:
206-
if not student_openai_client:
207-
raise ValueError(
208-
"`student_model` was specified but `student_openai_client` was not provided"
209-
)
210-
input_df = self._generate_answers_from_model(
211-
input_df, student_model, student_openai_client
175+
# this looks similar to validate_dataset but here we want an exact match, not a subset
176+
required_keys = {"user_input", "reference", "response"}
177+
columns = set(input_df.columns)
178+
if columns != required_keys:
179+
raise ValueError(
180+
f"Input Dataset can only have the following keys: {', '.join(required_keys)}. Keys provided were: {', '.join(columns)}"
212181
)
213182

183+
214184
if not run_config:
215185
# we set extreme timeout/retry values by default since OpenAI tier-1 rate limits
216186
# are horrible and will result in half of our evaluation results being NaN or 0
@@ -238,15 +208,25 @@ def run(
238208
)
239209
return results
240210

241-
def _generate_answers_from_model(
242-
self,
211+
@staticmethod
212+
def generate_answers_from_model(
243213
questions: DataFrame,
244-
student_model: ModelConfig,
245-
student_openai_client: OpenAIClient,
214+
model_config: ModelConfig,
215+
openai_client: OpenAIClient,
246216
) -> DataFrame:
247217
"""
248218
Given a DataFrame containing `user_input` columns, generates responses from the given model
249219
and returns a new DataFrame containing its answers in the `response` column.
220+
221+
Args:
222+
questions: (DataFrame):
223+
Questions and refernce answers to be returned with the responses from the model
224+
model_config: (ModelConfig):
225+
Configuration settings for the model when getting responses.
226+
openai_client (openai.Client | None, optional):
227+
The client to use when generating questions from the model, must be compatible with the OpenAI API.
228+
Returns:
229+
DataFrame with user_input, reference, and response columns. Responses for the user_input from the model
250230
"""
251231
# initialize response to write into
252232
updated_df = questions.copy()
@@ -256,17 +236,17 @@ def _generate_answers_from_model(
256236
messages: List[ChatCompletionMessageParam] = [
257237
{
258238
"role": "system",
259-
"content": student_model.system_prompt,
239+
"content": model_config.system_prompt,
260240
},
261241
{"role": "user", "content": qna["user_input"]},
262242
]
263-
response = student_openai_client.chat.completions.create(
243+
response = openai_client.chat.completions.create(
264244
messages=messages,
265-
model=student_model.model_name,
245+
model=model_config.model_name,
266246
# specify the seed so we can at least try to have some reproducibility when the clients support it
267-
seed=42,
268-
max_tokens=student_model.max_tokens,
269-
temperature=student_model.temperature,
247+
seed=model_config.seed,
248+
max_tokens=model_config.max_tokens,
249+
temperature=model_config.temperature,
270250
)
271251
updated_df.at[i, "response"] = response.choices[0].message.content
272252
return updated_df

0 commit comments

Comments
 (0)