diff --git a/.github/workflows/catalog_consistency.yml b/.github/workflows/catalog_consistency.yml index 4b42a8843b..96eab4c6e5 100644 --- a/.github/workflows/catalog_consistency.yml +++ b/.github/workflows/catalog_consistency.yml @@ -30,7 +30,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - run: curl -LsSf https://astral.sh/uv/install.sh | sh - run: uv pip install --system -e ".[tests]" diff --git a/examples/inference_using_cross_provider.py b/examples/inference_using_cross_provider.py index 3abc5371c0..292059a272 100644 --- a/examples/inference_using_cross_provider.py +++ b/examples/inference_using_cross_provider.py @@ -2,7 +2,7 @@ from unitxt.text_utils import print_dict if __name__ == "__main__": - for provider in ["watsonx", "rits", "watsonx-sdk", "hf-local"]: + for provider in ["vllm", "watsonx", "rits", "watsonx-sdk", "hf-local"]: print() print("------------------------------------------------ ") print("PROVIDER:", provider) diff --git a/pyproject.toml b/pyproject.toml index 5a7db6c150..e9430c8ecc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,8 @@ tests = [ "sqlparse", "diskcache", "pydantic", - "jsonschema_rs" + "jsonschema_rs", + "vllm" ] ui = [ "gradio", @@ -128,7 +129,8 @@ inference-tests = [ "tenacity", "diskcache", "numpy==1.26.4", - "ollama" + "ollama", + "vllm" ] assistant = [ "streamlit", diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index c31d612a3a..164ce25b27 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -3014,26 +3014,30 @@ class VLLMParamsMixin(Artifact): model: str n: int = 1 best_of: Optional[int] = None - _real_n: Optional[int] = None - presence_penalty: float = 0.0 - frequency_penalty: float = 0.0 - repetition_penalty: float = 1.0 - temperature: float = 0.0 + temperature: float = 1.0 top_p: float = 1.0 - top_k: int = -1 + top_k: int = 0 min_p: float = 0.0 seed: Optional[int] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[List[int]] = None bad_words: Optional[List[str]] = None + include_stop_str_in_output: bool = False ignore_eos: bool = False max_tokens: Optional[int] = 16 min_tokens: int = 0 logprobs: Optional[int] = None prompt_logprobs: Optional[int] = None + detokenize: bool = True + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin): + _requirements_list: list = ["vllm"] label = "vllm" def get_engine_id(self): @@ -3047,7 +3051,6 @@ def prepare_engine(self): self.sampling_params = SamplingParams(**args) self.llm = LLM( model=self.model, - device="auto", trust_remote_code=True, max_num_batched_tokens=4096, gpu_memory_utilization=0.7, @@ -3231,6 +3234,7 @@ def get_return_object(self, responses, return_meta_data): "vertex-ai", "replicate", "hf-local", + "vllm", ] @@ -3477,6 +3481,7 @@ class CrossProviderInferenceEngine( provider_model_map["watsonx"] = { k: f"watsonx/{v}" for k, v in provider_model_map["watsonx-sdk"].items() } + provider_model_map["vllm"] = provider_model_map["hf-local"] _provider_to_base_class = { "watsonx": LiteLLMInferenceEngine, @@ -3490,12 +3495,14 @@ class CrossProviderInferenceEngine( "vertex-ai": LiteLLMInferenceEngine, "replicate": LiteLLMInferenceEngine, "hf-local": HFAutoModelInferenceEngine, + "vllm": VLLMInferenceEngine, } _provider_param_renaming = { "watsonx-sdk": {"model": "model_name"}, "rits": {"model": "model_name"}, "hf-local": {"model": "model_name", "max_tokens": "max_new_tokens"}, + "vllm": {"top_logprobs": "logprobs", "logprobs": "prompt_logprobs"}, } def get_return_object(self, **kwargs): diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 48261b8f01..64fc90b8f3 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -19,6 +19,7 @@ OptionSelectingByLogProbsInferenceEngine, RITSInferenceEngine, TextGenerationInferenceOutput, + VLLMInferenceEngine, WMLInferenceEngineChat, WMLInferenceEngineGeneration, ) @@ -189,6 +190,20 @@ def test_watsonx_chat_inference(self): self.assertListEqual(predictions, ["7", "2"]) + def test_vllm_chat_inference(self): + model = VLLMInferenceEngine( + model=local_decoder_model, + data_classification_policy=["public"], + temperature=0, + max_tokens=1, + ) + + dataset = get_text_dataset() + + predictions = model(dataset) + + self.assertListEqual(list(predictions), ["7", "1"]) + def test_watsonx_inference_with_external_client(self): from ibm_watsonx_ai.client import APIClient, Credentials @@ -279,7 +294,7 @@ def test_option_selecting_by_log_prob_inference_engines(self): ] watsonx_engine = WMLInferenceEngineGeneration( - model_name="meta-llama/llama-3-2-1b-instruct" + model_name="meta-llama/llama-3-3-70b-instruct" ) for engine in [watsonx_engine]: @@ -383,7 +398,7 @@ def test_lite_llm_inference_engine(self): def test_lite_llm_inference_engine_without_task_data_not_failing(self): LiteLLMInferenceEngine( - model="watsonx/meta-llama/llama-3-2-1b-instruct", + model="watsonx/meta-llama/llama-3-3-70b-instruct", max_tokens=2, temperature=0, top_p=1,