Sync from v0.13
This commit is contained in:
193
tests/kernels/core/test_pos_encoding.py
Normal file
193
tests/kernels/core/test_pos_encoding.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
USE_KEY = [True, False]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads * head_size)
|
||||
|
||||
|
||||
# For testing sliced tensors
|
||||
def _get_padded_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size + 64)
|
||||
|
||||
|
||||
def _get_batch_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size)
|
||||
|
||||
|
||||
TENSORS_SHAPES_FN = [
|
||||
_get_batch_tensor_shape,
|
||||
_get_flat_tensor_shape,
|
||||
_get_padded_tensor_shape,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int | None,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# slice tensor if required, noop otherwise
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions, query, key)
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(
|
||||
out_query,
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query),
|
||||
)
|
||||
if use_key:
|
||||
torch.testing.assert_close(
|
||||
out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key),
|
||||
)
|
||||
else:
|
||||
assert ref_key is None and out_key is None, "expected returned key to be None"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
ROPE_THETAS = [10000, 1000000]
|
||||
ROPE_PARAMETERS = (
|
||||
{"rope_type": "default"},
|
||||
{"rope_type": "linear", "factor": (1,)},
|
||||
{"rope_type": "dynamic", "factor": 1},
|
||||
)
|
||||
settings = (
|
||||
HEAD_SIZES,
|
||||
ROTARY_DIMS,
|
||||
MAX_POSITIONS,
|
||||
ROPE_THETAS,
|
||||
IS_NEOX_STYLE,
|
||||
ROPE_PARAMETERS,
|
||||
DTYPES,
|
||||
)
|
||||
rope_setting_id_map: dict[str, int] = {}
|
||||
for setting in product(*settings):
|
||||
(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
)
|
||||
# different settings cannot share the same rope module
|
||||
assert id(rope) not in rope_setting_id_map.values()
|
||||
assert all(x.dtype == dtype for x in rope.buffers())
|
||||
assert all(x.dtype == dtype for x in rope.parameters())
|
||||
rope_setting_id_map[str(setting)] = id(rope)
|
||||
|
||||
for setting in product(*settings):
|
||||
(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
)
|
||||
# check if cache take effect
|
||||
assert id(rope) == rope_setting_id_map[str(setting)]
|
||||
Reference in New Issue
Block a user