Skip to content

Commit 4bb8f3e

Browse files
authored
Add "auto" dtype to RayLLMEngineArgs and some minor fixes (#89)
# What does this PR do? - Adds `auto` dtype to RayLLMEngineArgs. - Fixes a multiprocessing error with TACO - Adds ray.shutdown for ray backend to not interfere with multiprocessing code in scoring stage.
1 parent 4ef636c commit 4bb8f3e

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

skythought/evals/common/entities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ class RayLLMEngineArgs(BaseModel):
8383
gpu_memory_utilization: Optional[float] = Field(
8484
default=None, description="GPU memory utilization for the inference engine"
8585
)
86-
dtype: Optional[Literal["float32", "float16", "bfloat16", "float8"]] = Field(
87-
default=None, description="Data type for inference engine."
86+
dtype: Optional[Literal["float32", "float16", "bfloat16", "float8", "auto"]] = (
87+
Field(default=None, description="Data type for inference engine.")
8888
)
8989

9090
def get_ray_llm_config(self):

skythought/evals/inference_and_check.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def inference(
167167
# TODO: revisit the underlying issue and remove the deepcopy if possible
168168
responses = copy.deepcopy(responses)
169169
responses = sorted(responses, key=lambda x: x.index)
170+
# Cleanup ray session
171+
ray.shutdown()
170172
elif backend == Backend.OPENAI:
171173
llm = OpenAI(**backend_params.to_dict())
172174
assert isinstance(sampling_params.params, OpenAISamplingParams)

skythought/evals/tasks/taco/taco_handler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ def generate_prompt(self, problem):
5151
def check_correctness(self, problem, generation):
5252
TIME_OUT = 300
5353

54-
def _temp_run(problem, generation, debug, result):
55-
try:
56-
result.append(taco_run_test(problem, test=generation, debug=debug))
57-
except Exception as e:
58-
print(f"Error in _temp_run: {e}")
59-
6054
manager = Manager()
6155
result = manager.list()
6256
p = multiprocessing.Process(
@@ -106,3 +100,10 @@ def load_and_filter_dataset(
106100
)
107101

108102
return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
103+
104+
105+
def _temp_run(problem, generation, debug, result):
106+
try:
107+
result.append(taco_run_test(problem, test=generation, debug=debug))
108+
except Exception as e:
109+
print(f"Error in _temp_run: {e}")

0 commit comments

Comments
 (0)