Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [V1] TPU support #11936

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions examples/offline_inference/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams()#temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
7 changes: 5 additions & 2 deletions tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests", "--enforce-eager", "--max-num-seqs", "64"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked
Expand Down Expand Up @@ -61,12 +61,15 @@ def run_test(more_args):
)

measured_value = results["results"][TASK][FILTER]
print("measured_value = {}".format(measured_value))

assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def _cached_get_attn_backend(
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.PALLAS_VLLM_V1:
logger.info("Using Pallas backend.")
from vllm.v1.attention.backends.pallas import PallasAttentionBackendV1
return PallasAttentionBackendV1
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()

Expand Down
11 changes: 9 additions & 2 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
compilation_config.level = 2
# assert compilation_config.level < CompilationLevel.PIECEWISE,\
# "TPU does not support Inductor. compilation_config.level = {}".format(compilation_config.level)

if compilation_config.backend == "":
compilation_config.backend = "openxla"
Expand All @@ -72,3 +73,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"

@classmethod
def is_pin_memory_available(cls):
# TODO: Verify if it is indeed the case
logger.warning("Pin memory is not supported on TPU.")
return False
Loading
Loading