Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

0
tests/v1/tpu/__init__.py Normal file
View File

177
tests/v1/tpu/test_basic.py Normal file
View File

@@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A basic correctness check for TPUs
Run `pytest tests/v1/tpu/test_basic.py`.
"""
from typing import TYPE_CHECKING
import pytest
from torch_xla._internal import tpu
from vllm.platforms import current_platform
if TYPE_CHECKING:
from tests.conftest import VllmRunner
else:
VllmRunner = object
MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
# TODO: Enable this model when fixed.
# "Qwen/Qwen1.5-MoE-A2.7B",
# TODO: Enable this models with v6e
# "Qwen/Qwen2-7B-Instruct",
# "meta-llama/Llama-3.1-8B",
]
TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic(
vllm_runner: type[VllmRunner],
model: str,
max_tokens: int,
tensor_parallel_size: int,
max_num_seqs: int,
) -> None:
prompt = (
"The next numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with vllm_runner(
model,
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens=1024,
max_model_len=8192,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output or "0, 1" in output
@pytest.mark.skip(reason="Temporarily disabled due to timeout")
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
)
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("max_num_seqs", [16])
def test_phi3(
vllm_runner: type[VllmRunner],
max_tokens: int,
max_num_seqs: int,
) -> None:
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, by violating privacy",
" what is essential is love.",
" but in rising every time we fall.",
]
# test head dim = 96
model = "microsoft/Phi-3-mini-128k-instruct"
with vllm_runner(
model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text
TP_SIZE_8 = 8
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
@pytest.mark.skipif(
tpu.num_available_chips() < TP_SIZE_8,
reason=f"This test requires {TP_SIZE_8} TPU chips.",
)
def test_gemma3_27b_with_text_input_and_tp(
vllm_runner: type[VllmRunner],
) -> None:
model = "google/gemma-3-27b-it"
max_tokens = 16
tensor_parallel_size = TP_SIZE_8
max_num_seqs = 4
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
with vllm_runner(
model,
max_num_batched_tokens=256,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="This is a basic test for TPU only"
)
def test_w8a8_quantization(
vllm_runner: type[VllmRunner],
) -> None:
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
max_tokens = 5
tensor_parallel_size = 1
max_num_seqs = 4
prompt = (
"The next numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with vllm_runner(
model,
max_num_batched_tokens=64,
max_model_len=4096,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output or "0, 1" in output

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
import torch_xla
import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only")
@pytest.mark.parametrize("page_size", [32, 33])
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
def test_kv_cache_update_kernel(
page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int
):
page_num = 1000
padded_num_tokens = 128
kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu",
)
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu",
)
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array(
[
page_size * 2 - 7,
page_size * 2,
page_size * 3,
page_size * 4 + 6,
page_size * 5 + 7,
page_size * 6 + 8,
page_size * 15 + 3,
],
dtype=np.int32,
)
new_kv_cache_indices = np.concatenate(
[np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])]
)
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1
)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor(
[num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32
)
torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla,
slot_mapping_xla,
kv_cache_xla,
num_kv_update_slices_xla,
page_size,
num_slices_per_block,
)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens):
kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :]
assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4)

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test:
* Tests for MultiHeadAttention layer
"""
import pytest
import torch
import torch_xla
import torch_xla.core
import torch_xla.core.xla_model
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
def ref_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.matmul(attn_weights, value).transpose(1, 2)
return out
BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
def test_mha_attn_forward(
batch_size: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
device: str,
):
current_platform.seed_everything(0)
# These are expected to be f32
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
q = q.reshape(batch_size, seq_len, num_heads, head_size)
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = ref_attention(
q,
k,
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
# torch_xla flash_attn kernel is less accurate but much faster
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai
import pytest
from vllm.multimodal.utils import encode_image_base64
from vllm.platforms import current_platform
from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="session")
def base64_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_asset: encode_image_base64(
local_asset_server.get_image_asset(image_asset)
)
for image_asset in TEST_IMAGE_ASSETS
}
@pytest.mark.asyncio
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str]):
pytest.skip("Skip this test until it's fixed.")
def whats_in_this_image_msg(b64):
return [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
},
],
}
]
server_args = [
"--max-model-len",
"1024",
"--max-num-seqs",
"16",
"--gpu-memory-utilization",
"0.95",
"--trust-remote-code",
"--max-num-batched-tokens",
"576",
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input",
]
# Server will pre-compile on first startup (takes a long time).
with RemoteOpenAIServer(
model_name, server_args, max_wait_seconds=600
) as remote_server:
client: openai.AsyncOpenAI = remote_server.get_async_client()
# Other requests now should be much faster
for image_url in TEST_IMAGE_ASSETS:
image_base64 = base64_encoded_image[image_url]
chat_completion_from_base64 = await client.chat.completions.create(
model=model_name,
messages=whats_in_this_image_msg(image_base64),
max_completion_tokens=24,
temperature=0.0,
)
result = chat_completion_from_base64
assert result
choice = result.choices[0]
assert choice.finish_reason == "length"
message = choice.message
message = result.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"

100
tests/v1/tpu/test_pallas.py Normal file
View File

@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import ANY, patch
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata
def test_ragged_paged_attention():
# We verify that the kernel inputs such as sliding_window, etc. are passed
# in from the model correctly.
# The correctness of the paged attention kernel is tested in the kernel
# library.
num_heads = 4
head_size = 128
scale = 1.0
num_kv_heads = 4
sliding_window = 128
logits_soft_cap = 50.0
attn_impl = PallasAttentionBackendImpl(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=sliding_window,
kv_cache_dtype="auto",
logits_soft_cap=logits_soft_cap,
attn_type=AttentionType.DECODER,
)
class FakeAttentionLayer:
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
layer = FakeAttentionLayer()
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
num_tokens = 16
num_blocks = 1024
block_size = 16
query = torch.zeros(num_tokens, num_heads * head_size)
key = torch.zeros(num_tokens, num_kv_heads * head_size)
value = torch.zeros(num_tokens, num_kv_heads * head_size)
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
max_num_reqs = 8
max_num_blocks_per_req = 8
num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
block_tables = torch.zeros(
(max_num_reqs, max_num_blocks_per_req), dtype=torch.int32
)
context_lens = torch.ones((max_num_reqs,), dtype=torch.int32)
query_lens = [1] * max_num_reqs
query_start_loc = torch.cumsum(
torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
)
num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=8,
)
with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention:
attn_impl.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
mock_ragged_paged_attention.assert_called_once_with(
ANY, # query
ANY, # kv_cache
ANY, # context_lens
ANY, # block_tables
ANY, # query_start_loc
ANY, # num_seqs
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=scale,
sliding_window=sliding_window,
soft_cap=logits_soft_cap,
k_scale=1.0,
v_scale=1.0,
)

150
tests/v1/tpu/test_perf.py Normal file
View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A basic performance regression test for TPUs
Run `pytest tests/v1/tpu/test_perf.py`.
"""
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import pytest
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import get_tokenizer
if TYPE_CHECKING:
from tests.conftest import VllmRunner
else:
VllmRunner = object
@dataclass
class TestParams:
model: str
num_prompts: int
prefix_len: int
decode_len: int
expected_avg_time: float
err_tol: float
TEST_PARAMS = [
# TODO: Cannot run a series of tests because:
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
# Couldn't open iommu group /dev/vfio/0
# => Investigate
# TestParams(
# model="Qwen/Qwen2.5-1.5B-Instruct",
# num_prompts=1,
# prefix_len=10,
# decode_len=5,
# expected_avg_time=0.03,
# err_tol=0.01,
# ),
# TestParams(
# model="Qwen/Qwen2.5-1.5B-Instruct",
# num_prompts=10,
# prefix_len=100,
# decode_len=50,
# expected_avg_time=0.234,
# err_tol=0.020,
# ),
TestParams(
model="Qwen/Qwen2.5-1.5B-Instruct",
num_prompts=64,
prefix_len=500,
decode_len=50,
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
# tpu: v5lite (old vllm CI/CD)
# expected_avg_time=1.4,
# err_tol=0.30,
# (This is the active CI/CD instance)
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
# tpu: v6e (current vllm CI/CD)
expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH=
err_tol=0.20,
),
]
NUM_WARMUPS = 5
NUM_RUNS = 10
MAX_MODEL_LEN = 1024
MAX_NUM_SEQS = 32
GPU_UTIL = 0.9
@pytest.mark.skipif(
not current_platform.is_tpu(),
reason="This is a basic performance test for TPU only",
)
@pytest.mark.parametrize("params", TEST_PARAMS)
def test_perf(
vllm_runner: type[VllmRunner],
params: TestParams,
) -> None:
tokenizer = get_tokenizer(
params.model, tokenizer_mode="auto", trust_remote_code=True
)
prompts = []
for i in range(params.num_prompts):
prefix_token_ids = np.random.randint(
0, tokenizer.vocab_size, size=params.prefix_len
).tolist()
prompt = tokenizer.decode(prefix_token_ids)
prompts.append(prompt)
print(
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
len(prompts), params.prefix_len, params.decode_len
)
)
sampling_params = SamplingParams(
max_tokens=params.decode_len, temperature=1.0, min_p=0.0
)
with vllm_runner(
params.model,
max_num_batched_tokens=MAX_MODEL_LEN,
max_model_len=MAX_MODEL_LEN,
max_num_seqs=MAX_NUM_SEQS,
gpu_memory_utilization=GPU_UTIL,
enforce_eager=False,
tensor_parallel_size=1,
) as vllm_model:
print(" -- Warmup / Compile")
for i in range(NUM_WARMUPS):
_ = vllm_model.generate(prompts, sampling_params)
print(" -- Benchmarking... ")
times = []
for i in range(NUM_RUNS):
start_time = time.time()
_ = vllm_model.generate(prompts, sampling_params)
times.append(time.time() - start_time)
avg_time = sum(times) / len(times)
print(" -- avg_time = {}".format(avg_time))
print(
" -- expected_avg_time = {} with err_tol = {}".format(
params.expected_avg_time, params.err_tol
)
)
diff = avg_time - params.expected_avg_time
ok = diff < params.err_tol
if diff < -params.err_tol:
print(
" !! WARNING !! Performance has improved by {}, "
"it may be necessary to fine-tune the "
"expected_avg_time = {}".format(-diff, params.expected_avg_time)
)
assert ok, " !! ERROR !! Regression detected"

View File

@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
from vllm import LLM
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
def test_sampler_different(model_name: str):
"""
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=512,
max_num_batched_tokens=256,
)
prompts = ["Write a short story about a robot that dreams for the first time."]
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
output = llm.generate(prompts, sampling_params)
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text
with pytest.raises(ValueError):
# Unsupported `seed` param.
sampling_params = SamplingParams(temperature=0.3, seed=42)
output2 = llm.generate(prompts, sampling_params)
# Batch-case with TopK/P
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(
temperature=0.1,
min_p=0.8,
max_tokens=64,
# Vary number of ks
top_k=random.randint(4, 12),
top_p=random.random(),
)
for _ in range(B)
]
# Make sure first two reqs have the same K/P
sampling_params[0] = sampling_params[1]
output = llm.generate(p, sampling_params)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20]
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank))
# Can contain the sampled token
assert len(step) == expected_num or len(step) == expected_num + 1
# Check results are ordered by prob value
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1
llm = LLM(
model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128,
)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4)
topkp_sampling_params = SamplingParams(
temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5
)
for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import tempfile
import numpy as np
import pytest
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tpu import TPUModelLoader
def _setup_environment(model):
engine_args = EngineArgs(
model=model,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo",
)
# Under single worker mode, full model is init first and then
# partitioned using GSPMD.
ensure_model_parallel_initialized(1, 1)
return vllm_config
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
return MESH
@pytest.mark.parametrize(
"model",
[
"Qwen/Qwen2-1.5B-Instruct",
# Skip large models due to CI runner disk space limitations
# "meta-llama/Llama-3.1-8B-Instruct",
# "meta-llama/Llama-3.1-70B-Instruct",
],
)
def test_tpu_model_loader(model):
# Skip the 70B test if there are less than 8 chips
# TODO: Query using torch xla API, the query API is not working
# with SPMD now. However, This test is running under SPMD mode.
if "70B" in model and xr.global_runtime_device_count() < 8:
pytest.skip(
"Skipping 70B model if the TPU VM has less than 8 chips to \
avoid OOM."
)
vllm_config = _setup_environment(model)
loader = TPUModelLoader(load_config=vllm_config.load_config)
mesh = _get_spmd_mesh()
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
del model
gc.collect()

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
import torch_xla
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topk_equivalence_to_native_impl():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
# Random top-k values between 1 and 10.
k = torch.randint(1, 10, (BATCH_SIZE,))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
assert torch.allclose(result_native, result_tpu)
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
probs = logits.softmax(dim=-1)
# Random top-p values between 0 and 1.
p = torch.rand((BATCH_SIZE,))
# Set p=1 for ~50% of requests in the batch (top-p disabled).
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
no_op_k = torch.tensor([VOCAB_SIZE])
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
# Verify that the masked logit's probability sums to at least p.
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
torch_xla.sync()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
def test_topp_basic():
with torch.device(xm.xla_device()):
logits = torch.tensor(
[
[math.log(0.2), math.log(0.3), math.log(0.5)],
[math.log(0.5), math.log(0.1), math.log(0.4)],
]
)
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
)
torch_xla.sync()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_topp_select_all():
with torch.device(xm.xla_device()):
logits = torch.tensor(
[
[math.log(0.2), math.log(0.3), math.log(0.5)],
[math.log(0.5), math.log(0.1), math.log(0.4)],
]
)
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
)
torch_xla.sync()
assert torch.allclose(logits.cpu(), result.cpu())
def test_topp_with_ties():
with torch.device(xm.xla_device()):
# Input has multiple math.log(0.3).
logits = torch.tensor(
[[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
)
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
)
torch_xla.sync()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result = logits.clone().cpu()
expected_result[0, 3] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_both_topk_topp():
with torch.device(xm.xla_device()):
logits = torch.tensor(
[
[math.log(0.2), math.log(0.3), math.log(0.5)],
[math.log(0.5), math.log(0.1), math.log(0.4)],
]
)
# Set k=1 for the first batch.
result = apply_top_k_top_p_tpu(
logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
)
torch_xla.sync()
# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[0, 1] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.
Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import pytest
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod
from vllm.platforms import current_platform
from ...models.registry import HF_EXAMPLE_MODELS
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
@pytest.mark.skipif(
not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize(
"hf_overrides",
[
# w8a8 dynamic activation
{
"quantization_config": {
"quant_method": "tpu_int8",
"activation_scheme": "dynamic",
}
}
],
)
def test_model_tpu_int8(
vllm_runner,
model: str,
dtype: str,
max_tokens: int,
hf_overrides: dict,
monkeypatch,
) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")
activation_scheme = hf_overrides.get("quantization_config", {}).get(
"activation_scheme"
)
quantize_activation = activation_scheme == "dynamic"
# Allows using apply_model
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
# Prevent error from re-initializing cache
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
prompts = [
"A robot may not injure a human being",
]
answers = [
"or kill a human being",
]
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
def check_model(model):
for name, module in model.named_modules():
if not isinstance(module, LinearBase):
continue
quant_method = module.quant_method
assert isinstance(quant_method, TPUInt8LinearMethod)
assert quant_method.quantize_activation == quantize_activation
vllm.apply_model(check_model)
outputs = vllm.generate_greedy(prompts, max_tokens)
for (_, output), answer in zip(outputs, answers):
assert answer in output

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
import numpy as np
import pytest
import torch
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.linear import QKVParallelLinear
@pytest.fixture(autouse=True)
def setup_environment():
# This is a fake config used for init dist env.
# QKVParallelLinear needs dist env to be initialized.
engine_args = EngineArgs(
model="Qwen/Qwen2-1.5B-Instruct",
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo",
)
ensure_model_parallel_initialized(1, 1)
yield
MESH = None
def _get_spmd_mesh():
global MESH
if MESH is None:
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
return MESH
@pytest.mark.parametrize("bias", [False, True])
# `xr.use_spmd()` will set a global state, and this state is not reversible.
# Therefore, non-SPMD tests should be run before SPMD tests.
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
@pytest.mark.parametrize("device", ["cpu", "xla"])
@torch.no_grad()
def test_xla_qkv_linear(bias, mesh, device):
torch.manual_seed(123)
qkv_linear = QKVParallelLinear(
hidden_size=4096,
head_size=128,
total_num_heads=32,
total_num_kv_heads=8,
bias=bias,
params_dtype=torch.bfloat16,
return_bias=False,
)
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
if bias:
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
qkv_linear = qkv_linear.to(device)
xla_qkv_linear = xla_qkv_linear.to(device)
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
input_tensor = input_tensor.to(device)
output = qkv_linear(input_tensor)
xla_output = xla_qkv_linear(input_tensor)
assert torch.allclose(output.cpu(), xla_output.cpu())

View File

View File

@@ -0,0 +1,587 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.attention.layer import Attention
from vllm.config import (
CacheConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils.mem_constants import GiB_bytes
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.worker.tpu_model_runner import (
TPUModelRunner,
_get_padded_num_reqs_with_upper_limit,
_get_padded_token_len,
_get_req_paddings,
_get_token_paddings,
)
def get_vllm_config():
model_config = ModelConfig(
model="facebook/opt-125m",
dtype="bfloat16", # TPUs typically use bfloat16
seed=42,
)
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)
return vllm_config
def get_model_runner(vllm_config):
device = "xla:0" # Mocking TPU device
return TPUModelRunner(vllm_config, device)
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
vllm_config = get_vllm_config()
return get_model_runner(vllm_config)
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = []
num_scheduled_tokens = {}
total_num_scheduled_tokens = 0
for req_id in req_ids:
new_reqs.append(
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_features=[],
sampling_params=SamplingParams(),
pooling_params=PoolingParams(),
block_ids=([0],), # block_ids should be tuple[list[int]]
num_computed_tokens=0,
lora_request=None,
)
)
num_scheduled_tokens[req_id] = 3
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
def _is_req_scheduled(model_runner, req_id: str) -> bool:
return req_id in model_runner.input_batch.req_id_to_index
def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
"""Check if the request state block IDs match the block table.
This function handles both legacy BlockTable and new MultiGroupBlockTable
structures for backward compatibility.
"""
req_index = model_runner.input_batch.req_id_to_index[req_id]
multi_group_block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
# Access the first block table from MultiGroupBlockTable
# This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0]
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
# Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list):
# New format: tuple[list[int], ...] - extract first group
req_block_ids = req_state.block_ids[0]
else:
# Legacy format: list[int] - use directly
req_block_ids = req_state.block_ids
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
block_table_values = block_table.block_table.np[req_index, :num_blocks]
return (block_table_values == req_block_ids).all()
def test_update_states_new_request(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
# resume req
cached_req_data = CachedRequestData(
req_ids=[req_id],
resumed_req_ids={req_id},
new_token_ids=[[]],
all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids},
new_block_ids=[([],)],
num_computed_tokens=[0],
num_output_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_no_changes(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
req_ids = ("req_0", "req_1")
# new reqs
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert _is_req_scheduled(model_runner, req_ids[1])
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])
def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 256, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert actual_paddings == expected_paddings
def test_get_padded_token_len():
min_token_size, max_token_size, padding_gap = 16, 512, 64
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert _get_padded_token_len(paddings, 1) == 16
assert _get_padded_token_len(paddings, 16) == 16
assert _get_padded_token_len(paddings, 20) == 32
assert _get_padded_token_len(paddings, 300) == 320
assert _get_padded_token_len(paddings, 512) == 512
def test_get_padded_num_reqs_with_upper_limit():
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
def test_get_req_paddings():
assert _get_req_paddings(1, 32) == [8, 16, 32]
assert _get_req_paddings(8, 32) == [8, 16, 32]
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
vllm_config = model_runner.vllm_config
with (
pytest.raises(ValueError, match=error_msg),
set_current_vllm_config(vllm_config),
):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
kv_sharing_target_layer_name=layer_1,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
),
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
vllm_config = model_runner.vllm_config
with (
pytest.raises(ValueError, match=error_msg),
set_current_vllm_config(vllm_config),
):
fwd_context = {
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name=invalid_layer,
),
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
vllm_config = model_runner.vllm_config
with (
pytest.raises(ValueError, match=error_msg),
set_current_vllm_config(vllm_config),
):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name=layer_1,
),
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
),
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 1_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0
available_memory = 20 * GiB_bytes
# page size for each layer KV can be calculated as
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
# max_context_len = available_memory / (page_size / block_size) / num_caches
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
assert max_context_len == 655360
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 512kb)
kv_cache_config.num_blocks = 1
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = kv_cache_spec[
kv_cache_tensor.shared_by[0]
].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_0,
),
layer_1: Attention(
num_heads=8,
head_size=128,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
),
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 512KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 2 * 20480 # 20GB / 512KB
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == (2 * 655360)
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (512kb)
kv_cache_config.num_blocks = 1
kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
vllm_config = get_vllm_config()
vllm_config.model_config.max_model_len = 32000
vllm_config.scheduler_config.max_num_seqs = 1200
model_runner = get_model_runner(vllm_config)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert model_runner.num_reqs_most_model_len == 1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert model_runner.num_reqs_max_model_len == 524