Skip to content

Commit 323ffd0

Browse files
committed
rework dependencies; fix openai stuff
Signed-off-by: SumanthRH <[email protected]>
1 parent 8e3daad commit 323ffd0

File tree

4 files changed

+15
-36
lines changed

4 files changed

+15
-36
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
2-
name = "skythought"
2+
name = "skythought_evals"
33
version = "0.1.0"
4-
description = "Skythought Evals"
4+
description = "Skythought Evals: Evaluation and Data Generation Tools for Reasoning Models"
55
authors = [
66
{ name = "NovaSky Team"}
77
]

setup.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

skythought/skythought_evals/common/entities.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from enum import Enum
44
from importlib import resources
55
from pathlib import Path
6-
from typing import Any, Dict, Literal, Optional, Union
6+
from typing import Literal, Optional, Union
77

88
import yaml
9+
from openai import NOT_GIVEN, NotGiven
10+
from openai.types.chat import ChatCompletionReasoningEffort
911
from pydantic import BaseModel, ConfigDict, Field
1012
from vllm import SamplingParams as VLLMSamplingParams
1113

@@ -21,18 +23,20 @@ class Backend(str, Enum):
2123

2224

2325
class OpenAISamplingParams(BaseModel):
26+
model_config = ConfigDict(arbitrary_types_allowed=True)
27+
2428
temperature: float = TEMPERATURE_DEFAULT
2529
top_p: float = TOP_P_DEFAULT
2630
n: int = 1
2731
max_tokens: int = MAX_TOKENS_DEFAULT
28-
reasoning_effort: Optional[float] = None
32+
reasoning_effort: Union[ChatCompletionReasoningEffort, NotGiven] = NOT_GIVEN
2933
frequency_penalty: Optional[float] = None
3034

3135

3236
class SamplingParameters(BaseModel):
3337
model_config = ConfigDict(arbitrary_types_allowed=True)
3438

35-
params: Union[Dict[str, Any], OpenAISamplingParams, VLLMSamplingParams]
39+
params: Union[OpenAISamplingParams, VLLMSamplingParams]
3640

3741
@classmethod
3842
def from_dict(cls, backend: Backend, params: dict):

skythought/skythought_evals/inference_and_check.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,29 +77,26 @@ def fetch_response_openai(
7777
model_name = model_config.name
7878
# Ensure model_name has been resolved to a string
7979
assert model_name
80-
if "o1" in model_name:
80+
if model_name.startswith("o1") or model_name.startswith("o3"):
8181
# O1 doesn't support system prompt
8282
# NOTE: might want to implement this inside handler instead
8383
for p in prompt:
8484
p["role"] = "user"
85-
8685
response = client.chat.completions.create(
8786
model=model_config.model_id,
8887
messages=prompt,
8988
n=sampling_params.n,
90-
temperature=sampling_params.temperature,
91-
max_tokens=sampling_params.max_tokens,
9289
reasoning_effort=sampling_params.reasoning_effort,
93-
frequency_penalty=sampling_params.frequency_penalty,
9490
max_completion_tokens=sampling_params.max_tokens,
9591
)
9692
else:
93+
if sampling_params.reasoning_effort is not None:
94+
raise ValueError("Reasoning effort is only supported for reasoning models")
9795
response = client.chat.completions.create(
9896
model=model_config.model_id,
9997
messages=prompt,
10098
n=sampling_params.n,
10199
temperature=sampling_params.temperature,
102-
max_tokens=sampling_params.max_tokens,
103100
frequency_penalty=sampling_params.frequency_penalty,
104101
max_completion_tokens=sampling_params.max_tokens,
105102
)
@@ -170,12 +167,13 @@ def inference(
170167
responses = copy.deepcopy(responses)
171168
responses = sorted(responses, key=lambda x: x.index)
172169
elif backend == Backend.OPENAI:
173-
llm = OpenAI(**backend_params)
170+
llm = OpenAI(**backend_params.to_dict())
171+
assert isinstance(sampling_params.params, OpenAISamplingParams)
174172
fetch_partial = partial(
175173
fetch_response_openai,
176174
llm,
177175
model_config,
178-
sampling_params,
176+
sampling_params.params,
179177
)
180178

181179
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e:

0 commit comments

Comments
 (0)