Sync from v0.13
This commit is contained in:
0
tests/kernels/__init__.py
Normal file
0
tests/kernels/__init__.py
Normal file
@@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
# Reference default values of atol and rtol are from
|
||||
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
||||
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
||||
default_rtol = {
|
||||
torch.float16: 1e-3,
|
||||
torch.bfloat16: 1.6e-2,
|
||||
torch.float: 1.3e-6
|
||||
}
|
||||
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
|
||||
|
||||
|
||||
def get_default_atol(output) -> float:
|
||||
|
||||
19
tests/kernels/attention/conftest.py
Normal file
19
tests/kernels/attention/conftest.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.torch_utils import (
|
||||
create_kv_caches_with_random,
|
||||
create_kv_caches_with_random_flash,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory():
|
||||
return create_kv_caches_with_random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory_flashinfer():
|
||||
return create_kv_caches_with_random_flash
|
||||
190
tests/kernels/attention/test_aiter_flash_attn.py
Normal file
190
tests/kernels/attention/test_aiter_flash_attn.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
|
||||
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len]
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len]
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only ROCm is supported")
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(10, 1328), (5, 18), (129, 463)], [(8, 523), (24, 37), (3, 2011)]]
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
if not is_flash_attn_varlen_func_available():
|
||||
pytest.skip("flash_attn_varlen_func required to run this test.")
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
|
||||
maybe_quantized_query = query
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
torch.ops.vllm.flash_attn_varlen_func(
|
||||
maybe_quantized_query,
|
||||
maybe_quantized_key_cache,
|
||||
maybe_quantized_value_cache,
|
||||
out=output,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
alibi_slopes=None,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
atol, rtol = 2e-2, 2e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
457
tests/kernels/attention/test_attention.py
Normal file
457
tests/kernels/attention/test_attention.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
|
||||
# This should be sync with get_supported_head_sizes() in
|
||||
# vllm.attention.ops.paged_attn.PagedAttention
|
||||
HEAD_SIZES = [32, 80, 128, 256]
|
||||
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
|
||||
def ref_single_query_cached_kv_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
num_queries_per_kv: int,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
block_tables_lst = block_tables.cpu().tolist()
|
||||
seq_lens_lst = seq_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables_lst[i]
|
||||
seq_len = int(seq_lens_lst[i])
|
||||
|
||||
keys_lst: list[torch.Tensor] = []
|
||||
values_lst: list[torch.Tensor] = []
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_kv_heads, head_size)
|
||||
keys_lst.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values_lst.append(v)
|
||||
keys = torch.stack(keys_lst, dim=0)
|
||||
values = torch.stack(values_lst, dim=0)
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(seq_len).int()
|
||||
alibi_bias = (position_ids - seq_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
|
||||
)
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if (kv_cache_dtype == "fp8" and head_size % 16) or (
|
||||
version == "rocm" and head_size not in (64, 128)
|
||||
):
|
||||
pytest.skip()
|
||||
|
||||
if (
|
||||
version == "rocm"
|
||||
and current_platform.is_navi()
|
||||
and (
|
||||
kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
|
||||
)
|
||||
):
|
||||
pytest.skip()
|
||||
|
||||
global PARTITION_SIZE
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
|
||||
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v1":
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.paged_attention_v1,
|
||||
(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
64,
|
||||
0,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
elif version in ("v2", "rocm"):
|
||||
if current_platform.is_rocm() and version == "rocm":
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.paged_attention_v2,
|
||||
(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
64,
|
||||
0,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._rocm_C.paged_attention,
|
||||
(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||
dequantized_key_cache = torch.empty(
|
||||
size=key_cache_shape, dtype=dtype, device=device
|
||||
)
|
||||
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(
|
||||
size=value_cache_shape, dtype=dtype, device=device
|
||||
)
|
||||
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: list[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_bias: list[torch.Tensor] | None,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs: list[torch.Tensor] = []
|
||||
if alibi_bias:
|
||||
assert len(alibi_bias) == num_seqs
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask. ALiBi already includes a tril causal mask.
|
||||
if alibi_bias:
|
||||
attn_mask = alibi_bias[i]
|
||||
else:
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
|
||||
)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
||||
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
||||
head_size = 64
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_heads = 16
|
||||
num_kv_heads = 5
|
||||
with pytest.raises(AssertionError):
|
||||
_ = attention_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
291
tests/kernels/attention/test_attention_selector.py
Normal file
291
tests/kernels/attention/test_attention_selector.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": [
|
||||
"TRITON_MLA",
|
||||
"FLASHMLA",
|
||||
"FLASHINFER_MLA",
|
||||
"FLASH_ATTN_MLA",
|
||||
"CUTLASS_MLA",
|
||||
],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
}
|
||||
|
||||
DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||
"cuda": ["FLASHINFER", "FLASH_ATTN"],
|
||||
"hip": ["ROCM_ATTN"],
|
||||
"cpu": ["CPU_ATTN"],
|
||||
}
|
||||
|
||||
DEVICE_MLA_BLOCK_SIZES = {
|
||||
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
||||
"hip": [16, 1], # HIP requires special handling for block_size=1
|
||||
# "cpu": [16] # CPU uses fixed block size from test cases
|
||||
"cpu": [], # FIXME(woosuk): Temporarily disable CPU tests
|
||||
}
|
||||
|
||||
|
||||
def generate_params():
|
||||
is_rocm = current_platform.is_rocm()
|
||||
params = []
|
||||
device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
|
||||
for use_mla in [True, False]:
|
||||
for device in device_list:
|
||||
backends = (
|
||||
DEVICE_MLA_BACKENDS[device]
|
||||
if use_mla
|
||||
else DEVICE_REGULAR_ATTN_BACKENDS[device]
|
||||
)
|
||||
for name in backends:
|
||||
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
|
||||
for block_size in block_sizes:
|
||||
params.append(
|
||||
pytest.param(
|
||||
device,
|
||||
name,
|
||||
use_mla,
|
||||
block_size,
|
||||
id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
|
||||
)
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
||||
def test_env(
|
||||
device: str,
|
||||
name: str,
|
||||
use_mla: bool,
|
||||
block_size: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.platforms.current_platform", RocmPlatform()):
|
||||
if use_mla:
|
||||
# ROCm MLA backend logic:
|
||||
# - TRITON_MLA: supported when block_size != 1
|
||||
# - ROCM_AITER_MLA: supported when block_size == 1
|
||||
# If backend is forced but doesn't match block_size,
|
||||
# should raise ValueError
|
||||
|
||||
if name == "TRITON_MLA" and block_size == 1:
|
||||
# TRITON_MLA doesn't support block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
assert f"The selected backend, {name}" in str(exc_info.value)
|
||||
else:
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "ROCM_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if use_mla:
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.x), V1 only
|
||||
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||
# (SM 10.x), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
|
||||
if name == "CUTLASS_MLA":
|
||||
if block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip("CUTLASS_MLA only supports block_size 128")
|
||||
if capability[0] != 10:
|
||||
pytest.skip("CUTLASS MLA is not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER_MLA":
|
||||
if capability[0] != 10:
|
||||
pytest.skip(
|
||||
"FlashInfer MLA is not supported on this platform"
|
||||
)
|
||||
if block_size not in [32, 64]:
|
||||
# FlashInfer MLA only supports block_size 32 or 64
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supports block_size 32 or 64"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
576,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
)
|
||||
|
||||
if not flash_attn_supports_mla():
|
||||
pytest.skip(
|
||||
"FlashAttention MLA not supported on this platform"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(
|
||||
64, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN":
|
||||
backend = get_attn_backend(
|
||||
32, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_fp32_fallback(device: str):
|
||||
"""Test attention backend selection with fp32."""
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
pytest.skip(
|
||||
"Skipping as current backend selector does not "
|
||||
"handle fallbacks when a backend is set via env var."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
|
||||
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Reset the monkeypatch for subsequent tests
|
||||
monkeypatch.undo()
|
||||
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
|
||||
original_module = sys.modules.get("vllm_flash_attn")
|
||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Restore the original module if it existed
|
||||
if original_module is not None:
|
||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
|
||||
else:
|
||||
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that invalid attention backend names raise ValueError."""
|
||||
with (
|
||||
monkeypatch.context() as m,
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||
):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(32, torch.float16, None, 16)
|
||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
||||
1141
tests/kernels/attention/test_cache.py
Normal file
1141
tests/kernels/attention/test_cache.py
Normal file
File diff suppressed because it is too large
Load Diff
186
tests/kernels/attention/test_cascade_flash_attn.py
Executable file
186
tests/kernels/attention/test_cascade_flash_attn.py
Executable file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"vllm_flash_attn is not supported for vLLM on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 192, 256]
|
||||
BLOCK_SIZES = [16]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 39, 16912])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_kernel(
|
||||
num_tokens: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
|
||||
# Prepare inputs.
|
||||
prefix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
suffix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
|
||||
suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
|
||||
|
||||
# Run the kernel.
|
||||
output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
|
||||
|
||||
# Reference implementation.
|
||||
max_lse = torch.maximum(prefix_lse, suffix_lse)
|
||||
p_lse = torch.exp(prefix_lse - max_lse)
|
||||
s_lse = torch.exp(suffix_lse - max_lse)
|
||||
p_scale = p_lse / (p_lse + s_lse)
|
||||
s_scale = s_lse / (p_lse + s_lse)
|
||||
p_scale = p_scale.transpose(0, 1).unsqueeze(2)
|
||||
s_scale = s_scale.transpose(0, 1).unsqueeze(2)
|
||||
ref_output = p_scale * prefix_output + s_scale * suffix_output
|
||||
ref_output = ref_output.to(dtype)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
CASES = [
|
||||
# Case 1. A general case.
|
||||
([(129, 871), (18, 280), (37, 988), (1023, 2304), (1, 257)], 256),
|
||||
# Case 2. Flash-decoding case.
|
||||
([(1, 1023), (1, 879), (1, 778), (1, 1777)] * 100, 512),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens_and_common_prefix", CASES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50])
|
||||
@pytest.mark.parametrize("num_blocks", [2048])
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@torch.inference_mode()
|
||||
def test_cascade(
|
||||
seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(
|
||||
f"Flash attention version {fa_version} not supported due "
|
||||
f'to: "{fa_version_unsupported_reason(fa_version)}"'
|
||||
)
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
window_size = (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
|
||||
seq_lens, common_prefix_len = seq_lens_and_common_prefix
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
|
||||
total_num_query_tokens = sum(query_lens)
|
||||
query = torch.randn(total_num_query_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
assert common_prefix_len > 0
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
# Make sure the first `num_common_kv_blocks` blocks are the same.
|
||||
block_tables[:, :num_common_kv_blocks] = block_tables[0, :num_common_kv_blocks]
|
||||
|
||||
# Run the regular attention.
|
||||
ref_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens_tensor,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
# Run cascade attention.
|
||||
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
|
||||
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
|
||||
suffix_kv_lens = kv_lens_tensor - common_prefix_len
|
||||
output = torch.empty_like(query)
|
||||
cascade_attention(
|
||||
output=output,
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
cu_query_lens=cu_query_lens,
|
||||
max_query_len=max_query_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
max_kv_len=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
alibi_slopes=None,
|
||||
sliding_window=window_size,
|
||||
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
||||
block_table=block_tables,
|
||||
common_prefix_len=common_prefix_len,
|
||||
max_num_splits=0, # no max
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
628
tests/kernels/attention/test_cpu_attn.py
Normal file
628
tests/kernels/attention/test_cpu_attn.py
Normal file
@@ -0,0 +1,628 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cpu_attention_with_kv_cache,
|
||||
cpu_attn_get_scheduler_metadata,
|
||||
cpu_attn_reshape_and_cache,
|
||||
)
|
||||
|
||||
NUM_HEADS = [
|
||||
(4, 4),
|
||||
(8, 2),
|
||||
(9, 3),
|
||||
]
|
||||
HEAD_SIZES = [96, 128]
|
||||
QTYPES = [torch.bfloat16, torch.half, torch.float32]
|
||||
SLIDING_WINDOWS = [None, 256]
|
||||
NUM_BLOCKS = [
|
||||
1024,
|
||||
]
|
||||
SEQ_LENS = [ # (q_len, kv_len)
|
||||
[(1, 213), (1, 1), (1, 312), (1, 7), (1, 7812)], # decode batch
|
||||
[(2345, 2345), (5, 5), (3, 16), (134, 5131)], # prefill batch
|
||||
[(992, 2456), (1, 1234), (98, 1145), (1, 4162), (2345, 2345)], # mixed batch
|
||||
]
|
||||
|
||||
|
||||
def get_attn_isa(
|
||||
block_size: int | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
if block_size and dtype:
|
||||
return _get_attn_isa(dtype, block_size)
|
||||
else:
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
return "neon"
|
||||
elif torch._C._cpu._is_amx_tile_supported():
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
|
||||
|
||||
# rand number generation takes too much time, cache rand tensors
|
||||
@functools.lru_cache(maxsize=128, typed=False)
|
||||
def tensor_cache(
|
||||
elem_num: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.randn(elem_num, dtype=dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(
|
||||
closest_power_of_2, total_num_heads - closest_power_of_2
|
||||
)
|
||||
extra_powers = torch.arange(
|
||||
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
|
||||
)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes.float()
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
dtype = query.dtype
|
||||
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = alibi_slopes[:, None, None]
|
||||
|
||||
if s_aux is not None:
|
||||
s_aux = s_aux.float()
|
||||
s_aux = s_aux[:, None, None]
|
||||
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx : start_idx + query_len].float()
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len].float()
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len].float()
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
q_start_pos = kv_len - query_len
|
||||
q_pos = q_start_pos + torch.arange(0, query_len)[None, :, None]
|
||||
kv_pos = torch.arange(0, kv_len)[None, None, :]
|
||||
dist = q_pos - kv_pos
|
||||
alibi_bias = -alibi_slopes * dist
|
||||
attn += alibi_bias
|
||||
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
|
||||
if s_aux is not None:
|
||||
s_aux_ext = s_aux.repeat(1, query_len, 1)
|
||||
attn = torch.cat((s_aux_ext, attn), dim=-1)
|
||||
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
if s_aux is not None:
|
||||
attn = attn[:, :, 1:]
|
||||
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v).to(dtype=dtype)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def varlen_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
token_num = sum(query_lens)
|
||||
|
||||
# for n heads the set of slopes is the geometric sequence that starts
|
||||
# 2^(-8/n)
|
||||
alibi_slopes = _get_alibi_slopes(num_query_heads) if use_alibi else None
|
||||
|
||||
s_aux = (
|
||||
15 * torch.rand((num_query_heads,), dtype=torch.bfloat16) if use_sink else None
|
||||
)
|
||||
|
||||
query = tensor_cache(
|
||||
elem_num=token_num * num_query_heads * head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
query = query.view(
|
||||
token_num,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
)
|
||||
|
||||
key_value = tensor_cache(
|
||||
elem_num=2 * num_blocks * num_kv_heads * block_size * head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
key_value = key_value.view(
|
||||
2,
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
)
|
||||
key_cache, value_cache = key_value.unbind(0)
|
||||
|
||||
# KV cache for CPU attention
|
||||
packed_key_cache = torch.empty(
|
||||
num_blocks, num_kv_heads, block_size, head_size, dtype=dtype
|
||||
)
|
||||
packed_value_cache = torch.empty_like(packed_key_cache)
|
||||
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
# use reshape_and_cache to pack key_cache and value_cache
|
||||
slot_mapping = torch.arange(0, num_blocks * block_size, dtype=torch.int64)
|
||||
cpu_attn_reshape_and_cache(
|
||||
key=key_cache.view(-1, num_kv_heads, head_size),
|
||||
value=value_cache.view(-1, num_kv_heads, head_size),
|
||||
key_cache=packed_key_cache,
|
||||
value_cache=packed_value_cache,
|
||||
slot_mapping=slot_mapping,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
metadata = cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_seqs,
|
||||
num_heads=num_query_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
seq_lens=kv_lens_tensor,
|
||||
dtype=dtype,
|
||||
query_start_loc=cu_query_lens,
|
||||
causal=True,
|
||||
sliding_window_size=sliding_window if sliding_window is not None else -1,
|
||||
isa=isa,
|
||||
enable_kv_split=False,
|
||||
)
|
||||
|
||||
out_without_split = torch.empty_like(query)
|
||||
cpu_attention_with_kv_cache(
|
||||
query=query,
|
||||
key_cache=packed_key_cache,
|
||||
value_cache=packed_value_cache,
|
||||
output=out_without_split,
|
||||
query_start_loc=cu_query_lens,
|
||||
seq_lens=kv_lens_tensor,
|
||||
scale=scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
metadata = cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_seqs,
|
||||
num_heads=num_query_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
seq_lens=kv_lens_tensor,
|
||||
dtype=dtype,
|
||||
query_start_loc=cu_query_lens,
|
||||
causal=True,
|
||||
sliding_window_size=sliding_window if sliding_window is not None else -1,
|
||||
isa=isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
|
||||
out_with_split = torch.empty_like(query)
|
||||
cpu_attention_with_kv_cache(
|
||||
query=query,
|
||||
key_cache=packed_key_cache,
|
||||
value_cache=packed_value_cache,
|
||||
output=out_with_split,
|
||||
query_start_loc=cu_query_lens,
|
||||
seq_lens=kv_lens_tensor,
|
||||
scale=scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
(
|
||||
torch.testing.assert_close(out_with_split, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(out_with_split - ref_output))}",
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(out_without_split, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(out_without_split - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [96, 128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", QTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["vec"])
|
||||
def test_varlen_with_paged_kv_normal_vec(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [96, 128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["amx"])
|
||||
@pytest.mark.skipif(
|
||||
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
|
||||
)
|
||||
def test_varlen_with_paged_kv_normal_amx(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [48])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["vec16"])
|
||||
def test_varlen_with_paged_kv_normal_vec16(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", [96, 128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", QTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", ["neon"])
|
||||
@pytest.mark.skipif(
|
||||
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
|
||||
reason="Not an Arm CPU.",
|
||||
)
|
||||
def test_varlen_with_paged_kv_normal_neon(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@pytest.mark.parametrize("block_size", [128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [50])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||
def test_varlen_with_paged_kv_softcap(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@pytest.mark.parametrize("block_size", [128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [True])
|
||||
@pytest.mark.parametrize("use_sink", [False])
|
||||
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||
def test_varlen_with_paged_kv_alibi(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@pytest.mark.parametrize("block_size", [128])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("use_alibi", [False])
|
||||
@pytest.mark.parametrize("use_sink", [True])
|
||||
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||
def test_varlen_with_paged_kv_sink(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
sliding_window=sliding_window,
|
||||
dtype=dtype,
|
||||
block_size=block_size,
|
||||
soft_cap=soft_cap,
|
||||
num_blocks=num_blocks,
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
)
|
||||
214
tests/kernels/attention/test_cutlass_mla_decode.py
Normal file
214
tests/kernels/attention/test_cutlass_mla_decode.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False,
|
||||
diff_threshold: float | None = None,
|
||||
) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
if diff_threshold is not None:
|
||||
# directly compare the cos_diff with the threshold
|
||||
assert cos_diff < diff_threshold
|
||||
else:
|
||||
# use the default threshold
|
||||
if use_fp8:
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
||||
"Cutlass MLA Requires compute capability of 100 or above."
|
||||
if not current_platform.is_device_capability_family(100)
|
||||
else "Cutlass MLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(100),
|
||||
reason=CUTLASS_MLA_UNSUPPORTED_REASON,
|
||||
)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("h_kv", [1])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"torch_dtype",
|
||||
[
|
||||
torch.bfloat16,
|
||||
# fp8 can have occasional precision-related failures.
|
||||
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_mla_decode(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
|
||||
)
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
scale = math.sqrt(d) ** (-1)
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
init_dtype = q.dtype
|
||||
if use_fp8:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
descale_q = torch.ones((1), dtype=torch.float32)
|
||||
descale_k = torch.ones((1), dtype=torch.float32)
|
||||
|
||||
q = q.to(fp8_dtype)
|
||||
blocked_k = blocked_k.to(fp8_dtype)
|
||||
blocked_v = blocked_v.to(fp8_dtype)
|
||||
else:
|
||||
descale_q = None
|
||||
descale_k = None
|
||||
|
||||
def cutlass_mla():
|
||||
MAX_HEADS = 128
|
||||
|
||||
q_reshaped = q.squeeze(1)
|
||||
q_nope = q_reshaped[:, :, :dv].clone()
|
||||
q_pe = q_reshaped[:, :, dv:].clone()
|
||||
|
||||
if h_q < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv))
|
||||
q_nope_padded[:, :h_q] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv))
|
||||
q_pe_padded[:, :h_q] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
kv_cache_flat = blocked_k.squeeze(2)
|
||||
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
sm_count = device_properties.multi_processor_count
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seqlen * block_size, b, sm_count, num_kv_splits=1
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
||||
output_lse = torch.empty(
|
||||
(b, MAX_HEADS), dtype=torch.float32, device=q_nope.device
|
||||
)
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out_ans,
|
||||
output_lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_cache_flat,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
workspace,
|
||||
scale,
|
||||
1,
|
||||
)
|
||||
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (
|
||||
(blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_k
|
||||
)
|
||||
blocked_v_ = (
|
||||
(blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_v
|
||||
)
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
out_i, lse_i = scaled_dot_product_attention(
|
||||
q_[i].transpose(0, 1),
|
||||
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = out_i.transpose(0, 1)
|
||||
lse[i] = lse_i
|
||||
return out, lse
|
||||
|
||||
out_cutlass, lse_cutlass = cutlass_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
# Extract the single token (s_q=1) slice to match cutlass output shape
|
||||
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
|
||||
lse_torch_slice = lse_torch[:, 0, :] # [b, h_q]
|
||||
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
|
||||
# lse has larger numerical error, so use a larger threshold
|
||||
cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3)
|
||||
|
||||
t = triton.testing.do_bench(cutlass_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (
|
||||
torch.finfo(torch_dtype).bits // 8
|
||||
) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s"
|
||||
)
|
||||
294
tests/kernels/attention/test_deepgemm_attention.py
Normal file
294
tests/kernels/attention/test_deepgemm_attention.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
_ceil_to_ue8m0,
|
||||
calc_diff,
|
||||
fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits,
|
||||
get_num_sms,
|
||||
get_paged_mqa_logits_metadata,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (num_blocks, block_size, 1, head_dim)
|
||||
num_blocks, block_size, num_heads, head_dim = x.shape
|
||||
assert num_heads == 1
|
||||
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
x_fp8 = torch.empty(
|
||||
(num_blocks, block_size * (head_dim + 4)),
|
||||
device=x.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
x_fp8[:, : block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim
|
||||
).view(dtype=torch.uint8)
|
||||
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
|
||||
dtype=torch.uint8
|
||||
)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(
|
||||
x: torch.Tensor, dims: tuple, use_ue8m0: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
|
||||
|
||||
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
|
||||
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
|
||||
chunk_size = seq_len // 2
|
||||
cp_size = seq_len_kv // seq_len
|
||||
cp_id = cp_size // 3
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
for i in range(chunk_size):
|
||||
ke[i] = cp_id * chunk_size + i
|
||||
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
|
||||
return ks, ke
|
||||
|
||||
|
||||
def _ref_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
):
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
score = torch.einsum("mhd,nd->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
for seq_len in (512,):
|
||||
for seq_len_kv in (1024,):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv = torch.randn(
|
||||
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
weights = torch.randn(
|
||||
seq_len, num_heads, device="cuda", dtype=torch.float32
|
||||
)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + (
|
||||
seq_len_kv - seq_len
|
||||
)
|
||||
else:
|
||||
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
q=q,
|
||||
kv=kv,
|
||||
weights=weights,
|
||||
cu_seqlen_ks=ks,
|
||||
cu_seqlen_ke=ke,
|
||||
)
|
||||
|
||||
ref_neginf_mask = ref_logits == float("-inf")
|
||||
neginf_mask = logits == float("-inf")
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
|
||||
def _ref_fp8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, next_n, _, _ = q.size()
|
||||
_, block_size, _, _ = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size,
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_rk * block_size : (block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
max_model_len = 4096
|
||||
for batch_size, next_n in [(4, 1), (2, 2)]:
|
||||
for heads, index_dim in [(32, 128)]:
|
||||
for avg_kv in (2048,):
|
||||
num_blocks, blocksize = max_model_len * 2, 64
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, next_n, heads, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv_cache = torch.randn(
|
||||
(num_blocks, blocksize, 1, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
weights = torch.randn(
|
||||
(batch_size * next_n, heads),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
context_lens = (
|
||||
torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,))
|
||||
.cuda()
|
||||
.to(torch.int32)
|
||||
)
|
||||
max_block_len = (
|
||||
(context_lens.max().item() + blocksize - 1) // blocksize * blocksize
|
||||
)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, max_block_len),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
counter = 0
|
||||
block_idx_pool = list(range(num_blocks))
|
||||
random.shuffle(block_idx_pool)
|
||||
for i in range(batch_size):
|
||||
ctx_len = int(context_lens[i].item())
|
||||
for j in range((ctx_len + blocksize - 1) // blocksize):
|
||||
block_tables[i][j] = block_idx_pool[counter]
|
||||
counter += 1
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms()
|
||||
)
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_paged_mqa_logits(
|
||||
q,
|
||||
kv_cache,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
positions = (
|
||||
torch.arange(max_model_len, device="cuda")
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size * next_n, -1)
|
||||
)
|
||||
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
|
||||
next_n_offset = (
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n
|
||||
)
|
||||
mask = positions <= (
|
||||
context_lens[row_indices] - next_n + next_n_offset
|
||||
).unsqueeze(1)
|
||||
|
||||
logits = logits.masked_fill(~mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(~mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
216
tests/kernels/attention/test_flash_attn.py
Normal file
216
tests/kernels/attention/test_flash_attn.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"vllm_flash_attn is not supported for vLLM on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [40, 72, 80, 128, 256]
|
||||
BLOCK_SIZES = [16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
SOFT_CAPS = [None]
|
||||
SLIDING_WINDOWS = [None, 256]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len]
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len]
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_out", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
use_out: bool,
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(
|
||||
f"Flash attention version {fa_version} not supported due "
|
||||
f'to: "{fa_version_unsupported_reason(fa_version)}"'
|
||||
)
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip(
|
||||
"Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type"
|
||||
)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
out = torch.empty_like(query) if use_out else None
|
||||
|
||||
maybe_quantized_query = query
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
v=maybe_quantized_value_cache,
|
||||
out=out,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
497
tests/kernels/attention/test_flashinfer.py
Normal file
497
tests/kernels/attention/test_flashinfer.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"flashinfer is not supported for vLLM on ROCm.", allow_module_level=True
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
NUM_HEADS = [(32, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
SOFT_CAPS = [None, 30.0]
|
||||
SLIDING_WINDOWS = [None, 64]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len]
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len]
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_kv(
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
sliding_window: int | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=True
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
window_left=sliding_window - 1 if sliding_window is not None else -1,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = wrapper.run(query, key_value_cache)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_prefill_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
sliding_window: int | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
# Normalize the scale of the key and value caches to mitigate
|
||||
# numerical instability.
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
qo_indptr = [0]
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
qo_indptr.append(qo_indptr[-1] + query_lens[i])
|
||||
|
||||
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
window_left=sliding_window - 1 if sliding_window is not None else -1,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = wrapper.run(
|
||||
query,
|
||||
key_value_cache,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
) -> None:
|
||||
pytest.skip("TODO: fix the accuracy issue")
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
NUM_BLOCKS_FP8 = 2048
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
k_scale = key_cache.amax().item() / 448.0
|
||||
v_scale = value_cache.amax().item() / 448.0
|
||||
|
||||
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], dim=1).to(
|
||||
kv_cache_dtype
|
||||
)
|
||||
|
||||
assert kv_cache_fp8.shape == key_value_cache.shape
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
qo_indptr = [0]
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
qo_indptr.append(qo_indptr[-1] + query_lens[i])
|
||||
|
||||
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache.squeeze(1),
|
||||
value_cache=value_cache.squeeze(1),
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
del query
|
||||
del block_tables
|
||||
# verify prefill fp8
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.skip(reason="TODO: fix the accuracy issue")
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
) -> None:
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
use_tensor_cores = True
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
NUM_BLOCKS_FP8 = 2048
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
k_scale = key_cache.amax().item() / 448.0
|
||||
v_scale = value_cache.amax().item() / 448.0
|
||||
|
||||
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
|
||||
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
|
||||
assert key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1
|
||||
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
119
tests/kernels/attention/test_flashinfer_mla_decode.py
Normal file
119
tests/kernels/attention/test_flashinfer_mla_decode.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
else:
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
|
||||
@pytest.mark.parametrize("block_size", [32, 64])
|
||||
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Deepseek R1 config
|
||||
num_heads = 128
|
||||
kv_lora_rank = 512
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
qk_head_dim = kv_lora_rank + qk_rope_head_dim
|
||||
scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
|
||||
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1,)).item() for _ in range(bs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
|
||||
# Generate block tables with random but unique block IDs
|
||||
# From https://github.com/flashinfer-ai/flashinfer/pull/1222
|
||||
blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size
|
||||
max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4)
|
||||
total_blocks_needed = sum(blocks_per_seq)
|
||||
# Get random unique IDs for all blocks
|
||||
all_block_ids = torch.randperm(total_blocks_needed)
|
||||
|
||||
block_id = 0
|
||||
block_tables = torch.zeros(
|
||||
(bs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Populate block tables and track block assignments
|
||||
block_id = 0
|
||||
for i in range(bs):
|
||||
num_blocks_needed = blocks_per_seq[i]
|
||||
block_tables[i, :num_blocks_needed] = all_block_ids[
|
||||
block_id : block_id + num_blocks_needed
|
||||
]
|
||||
block_id += num_blocks_needed
|
||||
|
||||
kv_cache = torch.randn(block_tables.numel(), block_size, qk_head_dim).to(dtype)
|
||||
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)
|
||||
|
||||
out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor)
|
||||
|
||||
workspace_buffer = torch.zeros(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=q.device,
|
||||
)
|
||||
# Flashinfer MLA expects the query to be of shape
|
||||
# (bs, q_len_per_request, num_heads, qk_head_dim),
|
||||
# where q_len_per_request is the MTP query length (=1 without MTP)
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
out_ans = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_cache.unsqueeze(1),
|
||||
workspace_buffer=workspace_buffer,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens_tensor,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=scale,
|
||||
)
|
||||
out_ans = out_ans.squeeze(1)
|
||||
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
|
||||
457
tests/kernels/attention/test_flashinfer_trtllm_attention.py
Normal file
457
tests/kernels/attention/test_flashinfer_trtllm_attention.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
dequantize_nvfp4_to_dtype,
|
||||
get_nvfp4_global_scale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
else:
|
||||
import flashinfer
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
DTYPE = [torch.bfloat16]
|
||||
QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
BATCH_SIZE = [4, 12]
|
||||
MAX_SEQ_LENS = [(1024, 4096)]
|
||||
NUM_HEADS = [(64, 8), (40, 8)]
|
||||
HEAD_SIZE = [128]
|
||||
KV_LAYOUT = ["HND"] # currently only HND is supported
|
||||
BLOCK_SIZE = [16]
|
||||
WINDOW_LEFT = [-1, 127]
|
||||
SOFT_CAP = [None, 50.0]
|
||||
HAS_SINKS = [True, False]
|
||||
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
||||
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: float | None,
|
||||
has_sinks: bool,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
_, max_kv_len = max_seq_lens
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
# max_q_len = 1
|
||||
q_lens = torch.ones((batch_size,), dtype=torch.int32)
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
# Baseline Decode
|
||||
if has_sinks:
|
||||
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
|
||||
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
|
||||
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||
)
|
||||
else:
|
||||
sinks = None
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
qo_indptr=q_indptr,
|
||||
paged_kv_indptr=kv_indptr,
|
||||
paged_kv_indices=kv_indices,
|
||||
paged_kv_last_page_len=kv_last_page_lens,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim_qk=head_size,
|
||||
page_size=block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale_float = None
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = get_nvfp4_global_scale(output)
|
||||
o_sf_scale_float = o_sf_scale.item()
|
||||
|
||||
# TRTLLM Decode
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
window_left=window_left,
|
||||
sinks=sinks,
|
||||
o_sf_scale=o_sf_scale_float,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2
|
||||
)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(
|
||||
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 7e-2, 9e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 3e-2, 4e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
rtol, atol = 2e-2, 2e-2
|
||||
elif kv_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: float | None,
|
||||
has_sinks: bool,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
if q_quant_dtype != kv_quant_dtype:
|
||||
pytest.skip("Skipped mixed QKV dtypes for prefill")
|
||||
|
||||
max_q_len, max_kv_len = max_seq_lens
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
# Baseline Prefill
|
||||
if has_sinks:
|
||||
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
|
||||
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
|
||||
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||
)
|
||||
else:
|
||||
sinks = None
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
qo_indptr=q_indptr,
|
||||
paged_kv_indptr=kv_indptr,
|
||||
paged_kv_indices=kv_indices,
|
||||
paged_kv_last_page_len=kv_last_page_lens,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim_qk=head_size,
|
||||
page_size=block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale_float = None
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = get_nvfp4_global_scale(output)
|
||||
o_sf_scale_float = o_sf_scale.item()
|
||||
|
||||
# TRTLLM Prefill
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
window_left=window_left,
|
||||
sinks=sinks,
|
||||
o_sf_scale=o_sf_scale_float,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2
|
||||
)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(
|
||||
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 3e-1, 4e-1
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
rtol, atol = 2e-2, 3e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
178
tests/kernels/attention/test_flashmla.py
Normal file
178
tests/kernels/attention/test_flashmla.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(
|
||||
x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False
|
||||
) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
if use_fp8:
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = (
|
||||
is_flashmla_dense_supported()[1]
|
||||
if not is_flashmla_dense_supported()[0]
|
||||
else "FlashMLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_flashmla_dense_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON
|
||||
)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("h_kv", [1])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
|
||||
)
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = (
|
||||
float("nan")
|
||||
)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
init_dtype = q.dtype
|
||||
if use_fp8:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
descale_q = torch.ones((1), dtype=torch.float32)
|
||||
descale_k = torch.ones((1), dtype=torch.float32)
|
||||
|
||||
q = q.to(fp8_dtype)
|
||||
blocked_k = blocked_k.to(fp8_dtype)
|
||||
blocked_v = blocked_v.to(fp8_dtype)
|
||||
else:
|
||||
descale_q = None
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (
|
||||
(blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_k
|
||||
)
|
||||
blocked_v_ = (
|
||||
(blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_v
|
||||
)
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
out_i, lse_i = scaled_dot_product_attention(
|
||||
q_[i].transpose(0, 1),
|
||||
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = out_i.transpose(0, 1)
|
||||
lse[i] = lse_i
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out", use_fp8)
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (
|
||||
torch.finfo(torch_dtype).bits // 8
|
||||
) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s"
|
||||
)
|
||||
122
tests/kernels/attention/test_flashmla_sparse.py
Normal file
122
tests/kernels/attention/test_flashmla_sparse.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def test_sparse_flashmla_metadata_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 128
|
||||
num_heads_k = 1
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
topk = 128
|
||||
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
tile_md, num_splits = fm.get_mla_metadata(
|
||||
cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
assert tile_md.dtype == torch.int32
|
||||
assert num_splits.dtype == torch.int32
|
||||
|
||||
|
||||
def test_sparse_flashmla_decode_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 1
|
||||
head_dim_k = 576
|
||||
head_dim_v = 512
|
||||
num_heads_k = 1
|
||||
page_block_size = 64
|
||||
bytes_per_token = 656
|
||||
topk = 128
|
||||
|
||||
# Metadata
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
# q_heads_per_hk = num_heads_q // num_heads_k
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
tile_md, num_splits = fm.get_mla_metadata(
|
||||
cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
# Inputs
|
||||
q = torch.zeros(
|
||||
(batch_size, seqlen_q, num_heads_q, head_dim_k),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
k_cache = torch.zeros(
|
||||
(1, page_block_size, num_heads_k, bytes_per_token),
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
indices = torch.zeros(
|
||||
(batch_size, seqlen_q, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device)
|
||||
out, lse = fm.flash_mla_with_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_md,
|
||||
num_splits,
|
||||
indices=indices,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
assert out.shape[0] == batch_size
|
||||
assert out.shape[-1] == head_dim_v
|
||||
assert lse.shape[0] == batch_size
|
||||
|
||||
|
||||
def test_sparse_flashmla_prefill_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
device = torch.device("cuda")
|
||||
s_q = 1
|
||||
s_kv = 1
|
||||
h_q = 64 # kernel expects multiple of 64
|
||||
h_kv = 1
|
||||
d_qk = 576
|
||||
d_v = 512
|
||||
topk = 128
|
||||
|
||||
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
|
||||
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
|
||||
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
|
||||
|
||||
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v)
|
||||
assert out.shape == (s_q, h_q, d_v)
|
||||
assert max_logits.shape == (s_q, h_q)
|
||||
assert lse.shape == (s_q, h_q)
|
||||
266
tests/kernels/attention/test_lightning_attn.py
Normal file
266
tests/kernels/attention/test_lightning_attn.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
HEAD_SIZES = [64]
|
||||
BATCH_SIZES = [1, 2]
|
||||
SEQ_LENGTHS = [16]
|
||||
DTYPES = [torch.float32]
|
||||
|
||||
|
||||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
"""Reference implementation of lightning attention core algorithm
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
each step sequentially, instead of using parallelized triton kernels
|
||||
"""
|
||||
B, H, S, D = q.shape
|
||||
E = v.shape[-1]
|
||||
dtype = q.dtype
|
||||
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
|
||||
|
||||
# Use clone() to ensure an independent copy
|
||||
if kv_history is None:
|
||||
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
|
||||
else:
|
||||
kv_cache = kv_history.clone()
|
||||
|
||||
# More efficient implementation
|
||||
# Convert decay factors to matrix form
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
|
||||
|
||||
for b in range(B):
|
||||
for step in range(S):
|
||||
# Process all heads at once for this position
|
||||
q_bs = q[b, :, step] # [H, D]
|
||||
k_bs = k[b, :, step] # [H, D]
|
||||
v_bs = v[b, :, step] # [H, E]
|
||||
|
||||
# Calculate KV outer products for all heads
|
||||
for h in range(H):
|
||||
# Calculate KV outer product
|
||||
kv_outer = torch.outer(k_bs[h], v_bs[h])
|
||||
|
||||
# Update KV cache with decay
|
||||
# Note: Using the same order as in the Triton kernel
|
||||
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
|
||||
|
||||
# Calculate attention output
|
||||
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
|
||||
|
||||
# Match the shape returned by the actual implementation
|
||||
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
||||
# where dimension 2 contains both KV and KV history
|
||||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E]
|
||||
|
||||
return output, final_kv_cache
|
||||
|
||||
|
||||
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
||||
"""Reference implementation: linear attention decode function"""
|
||||
B, H, _, D = q.shape
|
||||
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
|
||||
|
||||
# Calculate decay factors once (more efficient)
|
||||
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
|
||||
|
||||
# Process each batch
|
||||
for b in range(B):
|
||||
slot_id = slot_idx[b].item()
|
||||
|
||||
# Skip padding positions
|
||||
if slot_id == -1:
|
||||
continue
|
||||
|
||||
# Process all heads at once for this batch
|
||||
q_b = q[b, :, 0] # [H, D]
|
||||
k_b = k[b, :, 0] # [H, D]
|
||||
v_b = v[b, :, 0] # [H, D]
|
||||
|
||||
# Process each attention head
|
||||
for h in range(H):
|
||||
# Get current query, key and value
|
||||
q_bh = q_b[h]
|
||||
k_bh = k_b[h]
|
||||
v_bh = v_b[h]
|
||||
|
||||
# Get cache
|
||||
kv_cache_old = kv_caches[b, h]
|
||||
|
||||
# Calculate new key-value outer product
|
||||
kv_outer = torch.outer(k_bh, v_bh)
|
||||
|
||||
# Apply decay and update cache
|
||||
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
|
||||
|
||||
# Calculate output
|
||||
out_h = torch.matmul(q_bh, kv_new)
|
||||
|
||||
# Update output and cache
|
||||
output[b, h * D : (h + 1) * D] = out_h
|
||||
kv_caches[b, h] = kv_new
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
reference_output = reference_linear_decode(
|
||||
q, k, v, kv_caches_copy, slope_rate, slot_idx
|
||||
)
|
||||
torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1)
|
||||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_linear_decode_forward_triton_with_padding(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
batch_size = 4
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
reference_output = reference_linear_decode(
|
||||
q, k, v, kv_caches_copy, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
|
||||
triton_masked = triton_output[padding_mask]
|
||||
reference_masked = reference_output[padding_mask]
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
valid_indices = slot_idx != -1
|
||||
|
||||
for i in range(batch_size):
|
||||
if valid_indices[i] > 0:
|
||||
torch.testing.assert_close(
|
||||
kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_lightning_attention_reference(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
ref_output, ref_kv_cache = reference_lightning_attention(
|
||||
q, k, v, ed, 256, kv_history
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
||||
|
||||
actual_output, actual_kv_cache = lightning_attention(
|
||||
q, k, v, ed, 256, kv_history_clone
|
||||
)
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol)
|
||||
|
||||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
||||
assert ref_kv_cache.shape == actual_kv_cache.shape
|
||||
319
tests/kernels/attention/test_merge_attn_states.py
Normal file
319
tests/kernels/attention/test_merge_attn_states.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states_torch(
|
||||
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
):
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
# inf -> -inf
|
||||
p_lse[p_lse == torch.inf] = -torch.inf
|
||||
s_lse[s_lse == torch.inf] = -torch.inf
|
||||
# max_lse [NUM_HEADS, NUM_TOKENS]
|
||||
max_lse = torch.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
p_lse_exp = torch.exp(p_lse)
|
||||
s_lse_exp = torch.exp(s_lse)
|
||||
out_se = p_lse_exp + s_lse_exp
|
||||
if output_lse is not None:
|
||||
output_lse = torch.log(out_se) + max_lse
|
||||
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output = prefix_output * p_scale + suffix_output * s_scale
|
||||
return output, output_lse
|
||||
|
||||
|
||||
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096]
|
||||
NUM_QUERY_HEADS = [4, 8, 16, 32, 48, 64]
|
||||
HEAD_SIZES = [32, 48, 64, 96, 128, 256]
|
||||
DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
all_case_info: list[tuple] = []
|
||||
|
||||
|
||||
def generate_markdown_table():
|
||||
global all_case_info
|
||||
table_header = (
|
||||
"| tokens | heads | headsize | dtype "
|
||||
"| device | torch | triton | cuda | speedup |"
|
||||
)
|
||||
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
|
||||
|
||||
def shortly_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype).removeprefix("torch.")
|
||||
|
||||
def shortly_device(device: str) -> str:
|
||||
return device.removeprefix("NVIDIA").strip()
|
||||
|
||||
print(table_header)
|
||||
print(table_separator)
|
||||
for info in all_case_info:
|
||||
(
|
||||
num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
avg_time_torch_kernel,
|
||||
avg_time_triton_kernel,
|
||||
avg_time_cuda_kernel,
|
||||
performance_improved,
|
||||
) = info
|
||||
dtype = shortly_dtype(dtype)
|
||||
device = shortly_device(device)
|
||||
print(
|
||||
f"| {num_tokens} | {num_heads} | {head_size} "
|
||||
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
|
||||
f"| {avg_time_triton_kernel:.5f}ms "
|
||||
f"| {avg_time_cuda_kernel:.5f}ms "
|
||||
f"| {performance_improved:.4f}x |"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
|
||||
):
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(
|
||||
"Currently only support compare triton merge_attn_states "
|
||||
"with custom cuda merge_attn_states kernel"
|
||||
)
|
||||
|
||||
NUM_TOKENS = num_tokens
|
||||
NUM_HEADS = num_query_heads
|
||||
HEAD_SIZE = head_size
|
||||
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"Device: {current_platform.get_device_name()}"
|
||||
)
|
||||
|
||||
# prefix_lse and suffix_lse contain inf and normal values
|
||||
prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
|
||||
|
||||
# Generate boolean masks
|
||||
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
|
||||
mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
|
||||
# Ensure that the same position is not True at the same time
|
||||
combined_mask = torch.logical_and(mask_prefix, mask_suffix)
|
||||
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
|
||||
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
|
||||
|
||||
prefix_lse[mask_prefix] = float("inf")
|
||||
suffix_lse[mask_suffix] = float("inf")
|
||||
|
||||
# Other input tensors (need to be initialized but
|
||||
# no actual calculation needed)
|
||||
output = torch.zeros(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
output_lse = torch.zeros(
|
||||
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
prefix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
|
||||
warmup_times = 2
|
||||
repeat_times = 20
|
||||
|
||||
output_torch = output.clone()
|
||||
output_lse_torch = output_lse.clone()
|
||||
total_time_torch_kernel = 0
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
# 0. Run the Torch kernel
|
||||
prefix_lse_torch = prefix_lse.clone()
|
||||
suffix_lse_torch = suffix_lse.clone()
|
||||
for _ in range(warmup_times):
|
||||
output_torch, output_lse_torch = merge_attn_states_torch(
|
||||
output_torch,
|
||||
prefix_output,
|
||||
prefix_lse_torch,
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
output_torch, output_lse_torch = merge_attn_states_torch(
|
||||
output_torch,
|
||||
prefix_output,
|
||||
prefix_lse_torch,
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time_torch_kernel += start.elapsed_time(end)
|
||||
|
||||
avg_time_torch_kernel = total_time_torch_kernel / repeat_times
|
||||
|
||||
# 1. Run the Triton kernel
|
||||
output_ref_triton = output.clone()
|
||||
output_lse_ref_triton = output_lse.clone()
|
||||
|
||||
total_time_triton_kernel = 0
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
for _ in range(warmup_times):
|
||||
merge_attn_states_triton(
|
||||
output_ref_triton,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
merge_attn_states_triton(
|
||||
output_ref_triton,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time_triton_kernel += start.elapsed_time(end)
|
||||
|
||||
avg_time_triton_kernel = total_time_triton_kernel / repeat_times
|
||||
|
||||
# 2. Run the CUDA kernel
|
||||
total_time_cuda_kernel = 0
|
||||
output_cuda = output.clone()
|
||||
output_lse_cuda = output_lse.clone()
|
||||
|
||||
for _ in range(warmup_times):
|
||||
merge_attn_states_cuda(
|
||||
output_cuda,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
merge_attn_states_cuda(
|
||||
output_cuda,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time_cuda_kernel += start.elapsed_time(end)
|
||||
|
||||
avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times
|
||||
|
||||
# 3. Performance compare
|
||||
performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
|
||||
print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
|
||||
print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
|
||||
print(
|
||||
f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
|
||||
f"Performance: {performance_improved:.5f}x"
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
# 4. Correctness compare
|
||||
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
||||
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
|
||||
# use rtol = 1e-2 for bfloat16.
|
||||
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
|
||||
|
||||
def diff(a: torch.Tensor, b: torch.Tensor):
|
||||
max_diff = torch.max(torch.abs(a.float() - b.float()))
|
||||
return max_diff
|
||||
|
||||
# Use Triton output as reference because we want to replace
|
||||
# the Triton kernel with custom CUDA kernel for merge attn
|
||||
# states operation.
|
||||
output_ref = output_ref_triton
|
||||
output_lse_ref = output_lse_ref_triton
|
||||
torch.testing.assert_close(
|
||||
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
|
||||
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
|
||||
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
|
||||
print("-" * 100)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output LSE all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
|
||||
print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
|
||||
print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
|
||||
print("-" * 100)
|
||||
|
||||
print(
|
||||
"All output values test passed! All inf values "
|
||||
"are correctly replaced with -inf."
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
device = current_platform.get_device_name()
|
||||
all_case_info.append(
|
||||
(
|
||||
NUM_TOKENS,
|
||||
NUM_HEADS,
|
||||
HEAD_SIZE,
|
||||
output_dtype,
|
||||
device,
|
||||
avg_time_torch_kernel,
|
||||
avg_time_triton_kernel,
|
||||
avg_time_cuda_kernel,
|
||||
performance_improved,
|
||||
)
|
||||
)
|
||||
if len(all_case_info) == (
|
||||
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
|
||||
):
|
||||
generate_markdown_table()
|
||||
153
tests/kernels/attention/test_mha_attn.py
Normal file
153
tests/kernels/attention/test_mha_attn.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
devices = ["cpu"]
|
||||
if current_platform.is_cuda():
|
||||
devices.append("cuda")
|
||||
if current_platform.is_rocm():
|
||||
devices.append("hip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_mha_attn_platform(device: str):
|
||||
"""
|
||||
Test the attention selector between different platform and device.
|
||||
"""
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if device == "cpu":
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CpuPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Native implementation of scaled dot product attention without mask:
|
||||
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
||||
- attn_mask: [batch_size, seq_len, seq_len]
|
||||
"""
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
||||
return out
|
||||
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
SEQ_LENS = [1]
|
||||
NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = (
|
||||
[torch.half, torch.bfloat16, torch.float]
|
||||
if not current_platform.is_rocm()
|
||||
else [torch.half, torch.bfloat16]
|
||||
)
|
||||
CUDA_DEVICES = ["cuda"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_mha_attn_forward(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
q = torch.randn(batch_size, seq_len, num_heads * head_size)
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(q, k, v)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
||||
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
if num_queries_per_kv > 1:
|
||||
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||
|
||||
ref_output = ref_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||
torch.testing.assert_close(output, ref_output)
|
||||
84
tests/kernels/attention/test_mla_decode_cpu.py
Normal file
84
tests/kernels/attention/test_mla_decode_cpu.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs", [4])
|
||||
@pytest.mark.parametrize("mean_seq_len", [256])
|
||||
@pytest.mark.parametrize("h_q", [16])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
def test_mla_decode_cpu(
|
||||
bs: int,
|
||||
mean_seq_len: int,
|
||||
h_q: int,
|
||||
d: int,
|
||||
dv: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
varlen: bool,
|
||||
):
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = d ** (-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
else:
|
||||
seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32)
|
||||
block_table = block_table.view(bs, seqlen_pad // block_size)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
for i, seq_len in enumerate(seq_lens.tolist()):
|
||||
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
|
||||
|
||||
out_mla = q.new_zeros(bs, h_q, dv)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
assert not out_mla.isnan().any(), "Likely read out of bounds"
|
||||
torch.testing.assert_close(out_mla, out_ref)
|
||||
234
tests/kernels/attention/test_pack_unpack_triton.py
Normal file
234
tests/kernels/attention/test_pack_unpack_triton.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
|
||||
|
||||
def test_pack_seq_basic_fp8():
|
||||
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors (N, H, D)
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
|
||||
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check output shape and properties
|
||||
expected_shape = (B, max(lengths_list), H, D)
|
||||
assert packed.shape == expected_shape
|
||||
assert packed.dtype == dtype
|
||||
assert packed.device == x.device
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = sum(lengths_list[:b])
|
||||
seq_len = lengths_list[b]
|
||||
|
||||
expected_data = x[start_idx : start_idx + seq_len].to(torch.float32)
|
||||
actual_data = packed[b, :seq_len].to(torch.float32)
|
||||
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_custom_padding_fp8():
|
||||
"""Test pack_seq_triton with custom padding values for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test with different padding values
|
||||
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
|
||||
result = pack_seq_triton(x, lengths, pad_value=pad_value)
|
||||
|
||||
# Check valid data
|
||||
for b in range(B):
|
||||
start_idx = b * 10
|
||||
expected_data = x[start_idx : start_idx + 10].to(torch.float32)
|
||||
actual_data = result[b, :10].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
# Check padding (fp8 has limited range, so check for large values)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
if pad_value < 0:
|
||||
assert torch.all(padded_data < -50) # Large negative values
|
||||
elif pad_value > 0:
|
||||
assert torch.all(padded_data > 50) # Large positive values
|
||||
else:
|
||||
assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_default_negative_inf_padding_fp8():
|
||||
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
# B = 2
|
||||
N, H, D = 20, 8, 16
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check that padding is large negative values (fp8 representation of -inf)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
assert torch.all(
|
||||
padded_data < -100
|
||||
) # fp8 -inf is represented as large negative number
|
||||
|
||||
|
||||
def test_pack_seq_edge_cases_fp8():
|
||||
"""Test pack_seq_triton with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (1, 10, 8, 16)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 1, 4, 8)
|
||||
|
||||
# Test with different sequence lengths
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 7, 8, 16)
|
||||
|
||||
|
||||
def test_pack_seq_different_block_sizes_fp8():
|
||||
"""Test pack_seq_triton with different block sizes for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 100, 16, 32, 4
|
||||
lengths = torch.tensor([25, 25, 25, 25], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test different block sizes
|
||||
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
|
||||
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
|
||||
|
||||
assert result.shape == (B, 25, H, D)
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = b * 25
|
||||
expected_data = x[start_idx : start_idx + 25].to(torch.float32)
|
||||
actual_data = result[b, :25].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_shape_consistency():
|
||||
"""Test that pack_seq_triton maintains shape consistency."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check shape consistency
|
||||
assert result.shape[0] == B # Batch dimension
|
||||
assert result.shape[1] == lengths.max().item() # Max sequence length
|
||||
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
|
||||
|
||||
|
||||
def test_pack_unpack_roundtrip_fp8():
|
||||
"""Test that pack -> unpack gives us back the original data for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]),
|
||||
(10, 4, 8, 3, [2, 4, 4]),
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]),
|
||||
(15, 8, 16, 3, [7, 5, 3]),
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Unpack the data
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
|
||||
# Check that we get back the original data (within fp8 precision)
|
||||
assert unpacked.shape == x.shape
|
||||
x_f32 = x.to(torch.float32)
|
||||
unpacked_f32 = unpacked.to(torch.float32)
|
||||
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Unpack without explicit start locations (computed in kernel)
|
||||
unpacked_with_loc = unpack_seq_triton(packed, lengths)
|
||||
assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def test_unpack_seq_triton_edge_cases_fp8():
|
||||
"""Test unpack function with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
# Only compare the first 3 elements that were actually packed
|
||||
assert_close(
|
||||
x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2
|
||||
)
|
||||
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2)
|
||||
639
tests/kernels/attention/test_prefix_prefill.py
Normal file
639
tests/kernels/attention/test_prefix_prefill.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 64]
|
||||
HEAD_SIZES = [24, 128]
|
||||
DTYPES = [torch.float16]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
SLIDING_WINDOW = [0, 16, 2048]
|
||||
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||
|
||||
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
||||
|
||||
|
||||
def create_causal_attention_mask_for_sdpa(
|
||||
query_lens: list[int],
|
||||
seq_lens: list[int],
|
||||
sliding_window: int = 0,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
total_queries = sum(query_lens)
|
||||
total_keys = sum(seq_lens)
|
||||
|
||||
# Create a mask filled with -inf
|
||||
mask = torch.full(
|
||||
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
|
||||
)
|
||||
|
||||
query_start = 0
|
||||
key_start = 0
|
||||
|
||||
for query_len, seq_len in zip(query_lens, seq_lens):
|
||||
query_end = query_start + query_len
|
||||
key_end = key_start + seq_len
|
||||
q_indices = torch.arange(query_len, device=device)
|
||||
k_indices = torch.arange(seq_len, device=device)
|
||||
q_pos_in_seq = seq_len - query_len + q_indices
|
||||
|
||||
valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]
|
||||
|
||||
if sliding_window > 0:
|
||||
valid_mask &= k_indices[None, :] >= (
|
||||
q_pos_in_seq[:, None] - sliding_window + 1
|
||||
)
|
||||
|
||||
mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0
|
||||
|
||||
query_start = query_end
|
||||
key_start = key_end
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def create_alibi_causal_mask(
|
||||
query_len: int,
|
||||
seq_len: int,
|
||||
alibi_slopes: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
query_pos = torch.arange(
|
||||
seq_len - query_len, seq_len, device=device, dtype=torch.float32
|
||||
)
|
||||
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
rel_pos = key_pos[None, :] - query_pos[:, None]
|
||||
|
||||
# Apply ALiBi slopes: [num_heads, query_len, seq_len]
|
||||
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
|
||||
alibi_bias = alibi_bias.to(dtype)
|
||||
|
||||
# Apply causal mask: prevent attending to future positions
|
||||
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
|
||||
causal_mask = key_pos[None, :] <= query_pos[:, None]
|
||||
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))
|
||||
|
||||
# Add batch dimension: [1, num_heads, query_len, seq_len]
|
||||
# SDPA expects batch dimension even for single sequences
|
||||
return alibi_bias.unsqueeze(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||
@pytest.mark.parametrize("op", OPS)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and op is chunked_prefill_paged_decode
|
||||
and kv_cache_dtype == "fp8_e5m2"
|
||||
):
|
||||
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
||||
#
|
||||
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
# ensure one sequence in batch is a decode
|
||||
query_lens[-1] = 1
|
||||
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
if kv_cache_dtype == "auto":
|
||||
cache_dtype = dtype
|
||||
else:
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
k_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.int32)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
|
||||
torch.int32
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||
if cur_ctx + block_size > b_ctx_len[i]:
|
||||
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||
else:
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc]
|
||||
)
|
||||
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc]
|
||||
)
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = (
|
||||
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
|
||||
.permute(0, 2, 3, 1, 4)
|
||||
.contiguous()
|
||||
)
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = (
|
||||
v_cache.view(-1, block_size, num_kv_heads, head_size)
|
||||
.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
)
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
|
||||
# (1, num_heads, seq_len, head_size)
|
||||
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
|
||||
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, num_tokens, head_size
|
||||
)
|
||||
|
||||
# Expand key and value for GQA/MQA to match query heads
|
||||
key_sdpa = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, sum(seq_lens), head_size
|
||||
)
|
||||
|
||||
value_sdpa = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, sum(seq_lens), head_size
|
||||
)
|
||||
|
||||
attn_mask = create_causal_attention_mask_for_sdpa(
|
||||
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
output_ref = F.scaled_dot_product_attention(
|
||||
query_sdpa,
|
||||
key_sdpa,
|
||||
value_sdpa,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
output_ref = F.scaled_dot_product_attention(
|
||||
query_sdpa,
|
||||
key_sdpa,
|
||||
value_sdpa,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
|
||||
# Reshape output back to (num_tokens, num_heads, head_size)
|
||||
output_ref = output_ref.view(num_heads, num_tokens, head_size)
|
||||
output_ref = output_ref.permute(1, 0, 2).contiguous()
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("op", OPS)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_alibi(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and op is chunked_prefill_paged_decode
|
||||
and kv_cache_dtype == "fp8_e5m2"
|
||||
):
|
||||
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
||||
#
|
||||
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(
|
||||
closest_power_of_2, total_num_heads - closest_power_of_2
|
||||
)
|
||||
extra_powers = torch.arange(
|
||||
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
|
||||
)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
if kv_cache_dtype == "auto":
|
||||
cache_dtype = dtype
|
||||
else:
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
k_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.int32)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
|
||||
torch.int32
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||
if cur_ctx + block_size > b_ctx_len[i]:
|
||||
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||
else:
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc]
|
||||
)
|
||||
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc]
|
||||
)
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = (
|
||||
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
|
||||
.permute(0, 2, 3, 1, 4)
|
||||
.contiguous()
|
||||
)
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = (
|
||||
v_cache.view(-1, block_size, num_kv_heads, head_size)
|
||||
.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
)
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# Prepare query, key, value for SDPA
|
||||
# Expand key and value for GQA/MQA to match query heads
|
||||
key_expanded = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value_expanded = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
|
||||
output_ref = torch.empty_like(output)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
query_start = 0
|
||||
key_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
query_end = query_start + query_len
|
||||
key_end = key_start + seq_len
|
||||
|
||||
# Get query, key, value for this sequence
|
||||
q = query[query_start:query_end] # [query_len, num_heads, head_size]
|
||||
k = key_expanded[
|
||||
key_start:key_end
|
||||
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
|
||||
v = value_expanded[
|
||||
key_start:key_end
|
||||
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
|
||||
|
||||
# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
|
||||
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
|
||||
q_sdpa = (
|
||||
q_sdpa.permute(1, 2, 0, 3)
|
||||
.reshape(1, num_heads, query_len, head_size)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
k_sdpa = (
|
||||
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
|
||||
)
|
||||
v_sdpa = (
|
||||
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
|
||||
)
|
||||
|
||||
# Create ALiBi causal mask for this sequence using utility function
|
||||
alibi_mask = create_alibi_causal_mask(
|
||||
query_len, seq_len, alibi_slopes, device, dtype
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_sdpa,
|
||||
k_sdpa,
|
||||
v_sdpa,
|
||||
attn_mask=alibi_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# Reshape output back to [query_len, num_heads, head_size]
|
||||
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
|
||||
output_ref[query_start:query_end].copy_(out)
|
||||
|
||||
query_start = query_end
|
||||
key_start = key_end
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
# These tests are optional to only run when explicitly invoked
|
||||
#
|
||||
# pytest -v -s --optional \
|
||||
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
|
||||
#
|
||||
# These tests are useful to test model dtype float32 on Turing devices.
|
||||
# We skip them to not increase the time when running tests on CI
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||
@pytest.mark.parametrize("op", OPS)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_f32(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
test_contexted_kv_attention(
|
||||
num_heads,
|
||||
num_queries_per_kv,
|
||||
head_size,
|
||||
sliding_window,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
device,
|
||||
op,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("op", OPS)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_alibi_f32(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
test_contexted_kv_attention_alibi(
|
||||
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
|
||||
)
|
||||
55
tests/kernels/attention/test_rocm_attention_selector.py
Normal file
55
tests/kernels/attention/test_rocm_attention_selector.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
|
||||
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
|
||||
# Test standard ROCm attention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
||||
|
||||
# MLA test for deepseek related
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
92
tests/kernels/attention/test_triton_decode_attention.py
Normal file
92
tests/kernels/attention/test_triton_decode_attention.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B", [3, 5])
|
||||
@pytest.mark.parametrize("L", [1027, 1025])
|
||||
@pytest.mark.parametrize("H_Q", [32])
|
||||
@pytest.mark.parametrize("H_KV", [32, 8])
|
||||
@pytest.mark.parametrize("D_QK", [128, 192, 576])
|
||||
@pytest.mark.parametrize("D_V", [128, 512])
|
||||
@pytest.mark.parametrize("CACHE_SIZE", [16384])
|
||||
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
|
||||
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
assert CACHE_SIZE % PAGE_SIZE == 0
|
||||
dtype = torch.bfloat16
|
||||
seq_len = L # This represents the number of tokens already in the sequence
|
||||
sm_scale = 1.0 / (D_QK**0.5)
|
||||
num_kv_splits = 8
|
||||
|
||||
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
||||
req_to_page = torch.randint(
|
||||
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
|
||||
)
|
||||
req_to_token = req_to_page * PAGE_SIZE
|
||||
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
|
||||
req_to_token = req_to_token.view(B, -1)
|
||||
req_to_token = req_to_token[:, :seq_len].contiguous()
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
# Page size is 1.
|
||||
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Call the original implementation.
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
# Page size can be larger than 1.
|
||||
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
||||
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
||||
|
||||
o1 = torch.zeros_like(o)
|
||||
lse1 = torch.zeros_like(lse)
|
||||
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o1,
|
||||
lse1,
|
||||
req_to_page,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
PAGE_SIZE,
|
||||
)
|
||||
|
||||
assert torch.allclose(o, o1)
|
||||
218
tests/kernels/attention/test_triton_unified_attention.py
Normal file
218
tests/kernels/attention/test_triton_unified_attention.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = (
|
||||
[None, torch.float8_e4m3fn]
|
||||
if not current_platform.is_rocm()
|
||||
else [None, torch.float8_e4m3fnuz]
|
||||
)
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
# 0: use 2D kernel for decode
|
||||
# 8: use 3D kernel for decode
|
||||
SEQ_THRESHOLD_3D_VALUES = [0, 8]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query_lens: list[int],
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs: list[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
block_indices = block_tables[i, :num_kv_blocks]
|
||||
|
||||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
k = k[:kv_len]
|
||||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||
v = v[:kv_len]
|
||||
|
||||
if q.shape[1] != k.shape[1]:
|
||||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None and soft_cap > 0:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
|
||||
outputs.append(out)
|
||||
start_idx += query_len
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attn(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
seq_threshold_3D: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
|
||||
maybe_quantized_query = query
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = None # Not yet supported
|
||||
k_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
|
||||
num_par_softmax_segments = 16
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
unified_attention(
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
v=maybe_quantized_value_cache,
|
||||
out=output,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.utils import (create_kv_caches_with_random,
|
||||
create_kv_caches_with_random_flash)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory():
|
||||
return create_kv_caches_with_random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory_flashinfer():
|
||||
return create_kv_caches_with_random_flash
|
||||
144
tests/kernels/core/test_activation.py
Normal file
144
tests/kernels/core/test_activation.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import (
|
||||
FastGELU,
|
||||
FatreluAndMul,
|
||||
GeluAndMul,
|
||||
MulAndSilu,
|
||||
NewGELU,
|
||||
QuickGELU,
|
||||
SiluAndMul,
|
||||
SwigluOAIAndMul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation",
|
||||
[
|
||||
"silu_and_mul",
|
||||
"mul_and_silu",
|
||||
"gelu",
|
||||
"gelu_tanh",
|
||||
"fatrelu",
|
||||
"swigluoai_and_mul",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_act_and_mul(
|
||||
activation: str,
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
if activation == "silu_and_mul":
|
||||
layer = SiluAndMul()
|
||||
fn = torch.ops._C.silu_and_mul
|
||||
if activation == "mul_and_silu":
|
||||
layer = MulAndSilu()
|
||||
fn = torch.ops._C.mul_and_silu
|
||||
elif activation == "gelu":
|
||||
layer = GeluAndMul(approximate="none")
|
||||
fn = torch.ops._C.gelu_and_mul
|
||||
elif activation == "gelu_tanh":
|
||||
layer = GeluAndMul(approximate="tanh")
|
||||
fn = torch.ops._C.gelu_tanh_and_mul
|
||||
elif activation == "fatrelu":
|
||||
threshold = random.uniform(0, 1)
|
||||
layer = FatreluAndMul(threshold)
|
||||
fn = torch.ops._C.fatrelu_and_mul
|
||||
elif activation == "swigluoai_and_mul":
|
||||
layer = SwigluOAIAndMul()
|
||||
fn = torch.ops._C.swigluoai_and_mul
|
||||
out = layer(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
if activation == "swigluoai_and_mul":
|
||||
rtol = {
|
||||
# For fp16, change the relative tolerance from 1e-3 to 2e-3
|
||||
torch.float16: 2e-3,
|
||||
torch.bfloat16: 2e-2,
|
||||
torch.float: 1.3e-6,
|
||||
}
|
||||
|
||||
def _get_rtol(output) -> float:
|
||||
return rtol[output.dtype]
|
||||
|
||||
torch.testing.assert_close(
|
||||
out, ref_out, atol=get_default_atol(out), rtol=_get_rtol(out)
|
||||
)
|
||||
else:
|
||||
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
|
||||
# equivalent to the native PyTorch implementations, so we can do exact
|
||||
# comparison.
|
||||
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "fatrelu":
|
||||
opcheck(fn, (out, x, threshold))
|
||||
elif activation == "swigluoai_and_mul":
|
||||
opcheck(fn, (out, x, layer.alpha, layer.limit))
|
||||
else:
|
||||
opcheck(fn, (out, x))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation",
|
||||
[
|
||||
(FastGELU, torch.ops._C.gelu_fast),
|
||||
(NewGELU, torch.ops._C.gelu_new),
|
||||
(QuickGELU, torch.ops._C.gelu_quick),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_activation(
|
||||
activation: type[torch.nn.Module],
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = activation[0]()
|
||||
fn = activation[1]
|
||||
out = layer(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
torch.testing.assert_close(
|
||||
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
|
||||
)
|
||||
|
||||
out = torch.empty_like(x)
|
||||
opcheck(fn, (out, x))
|
||||
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
|
||||
|
||||
This test ensures that RotaryEmbedding classes correctly call the appropriate
|
||||
ApplyRotaryEmb methods based on the calling context:
|
||||
|
||||
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
|
||||
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RotaryEmbeddingTestCase:
|
||||
"""Test case configuration for RotaryEmbedding dispatch tests."""
|
||||
|
||||
name: str
|
||||
rope_class: type
|
||||
rope_kwargs: dict
|
||||
method_name: str # forward_native, forward_cuda, forward
|
||||
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
|
||||
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
|
||||
expect_forward: bool # Should call ApplyRotaryEmb.forward()
|
||||
|
||||
|
||||
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
|
||||
"""Generate test cases for all RotaryEmbedding classes."""
|
||||
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
|
||||
Ernie4_5_VLRotaryEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
|
||||
|
||||
common_kwargs = {
|
||||
"head_size": 128,
|
||||
"rotary_dim": 128,
|
||||
"max_position_embeddings": 4096,
|
||||
"base": 10000,
|
||||
"is_neox_style": True,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
return [
|
||||
# MRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="MRotaryEmbedding.forward_native",
|
||||
rope_class=MRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||
method_name="forward_native",
|
||||
positions_shape=(3, 32), # 2D for multimodal
|
||||
expect_forward_native=True,
|
||||
expect_forward=False,
|
||||
),
|
||||
RotaryEmbeddingTestCase(
|
||||
name="MRotaryEmbedding.forward_cuda_1d",
|
||||
rope_class=MRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||
method_name="forward_cuda",
|
||||
positions_shape=(32,), # 1D triggers apply_rotary_emb path
|
||||
expect_forward_native=False,
|
||||
expect_forward=True,
|
||||
),
|
||||
# XDRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="XDRotaryEmbedding.forward",
|
||||
rope_class=XDRotaryEmbedding,
|
||||
rope_kwargs={
|
||||
**common_kwargs,
|
||||
"scaling_alpha": 1.0,
|
||||
"xdrope_section": [16, 16, 16, 16],
|
||||
},
|
||||
method_name="forward",
|
||||
positions_shape=(4, 32), # 4D for P/W/H/T
|
||||
expect_forward_native=False,
|
||||
expect_forward=True,
|
||||
),
|
||||
# Ernie4_5_VLRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="Ernie4_5_VLRotaryEmbedding.forward_native",
|
||||
rope_class=Ernie4_5_VLRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
|
||||
method_name="forward_native",
|
||||
positions_shape=(3, 32), # 2D for multimodal
|
||||
expect_forward_native=True,
|
||||
expect_forward=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_dispatch_test(
|
||||
test_case: RotaryEmbeddingTestCase,
|
||||
device: str,
|
||||
):
|
||||
"""Run a dispatch test for a RotaryEmbedding class."""
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
|
||||
)
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
|
||||
|
||||
apply_rotary_emb = rope.apply_rotary_emb
|
||||
|
||||
# Verify custom op is enabled
|
||||
if test_case.expect_forward_native:
|
||||
assert (
|
||||
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
|
||||
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
|
||||
|
||||
# Setup call tracking
|
||||
call_tracker = {"forward_native_called": False, "forward_called": False}
|
||||
original_forward_native = apply_rotary_emb.forward_native
|
||||
original_forward = apply_rotary_emb.forward
|
||||
|
||||
def tracked_forward_native(*args, **kwargs):
|
||||
call_tracker["forward_native_called"] = True
|
||||
return original_forward_native(*args, **kwargs)
|
||||
|
||||
def tracked_forward(*args, **kwargs):
|
||||
call_tracker["forward_called"] = True
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
apply_rotary_emb.forward_native = tracked_forward_native
|
||||
apply_rotary_emb.forward = tracked_forward
|
||||
|
||||
try:
|
||||
num_tokens = test_case.positions_shape[-1]
|
||||
num_q_heads = 8
|
||||
num_kv_heads = 2
|
||||
head_size = test_case.rope_kwargs["head_size"]
|
||||
max_position = test_case.rope_kwargs["max_position_embeddings"]
|
||||
|
||||
positions = torch.randint(
|
||||
0, max_position // 4, test_case.positions_shape, device=device
|
||||
)
|
||||
query = torch.randn(
|
||||
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
num_tokens,
|
||||
num_kv_heads * head_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Call the method under test
|
||||
method = getattr(rope, test_case.method_name)
|
||||
method(positions, query.clone(), key.clone())
|
||||
|
||||
# Verify expectations
|
||||
if test_case.expect_forward_native:
|
||||
assert call_tracker["forward_native_called"], (
|
||||
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
|
||||
)
|
||||
if not test_case.expect_forward:
|
||||
assert not call_tracker["forward_called"], (
|
||||
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
|
||||
"Bug: when +apply_rotary_emb is enabled, forward_native() "
|
||||
"incorrectly dispatches to CUDA/HIP kernels."
|
||||
)
|
||||
if test_case.expect_forward:
|
||||
assert call_tracker["forward_called"], (
|
||||
f"{test_case.name} should call ApplyRotaryEmb.forward()"
|
||||
)
|
||||
finally:
|
||||
apply_rotary_emb.forward_native = original_forward_native
|
||||
apply_rotary_emb.forward = original_forward
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_rotary_embedding_dispatch(
|
||||
test_case: RotaryEmbeddingTestCase,
|
||||
device: str,
|
||||
):
|
||||
"""
|
||||
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
|
||||
|
||||
- forward_native methods should call ApplyRotaryEmb.forward_native()
|
||||
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
|
||||
"""
|
||||
run_dispatch_test(test_case, device)
|
||||
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def _apply_qk_norm_rope(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_norm: RMSNorm,
|
||||
k_norm: RMSNorm,
|
||||
rope: RotaryEmbedding,
|
||||
num_heads_q: int,
|
||||
num_heads_kv: int,
|
||||
head_dim: int,
|
||||
) -> torch.Tensor:
|
||||
q_size = num_heads_q * head_dim
|
||||
kv_size = num_heads_kv * head_dim
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm.forward_native(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm.forward_native(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = rope.forward_native(positions, q, k)
|
||||
return torch.cat([q, k, v], dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
num_tokens = 4
|
||||
|
||||
total_dim = (num_heads + 2 * num_kv_heads) * head_dim
|
||||
qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device)
|
||||
qkv_fused = qkv_base.clone()
|
||||
positions = torch.arange(num_tokens, dtype=torch.long, device=device)
|
||||
|
||||
q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
q_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
ref_result = _apply_qk_norm_rope(
|
||||
qkv=qkv_base,
|
||||
positions=positions,
|
||||
q_norm=q_norm,
|
||||
k_norm=k_norm,
|
||||
rope=rope,
|
||||
num_heads_q=num_heads,
|
||||
num_heads_kv=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.fused_qk_norm_rope,
|
||||
(
|
||||
qkv_fused.clone(),
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
),
|
||||
)
|
||||
|
||||
torch.ops._C.fused_qk_norm_rope(
|
||||
qkv_fused,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(
|
||||
qkv_fused,
|
||||
ref_result,
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
237
tests/kernels/core/test_fused_quant_layernorm.py
Normal file
237
tests/kernels/core/test_fused_quant_layernorm.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
)
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
|
||||
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
|
||||
# Avoid combinatorial explosion with full Cartesian product
|
||||
NUM_TOKENS_HIDDEN_SIZES = [
|
||||
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
|
||||
*[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
|
||||
*[(4096, i) for i in [1, 64, 5137]],
|
||||
]
|
||||
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SCALE_UBS = [True, False]
|
||||
GROUP_SIZES = [None, [1, 64], [1, 128]]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
## Helpers
|
||||
|
||||
|
||||
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def ref_rms_norm(
|
||||
rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor | None
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, residual = rms_norm_layer.forward_native(x, residual)
|
||||
else:
|
||||
out = rms_norm_layer.forward_native(x)
|
||||
|
||||
return out, residual
|
||||
|
||||
|
||||
def ref_dynamic_per_token_or_block_quant(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
# Norm
|
||||
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
|
||||
|
||||
# Quant
|
||||
if group_size is not None:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = per_token_group_quant_fp8(
|
||||
torch_out, group_size=group_size[1], use_ue8m0=False
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales = per_token_group_quant_int8(
|
||||
torch_out, group_size=group_size[1]
|
||||
)
|
||||
else:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = ops.scaled_fp8_quant(
|
||||
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
|
||||
|
||||
return torch_out, scales, residual
|
||||
|
||||
|
||||
def ref_impl(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ref_dynamic_per_token_or_block_quant(
|
||||
rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
|
||||
def ops_dynamic_per_token_or_block_quant(
|
||||
weight: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
if group_size is not None:
|
||||
out, scales = ops.rms_norm_per_block_quant(
|
||||
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True
|
||||
)
|
||||
scales = scales.contiguous()
|
||||
else:
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, weight, EPS, quant_dtype, scale_ub, residual
|
||||
)
|
||||
return out, scales, residual
|
||||
|
||||
|
||||
def ops_impl(
|
||||
weight: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ops_dynamic_per_token_or_block_quant(
|
||||
weight, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
has_scale_ub: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int] | None,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if group_size is not None and hidden_size % group_size[1] != 0:
|
||||
# skip
|
||||
return
|
||||
|
||||
if group_size is not None and has_scale_ub:
|
||||
# blockwise baseline doesn't support scale_ub
|
||||
return
|
||||
|
||||
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
|
||||
# skip
|
||||
return
|
||||
|
||||
layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)
|
||||
|
||||
# Make weights
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
|
||||
# Make inputs
|
||||
scale = 1 / (hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
if has_scale_ub:
|
||||
rms_x, _ = ref_rms_norm(layer, x, residual)
|
||||
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
|
||||
else:
|
||||
scale_ub = None
|
||||
|
||||
ref_out, ref_scales, ref_residual = ref_impl(
|
||||
layer, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
ops_out, ops_scales, ops_residual = ops_impl(
|
||||
layer.weight, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
assert ref_out.dtype == quant_dtype
|
||||
assert ops_out.dtype == quant_dtype
|
||||
if quant_dtype == torch.int8:
|
||||
assert torch.allclose(ref_scales, ops_scales, atol=1e-6)
|
||||
# big atol to account for round-off errors.
|
||||
assert torch.allclose(ref_out, ops_out, atol=1)
|
||||
else:
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
a = ref_out.to(dtype=torch.float32)
|
||||
b = ops_out.to(dtype=torch.float32)
|
||||
ok = torch.allclose(a, b, atol=1e-6)
|
||||
if not ok:
|
||||
# fallback: compare dequantized values with relaxed tolerance
|
||||
if group_size is None:
|
||||
a_deq = a * ref_scales.view(-1, 1)
|
||||
b_deq = b * ops_scales.view(-1, 1)
|
||||
else:
|
||||
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
|
||||
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
|
||||
# NOTE: It is possible that some future test cases trigger this
|
||||
# max diff due to precision issues. If such an error is
|
||||
# encountered, it's recommended to inspect the differences between
|
||||
# all corresponding elements from each tensor (e.g. by looping over
|
||||
# them) and checking how many the max diff error shows up on (just
|
||||
# a few bad elements should still be considered acceptable).
|
||||
ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2)
|
||||
assert ok
|
||||
if add_residual:
|
||||
assert torch.allclose(ref_residual, ops_residual)
|
||||
|
||||
output = torch.empty_like(x, dtype=quant_dtype)
|
||||
scales = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant,
|
||||
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
|
||||
)
|
||||
153
tests/kernels/core/test_layernorm.py
Normal file
153
tests/kernels/core/test_layernorm.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import FP8_DTYPE
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [8, 768, 769, 5120, 5125, 8192] # Arbitrary values for testing
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("strided_input", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
strided_input: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||
x = torch.randn(num_tokens, last_dim, dtype=dtype)
|
||||
x = x[..., :hidden_size]
|
||||
assert x.is_contiguous() != strided_input
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_out = layer.forward_native(x, residual)
|
||||
out = layer(x, residual)
|
||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||
# numerical errors than other operators because they involve reductions.
|
||||
# Therefore, we use a larger tolerance.
|
||||
if add_residual:
|
||||
torch.testing.assert_close(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
if residual is not None:
|
||||
opcheck(
|
||||
torch.ops._C.fused_add_rms_norm,
|
||||
(x, residual, layer.weight.data, layer.variance_epsilon),
|
||||
)
|
||||
else:
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("strided_input", [False, True])
|
||||
def test_fused_rms_norm_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_scale: float,
|
||||
seed: int,
|
||||
device: str,
|
||||
strided_input: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
|
||||
x = x_base[..., :hidden_size]
|
||||
assert x.is_contiguous() != strided_input
|
||||
|
||||
x *= scale
|
||||
if add_residual:
|
||||
residual = torch.randn_like(x) * scale
|
||||
residual_fused = residual.clone()
|
||||
else:
|
||||
residual = residual_fused = None
|
||||
|
||||
out_norm = torch.empty_like(x)
|
||||
out_quant = torch.empty_like(x, dtype=FP8_DTYPE)
|
||||
out_quant_fused = torch.empty_like(out_quant)
|
||||
|
||||
quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32)
|
||||
|
||||
if add_residual:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6
|
||||
)
|
||||
|
||||
# Unfused kernel is in-place so it goes second
|
||||
# Also use a separate clone of x to avoid modifying the input
|
||||
x_unfused_base = x_base.clone()
|
||||
x_unfused = x_unfused_base[..., :hidden_size]
|
||||
assert x_unfused.is_contiguous() != strided_input
|
||||
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
out_quant, x_unfused.contiguous(), quant_scale_t
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2)
|
||||
opcheck(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6),
|
||||
)
|
||||
else:
|
||||
torch.ops._C.rms_norm_static_fp8_quant(
|
||||
out_quant_fused, x, weight, quant_scale_t, 1e-6
|
||||
)
|
||||
|
||||
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, weight, quant_scale_t, 1e-6),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
out_quant.to(dtype=torch.float32),
|
||||
out_quant_fused.to(dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
253
tests/kernels/core/test_mrope.py
Normal file
253
tests/kernels/core/test_mrope.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
max_position_embeddings: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Generate test data for given configuration."""
|
||||
current_platform.seed_everything(42)
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(
|
||||
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||
)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
class MRoPETestInfo(NamedTuple):
|
||||
model_name: str
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
atol: float = 1e-2
|
||||
rtol: float = 1.6e-2
|
||||
marks: list[pytest.MarkDecorator] = []
|
||||
|
||||
|
||||
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
|
||||
|
||||
MODELS_TO_TEST = [
|
||||
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-4B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
],
|
||||
),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
num_tokens_list = [11, 8192]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_info, model_name",
|
||||
[
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope(
|
||||
model_name: str,
|
||||
model_info: MRoPETestInfo,
|
||||
tp_size: int,
|
||||
dtype: torch.dtype,
|
||||
num_tokens: int,
|
||||
):
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = get_config(model_name, False).get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = (
|
||||
config.head_dim
|
||||
if hasattr(config, "head_dim")
|
||||
else config.hidden_size // total_num_heads
|
||||
)
|
||||
is_neox_style = True
|
||||
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
query_native, key_native = mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
query_cuda, key_cuda = mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model_info, model_name",
|
||||
[
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope_torch_compile_tracing(
|
||||
model_name: str,
|
||||
model_info: MRoPETestInfo,
|
||||
tp_size: int,
|
||||
dtype: torch.dtype,
|
||||
num_tokens: int,
|
||||
):
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = get_config(model_name, False).get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = (
|
||||
config.head_dim
|
||||
if hasattr(config, "head_dim")
|
||||
else config.hidden_size // total_num_heads
|
||||
)
|
||||
is_neox_style = True
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
# Generate test data
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
# Create a wrapper that makes the in-place function appear functional
|
||||
def functional_forward_cuda(pos, q, k):
|
||||
"""Wrapper that converts in-place operation to functional style
|
||||
|
||||
CUDA Graph does not support in-place operations.
|
||||
This wrapper creates working copies of the
|
||||
input tensors and modifies them.
|
||||
"""
|
||||
q_work = q.clone() # Create working copies
|
||||
k_work = k.clone()
|
||||
# Your in-place function modifies q_work and k_work
|
||||
mrope_helper_class.forward_cuda(pos, q_work, k_work)
|
||||
return q_work, k_work # Return the modified tensors
|
||||
|
||||
# Get reference results
|
||||
query_native, key_native = mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
try:
|
||||
compiled_forward_cuda = torch.compile(
|
||||
functional_forward_cuda,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
# Run compiled version
|
||||
query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
)
|
||||
|
||||
# Run original version for comparison
|
||||
query_cuda = query.clone()
|
||||
key_cuda = key.clone()
|
||||
mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)
|
||||
|
||||
# Verify results
|
||||
torch.testing.assert_close(
|
||||
query_compiled_cuda, query_cuda, atol=atol, rtol=rtol
|
||||
)
|
||||
torch.testing.assert_close(key_compiled_cuda, key_cuda, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
query_compiled_cuda, query_native, atol=atol, rtol=rtol
|
||||
)
|
||||
torch.testing.assert_close(key_compiled_cuda, key_native, atol=atol, rtol=rtol)
|
||||
|
||||
print("✓ forward_cuda successfully traced with torch.compile inductor")
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"forward_cuda failed to trace with torch.compile inductor: {e}")
|
||||
26
tests/kernels/core/test_opcheck.py
Normal file
26
tests/kernels/core/test_opcheck.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for miscellaneous utilities
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
|
||||
|
||||
def test_convert_fp8_opcheck():
|
||||
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
|
||||
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
|
||||
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
|
||||
|
||||
|
||||
# TODO: Add this back, currently fails with
|
||||
# csrc/cuda_utils_kernels.cu:15 'invalid argument'
|
||||
# @pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
# reason="Only supported for CUDA")
|
||||
# def test_cuda_utils_opcheck():
|
||||
# opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
|
||||
# opcheck(
|
||||
# torch.ops._C_cuda_utils.
|
||||
# get_max_shared_memory_per_block_device_attribute, (0, ))
|
||||
18
tests/kernels/core/test_permute_cols.py
Normal file
18
tests/kernels/core/test_permute_cols.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import permute_cols
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_permute_cols(shape, dtype):
|
||||
x = torch.randn(shape, dtype=dtype).cuda()
|
||||
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
|
||||
opcheck(torch.ops._C.permute_cols, (x, perm))
|
||||
y = permute_cols(x, perm)
|
||||
torch.testing.assert_close(y, x[:, perm])
|
||||
193
tests/kernels/core/test_pos_encoding.py
Normal file
193
tests/kernels/core/test_pos_encoding.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
USE_KEY = [True, False]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads * head_size)
|
||||
|
||||
|
||||
# For testing sliced tensors
|
||||
def _get_padded_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size + 64)
|
||||
|
||||
|
||||
def _get_batch_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size)
|
||||
|
||||
|
||||
TENSORS_SHAPES_FN = [
|
||||
_get_batch_tensor_shape,
|
||||
_get_flat_tensor_shape,
|
||||
_get_padded_tensor_shape,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: int | None,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# slice tensor if required, noop otherwise
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions, query, key)
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(
|
||||
out_query,
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query),
|
||||
)
|
||||
if use_key:
|
||||
torch.testing.assert_close(
|
||||
out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key),
|
||||
)
|
||||
else:
|
||||
assert ref_key is None and out_key is None, "expected returned key to be None"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
ROPE_THETAS = [10000, 1000000]
|
||||
ROPE_PARAMETERS = (
|
||||
{"rope_type": "default"},
|
||||
{"rope_type": "linear", "factor": (1,)},
|
||||
{"rope_type": "dynamic", "factor": 1},
|
||||
)
|
||||
settings = (
|
||||
HEAD_SIZES,
|
||||
ROTARY_DIMS,
|
||||
MAX_POSITIONS,
|
||||
ROPE_THETAS,
|
||||
IS_NEOX_STYLE,
|
||||
ROPE_PARAMETERS,
|
||||
DTYPES,
|
||||
)
|
||||
rope_setting_id_map: dict[str, int] = {}
|
||||
for setting in product(*settings):
|
||||
(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
)
|
||||
# different settings cannot share the same rope module
|
||||
assert id(rope) not in rope_setting_id_map.values()
|
||||
assert all(x.dtype == dtype for x in rope.buffers())
|
||||
assert all(x.dtype == dtype for x in rope.parameters())
|
||||
rope_setting_id_map[str(setting)] = id(rope)
|
||||
|
||||
for setting in product(*settings):
|
||||
(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
)
|
||||
# check if cache take effect
|
||||
assert id(rope) == rope_setting_id_map[str(setting)]
|
||||
76
tests/kernels/core/test_rotary_embedding.py
Normal file
76
tests/kernels/core/test_rotary_embedding.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for miscellaneous utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
def rotary_embedding_opcheck(
|
||||
rot,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
):
|
||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
# ops.rotary_embedding() is a in-place operation
|
||||
# that updates the query and key tensors.
|
||||
opcheck(
|
||||
torch.ops._C.rotary_embedding,
|
||||
(positions, query, key, rot.head_size, cos_sin_cache, rot.is_neox_style),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
|
||||
@pytest.mark.parametrize("is_neox_style", [True, False])
|
||||
@pytest.mark.parametrize("rotary_dim", [32])
|
||||
@pytest.mark.parametrize("head_size", [32, 108])
|
||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||
@pytest.mark.parametrize("use_key", [True, False])
|
||||
@pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
|
||||
def test_rotary_embedding_opcheck(
|
||||
dist_init,
|
||||
device,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rotary_dim,
|
||||
head_size,
|
||||
seq_len,
|
||||
use_key,
|
||||
head_stride_is_contiguous,
|
||||
):
|
||||
batch_size = 1
|
||||
base = 10000
|
||||
num_heads = 7
|
||||
rot = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, torch.float32
|
||||
)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
|
||||
head_stride = head_size + (64 if head_stride_is_contiguous else 0)
|
||||
|
||||
query = torch.randn(
|
||||
batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device
|
||||
)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
rotary_embedding_opcheck(rot, positions, query, key)
|
||||
|
||||
# if we have a contiguous head stride, test the alternate
|
||||
# [..., num_heads * head_dim] shape/layout
|
||||
if head_stride_is_contiguous:
|
||||
rotary_embedding_opcheck(
|
||||
rot,
|
||||
positions,
|
||||
query.flatten(start_dim=-2),
|
||||
key.flatten(start_dim=-2) if use_key else None,
|
||||
)
|
||||
53
tests/kernels/core/test_uva.py
Normal file
53
tests/kernels/core/test_uva.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cpu_write(device):
|
||||
torch.set_default_device(device)
|
||||
cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32)
|
||||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
assert cuda_view.device.type == "cuda"
|
||||
|
||||
assert cuda_view[0, 0] == 0
|
||||
assert cuda_view[2, 3] == 0
|
||||
assert cuda_view[4, 5] == 0
|
||||
|
||||
cpu_tensor[0, 0] = 1
|
||||
cpu_tensor[2, 3] = 2
|
||||
cpu_tensor[4, 5] = -1
|
||||
|
||||
cuda_view.mul_(2)
|
||||
assert cuda_view[0, 0] == 2
|
||||
assert cuda_view[2, 3] == 4
|
||||
assert cuda_view[4, 5] == -2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_gpu_write(device):
|
||||
torch.set_default_device(device)
|
||||
cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32)
|
||||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
assert cuda_view.device.type == "cuda"
|
||||
|
||||
assert cuda_view[0, 0] == 0
|
||||
assert cuda_view[2, 3] == 0
|
||||
assert cuda_view[4, 5] == 0
|
||||
|
||||
cuda_view[0, 0] = 1
|
||||
cuda_view[2, 3] = 2
|
||||
cuda_view[4, 5] = -1
|
||||
cuda_view.mul_(2)
|
||||
|
||||
assert cpu_tensor[0, 0] == 2
|
||||
assert cpu_tensor[2, 3] == 4
|
||||
assert cpu_tensor[4, 5] == -2
|
||||
373
tests/kernels/mamba/test_causal_conv1d.py
Normal file
373
tests/kernels/mamba/test_causal_conv1d.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
initial_states: torch.Tensor | None = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in
|
||||
) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(
|
||||
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
batch, dim, seqlen = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = torch.cat([conv_state, x], dim=-1).to(
|
||||
weight.dtype
|
||||
) # (batch, dim, state_len + seqlen)
|
||||
conv_state.copy_(x_new[:, :, -state_len:])
|
||||
else:
|
||||
width_idx = torch.arange(
|
||||
-(width - 1), 0, dtype=torch.long, device=x.device
|
||||
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
||||
width_idx = (
|
||||
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
)
|
||||
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
||||
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
|
||||
0
|
||||
) + cache_seqlens.unsqueeze(1)
|
||||
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
|
||||
:, :, -seqlen:
|
||||
]
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
def causal_conv1d_opcheck_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
cu_seq_len: torch.Tensor | None = None,
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
conv_states: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
seq_idx: (batch, seqlen)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1), to be written to
|
||||
activation: either None or "silu" or "swish"
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=conv_state_indices,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref, conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||
@pytest.mark.parametrize("width", [3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_causal_conv1d_update_with_batch_gather(
|
||||
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
|
||||
):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
# total_entries = number of cache line
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||
x = torch.randn(
|
||||
padded_batch_size, seqlen, dim, device=device, dtype=itype
|
||||
).transpose(1, 2)
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device
|
||||
)
|
||||
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
|
||||
unused_states_bool[conv_state_indices] = False
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
conv_state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# conv_state will be (cache_lines, dim, state_len)
|
||||
# with contiguous along dim-axis
|
||||
conv_state = torch.randn(
|
||||
total_entries, width - 1, dim, device=device, dtype=itype
|
||||
).transpose(1, 2)
|
||||
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
out = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_state_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
out_ref = causal_conv1d_update_ref(
|
||||
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
|
||||
)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.equal(
|
||||
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
|
||||
)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("seqlen", [8, 249, 4096])
|
||||
@pytest.mark.parametrize("dim", [64, 4096])
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@pytest.mark.parametrize("batch", [4, 10])
|
||||
def test_causal_conv1d_varlen(
|
||||
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
|
||||
):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = batch
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
|
||||
).tolist()
|
||||
)
|
||||
assert sum(seqlens[-1]) == seqlen
|
||||
assert all(s > 0 for s in seqlens[-1])
|
||||
|
||||
total_entries = batch_size * 10
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
|
||||
x = rearrange(
|
||||
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||
"b s d -> b d s",
|
||||
)[:, 4096 : 4096 + dim, :]
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(
|
||||
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
|
||||
).transpose(1, 2)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(
|
||||
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
|
||||
)
|
||||
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
|
||||
:batch_size
|
||||
]
|
||||
padded_state_indices = torch.concat(
|
||||
[
|
||||
state_indices,
|
||||
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
out = causal_conv1d_fn(
|
||||
x.squeeze(0),
|
||||
weight,
|
||||
bias=bias,
|
||||
conv_states=final_states,
|
||||
query_start_loc=cumsum.cuda(),
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
)
|
||||
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
|
||||
for i in range(len(seqlens[0])):
|
||||
x_s = [v[i].unsqueeze(0) for v in splits][0]
|
||||
if padded_state_indices[i] == PAD_SLOT_ID:
|
||||
continue
|
||||
out_ref_b.append(
|
||||
causal_conv1d_ref(
|
||||
x_s,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
|
||||
initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0)
|
||||
if has_initial_states[i]
|
||||
else None,
|
||||
)
|
||||
)
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
assert torch.allclose(
|
||||
final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
137
tests/kernels/mamba/test_mamba_mixer2.py
Normal file
137
tests/kernels/mamba/test_mamba_mixer2.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size_n_groups",
|
||||
[
|
||||
(64, 1),
|
||||
(64, 2),
|
||||
(64, 4), # hidden_size be divisible by num_gpus
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_mixer2_gated_norm_multi_gpu(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size_n_groups: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
):
|
||||
hidden_size, n_groups = hidden_size_n_groups
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(
|
||||
num_processes,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
n_groups,
|
||||
dtype,
|
||||
device,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
|
||||
|
||||
|
||||
def mixer2_gated_norm_tensor_parallel(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
n_groups: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# create random weights an inputs
|
||||
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
gate_states = torch.randn(batch_size, seq_len, hidden_size)
|
||||
|
||||
# create gated-norm with TP
|
||||
mixer = Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
mixer.weight.weight_loader(mixer.weight, weight) # load
|
||||
|
||||
# create gated-norm without TP to compute reference
|
||||
# - utilize mock patching to disable TP when
|
||||
with (
|
||||
unittest.mock.patch(
|
||||
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||||
"get_tensor_model_parallel_world_size",
|
||||
return_value=1,
|
||||
),
|
||||
unittest.mock.patch(
|
||||
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
),
|
||||
):
|
||||
mixer_single_gpu = Mixer2RMSNormGated(
|
||||
full_hidden_size=hidden_size,
|
||||
full_n_groups=n_groups,
|
||||
)
|
||||
# assign weight to single-gpu mixer
|
||||
mixer_single_gpu.weight.data = weight
|
||||
|
||||
# generate and compare
|
||||
N = hidden_size // world_size
|
||||
output = mixer(
|
||||
hidden_states[..., local_rank * N : (local_rank + 1) * N],
|
||||
gate_states[..., local_rank * N : (local_rank + 1) * N],
|
||||
)
|
||||
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
ref_output[..., local_rank * N : (local_rank + 1) * N],
|
||||
atol=5e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
1093
tests/kernels/mamba/test_mamba_ssm.py
Normal file
1093
tests/kernels/mamba/test_mamba_ssm.py
Normal file
File diff suppressed because it is too large
Load Diff
569
tests/kernels/mamba/test_mamba_ssm_ssd.py
Normal file
569
tests/kernels/mamba/test_mamba_ssm_ssd.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
|
||||
|
||||
|
||||
# this is the segsum implementation taken from above
|
||||
def segsum(x):
|
||||
"""Calculates segment sum."""
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
Arguments:
|
||||
X: (batch, length, n_heads, d_head)
|
||||
A: (batch, length, n_heads)
|
||||
B: (batch, length, n_heads, d_state)
|
||||
C: (batch, length, n_heads, d_state)
|
||||
Return:
|
||||
Y: (batch, length, n_heads, d_head)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
X, A, B, C = (
|
||||
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
|
||||
)
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
L = torch.exp(segsum(A))
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
|
||||
# chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms
|
||||
# (diagonal and off-diagonal blocks)
|
||||
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||
return Y, final_state
|
||||
|
||||
|
||||
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
|
||||
current_platform.seed_everything(0)
|
||||
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
|
||||
dt = F.softplus(
|
||||
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
|
||||
)
|
||||
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
|
||||
|
||||
return A, dt, X, B, C
|
||||
|
||||
|
||||
def generate_continuous_batched_examples(
|
||||
example_lens_by_batch,
|
||||
num_examples,
|
||||
full_length,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
device="cuda",
|
||||
return_naive_ref=True,
|
||||
):
|
||||
# this function generates a random examples of certain length
|
||||
# and then cut according to "example_lens_by_batch" and feed
|
||||
# them in continuous batches to the kernels.
|
||||
# If if return_naive_ref=True, the naive torch implementation
|
||||
# ssd_minimal_discrete will be used to compute and return
|
||||
# reference output.
|
||||
|
||||
# generate the full-length example
|
||||
A, dt, X, B, C = generate_random_inputs(
|
||||
num_examples, full_length, n_heads, d_head, itype
|
||||
)
|
||||
|
||||
if return_naive_ref:
|
||||
Y_min, final_state_min = ssd_minimal_discrete(
|
||||
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
|
||||
)
|
||||
|
||||
# internal function that outputs a cont batch of examples
|
||||
# given a tuple of lengths for each example in the batch
|
||||
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
||||
# 4 examples from second eg, etc
|
||||
def get_continuous_batch(example_lens: tuple[int, ...]):
|
||||
indices = []
|
||||
for i, x in enumerate(example_lens):
|
||||
c = last_taken.get(i, 0)
|
||||
indices.append((c, c + x))
|
||||
last_taken[i] = (c + x) % full_length
|
||||
exhausted[i] = last_taken[i] == 0
|
||||
|
||||
return (
|
||||
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
|
||||
for x in (dt, X, B, C)
|
||||
)
|
||||
|
||||
# internal function that maps "n" to the appropriate right boundary
|
||||
# value when forming continuous batches from examples of length given
|
||||
# by "full_length".
|
||||
# - e.g., when n > full_length, returns n % full_length
|
||||
# when n == full_length, returns full_length
|
||||
def end_boundary(n: int):
|
||||
return n - ((n - 1) // full_length) * full_length
|
||||
|
||||
IND_E = None
|
||||
for spec in example_lens_by_batch:
|
||||
# get the (maybe partial) example seen in this cont batch
|
||||
dt2, X2, B2, C2 = get_continuous_batch(spec)
|
||||
|
||||
# get the metadata
|
||||
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
|
||||
seq_idx = torch.zeros(
|
||||
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
|
||||
)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
cu_seqlens,
|
||||
cu_seqlens[1:],
|
||||
)
|
||||
):
|
||||
seq_idx[srt:end] = i
|
||||
|
||||
# for cont batch
|
||||
if IND_E is None:
|
||||
IND_S = [0 for _ in range(len(spec))]
|
||||
else:
|
||||
IND_S = [x % full_length for x in IND_E]
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
# varlen has implicit batch=1
|
||||
dt2 = dt2.squeeze(0)
|
||||
X2 = X2.squeeze(0)
|
||||
B2 = B2.squeeze(0)
|
||||
C2 = C2.squeeze(0)
|
||||
yield (
|
||||
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
|
||||
if return_naive_ref
|
||||
else None,
|
||||
cu_seqlens,
|
||||
seq_idx,
|
||||
(A, dt2, X2, B2, C2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("n_heads", [4, 16, 32])
|
||||
@pytest.mark.parametrize("d_head", [5, 8, 32, 128])
|
||||
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
|
||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
|
||||
# this tests the kernels on a single example (bs=1)
|
||||
|
||||
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||
|
||||
if itype == torch.bfloat16:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
else:
|
||||
atol, rtol = 8e-3, 5e-3
|
||||
|
||||
# set seed
|
||||
batch_size = 1 # batch_size
|
||||
# ssd_minimal_discrete requires chunk_size divide seqlen
|
||||
# - this is only required for generating the reference seqs,
|
||||
# it is not an operational limitation.
|
||||
seqlen, chunk_size = seq_len_chunk_size
|
||||
|
||||
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(
|
||||
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
|
||||
)
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
|
||||
)
|
||||
# varlen has implicit batch=1
|
||||
X = X.squeeze(0)
|
||||
dt = dt.squeeze(0)
|
||||
A = A.squeeze(0)
|
||||
B = B.squeeze(0)
|
||||
C = C.squeeze(0)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||
|
||||
# just test the last head
|
||||
# NOTE, in the kernel we always cast states to fp32
|
||||
torch.testing.assert_close(
|
||||
final_state[:, -1].to(torch.float32),
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32])
|
||||
@pytest.mark.parametrize("n_heads", [4, 8])
|
||||
@pytest.mark.parametrize("d_head", [5, 16, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_len_chunk_size_cases",
|
||||
[
|
||||
# small-ish chunk_size (8)
|
||||
(64, 8, 2, [(64, 32), (64, 32)]),
|
||||
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
|
||||
(
|
||||
64,
|
||||
8,
|
||||
2,
|
||||
[(4, 4), (4, 4), (4, 4), (4, 4)],
|
||||
), # chunk_size larger than cont batches
|
||||
(64, 8, 5, [(64, 32, 16, 8, 8)]),
|
||||
# large-ish chunk_size (256)
|
||||
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
|
||||
(
|
||||
64,
|
||||
256,
|
||||
2,
|
||||
[(5, 30), (1, 2), (1, 2), (1, 2)],
|
||||
), # irregular sizes with small sequences
|
||||
# we also need to test some large seqlen
|
||||
# to catch errors with init states decay
|
||||
(768, 128, 2, [(138, 225), (138, 225)]),
|
||||
],
|
||||
)
|
||||
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
|
||||
# this test with multiple examples in a continuous batch
|
||||
# (i.e. chunked prefill)
|
||||
|
||||
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
|
||||
|
||||
# This test can have larger error for longer sequences
|
||||
if seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, _token_seq_idx, (
|
||||
A,
|
||||
dt,
|
||||
X,
|
||||
B,
|
||||
C,
|
||||
) in generate_continuous_batched_examples(
|
||||
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
|
||||
):
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
|
||||
)
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
initial_states=states,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
for i in range(num_examples):
|
||||
# just test one dim and dstate
|
||||
Y_eg = Y[cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
|
||||
Y_min_eg = Y_min[i][:, 0, 0]
|
||||
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||
|
||||
# update states
|
||||
states = new_states
|
||||
for i, clear in exhausted.items():
|
||||
if clear:
|
||||
states[i].fill_(0.0)
|
||||
exhausted[i] = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chunk_size", [8, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlens",
|
||||
[(16, 20), (270, 88, 212, 203)],
|
||||
)
|
||||
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
# This test verifies the correctness of the chunked prefill implementation
|
||||
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
|
||||
# dimension) of chunked results with the full sequence result.
|
||||
# It is different from test_mamba_chunk_scan_cont_batch by:
|
||||
# 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
|
||||
# reference outputs. Instead, it compares chunked kernel outputs to full
|
||||
# sequence kernel outputs. This is the most straightforward way to
|
||||
# assert chunked prefill correctness.
|
||||
# 2. It focuses on cases where sequences change in the middle of mamba
|
||||
# chunks, and not necessarily on chunk boundaries.
|
||||
|
||||
max_seqlen = max(seqlens)
|
||||
# This test can have larger error for longer sequences
|
||||
if max_seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
num_sequences = len(seqlens)
|
||||
n_heads = 16
|
||||
d_head = 64
|
||||
itype = torch.float32
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
|
||||
generate_continuous_batched_examples(
|
||||
[seqlens],
|
||||
num_sequences,
|
||||
max_seqlen,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
return_naive_ref=False,
|
||||
)
|
||||
)
|
||||
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
|
||||
device = X.device
|
||||
|
||||
## full seqlen computation
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
|
||||
)
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_ref,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
# first chunk
|
||||
chunked_seqlens = seqlens // 2
|
||||
chunked_cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
|
||||
)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ...
|
||||
]
|
||||
|
||||
X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
X, i
|
||||
)
|
||||
dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
dt, i
|
||||
)
|
||||
B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
B, i
|
||||
)
|
||||
C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
C, i
|
||||
)
|
||||
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
|
||||
)
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined_varlen(
|
||||
X_chunked,
|
||||
dt_chunked,
|
||||
A,
|
||||
B_chunked,
|
||||
C_chunked,
|
||||
chunk_size,
|
||||
cu_seqlens=chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_partial,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
remaining_chunked_seqlens = seqlens - chunked_seqlens
|
||||
remaining_chunked_cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=device),
|
||||
torch.cumsum(remaining_chunked_seqlens, dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ...
|
||||
]
|
||||
|
||||
remaining_X_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(X, i)
|
||||
remaining_dt_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(dt, i)
|
||||
remaining_B_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(B, i)
|
||||
remaining_C_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(C, i)
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat(
|
||||
[
|
||||
pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...],
|
||||
pt2[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1],
|
||||
...,
|
||||
],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat(
|
||||
[concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0
|
||||
)
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, chunk_size)
|
||||
)
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||
remaining_X_chunked,
|
||||
remaining_dt_chunked,
|
||||
A,
|
||||
remaining_B_chunked,
|
||||
remaining_C_chunked,
|
||||
chunk_size,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_chunked,
|
||||
D=None,
|
||||
initial_states=partial_state,
|
||||
)
|
||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||
|
||||
# kernel chunked is same as kernel overall
|
||||
for i in range(num_sequences):
|
||||
Y_seq = Y[cu_seqlens[i] : cu_seqlens[i + 1], ...]
|
||||
Y_ref_seq = Y_ref[cu_seqlens[i] : cu_seqlens[i + 1], ...]
|
||||
torch.testing.assert_close(
|
||||
Y_seq[: chunked_seqlens[i], ...],
|
||||
Y_ref_seq[: chunked_seqlens[i], ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x, i=i: f"seq{i} output part1 " + x,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
Y_seq[chunked_seqlens[i] :, ...],
|
||||
Y_ref_seq[chunked_seqlens[i] :, ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x, i=i: f"seq{i} output part2 " + x,
|
||||
)
|
||||
|
||||
state_seq = state_chunked[i]
|
||||
state_seq_ref = state_ref[i]
|
||||
torch.testing.assert_close(
|
||||
state_seq,
|
||||
state_seq_ref,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x, i=i: f"seq{i} state " + x,
|
||||
)
|
||||
0
tests/kernels/moe/__init__.py
Normal file
0
tests/kernels/moe/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
164
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
164
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
from .common import Config
|
||||
from .mk_objects import (
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES,
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
|
||||
)
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
|
||||
|
||||
def to_quant_torch_dtype(s: str) -> torch.dtype:
|
||||
if s == "torch.float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported quant type {s}")
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of ranks that participate in all2all",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pf-type",
|
||||
type=to_pf_class_type,
|
||||
required=True,
|
||||
help=(
|
||||
"Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experts-type",
|
||||
type=to_experts_class_type,
|
||||
required=True,
|
||||
help=(
|
||||
f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[64],
|
||||
help="num tokens per rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
type=int,
|
||||
default=7168,
|
||||
help="hidden-size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="N dimension of the first fused-moe matmul",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-experts", type=int, default=32, help="Global num experts"
|
||||
)
|
||||
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl.",
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument(
|
||||
"--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-token-quantized-activations",
|
||||
action="store_true",
|
||||
help=("The input activations must be per-token quantized"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-shape", nargs="+", type=int, help="Quantization block shape"
|
||||
)
|
||||
|
||||
# Torch trace profile generation args
|
||||
parser.add_argument(
|
||||
"--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _validate_args(args: argparse.Namespace):
|
||||
if args.quant_dtype is not None:
|
||||
assert args.quant_dtype == torch.float8_e4m3fn
|
||||
if args.block_shape is not None:
|
||||
assert len(args.block_shape) == 2, (
|
||||
f"block shape must have 2 elements. got {args.block_shape}"
|
||||
)
|
||||
|
||||
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
|
||||
assert args.world_size == 1, "Single GPU objects need world size set to 1"
|
||||
|
||||
if args.torch_trace_dir_path is not None:
|
||||
from pathlib import Path
|
||||
|
||||
assert Path(args.torch_trace_dir_path).is_dir(), (
|
||||
f"Please create {args.torch_trace_dir_path}"
|
||||
)
|
||||
|
||||
|
||||
def make_config(args: argparse.Namespace) -> Config:
|
||||
_validate_args(args)
|
||||
|
||||
quant_config = None
|
||||
if args.quant_dtype is not None:
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
block_shape=args.block_shape,
|
||||
)
|
||||
|
||||
return Config(
|
||||
Ms=args.m,
|
||||
K=args.k,
|
||||
N=args.n,
|
||||
E=args.num_experts,
|
||||
topks=args.topk,
|
||||
dtype=torch.bfloat16, # hard-code
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path,
|
||||
)
|
||||
668
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
668
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
@@ -0,0 +1,668 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
make_fused_experts,
|
||||
make_prepare_finalize,
|
||||
prepare_finalize_info,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
|
||||
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
return f"{name} : {t.shape} {t.dtype} {t.device}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: list[int] | int
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: list[int] | int
|
||||
dtype: torch.dtype
|
||||
quant_config: TestMoEQuantConfig | None
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = TestMoEQuantConfig(None, False, False, None)
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
s += "== Config:\n"
|
||||
s += f" world_size={self.world_size}\n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__}\n"
|
||||
s += f" FE={self.fused_experts_type.__name__}\n"
|
||||
s += f" E={self.E}\n"
|
||||
s += f" Ms={self.Ms}\n"
|
||||
s += f" N={self.N}\n"
|
||||
s += f" K={self.K}\n"
|
||||
s += f" topk={self.topks}\n"
|
||||
s += f" dtype={self.dtype}\n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
|
||||
s += " Quant:\n"
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype}\n"
|
||||
s += f" q_block_shape={self.quant_block_shape}\n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant}\n"
|
||||
else:
|
||||
s += " quant=None\n"
|
||||
return s
|
||||
|
||||
@property
|
||||
def M(self) -> int:
|
||||
assert isinstance(self.Ms, int)
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def is_per_act_token_quant(self) -> bool:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
return not self.is_per_act_token_quant and self.quant_block_shape is None
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> list[int] | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def topk(self) -> int:
|
||||
assert isinstance(self.topks, int)
|
||||
return self.topks
|
||||
|
||||
@property
|
||||
def num_local_experts(self) -> int:
|
||||
return self.E // self.world_size
|
||||
|
||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||
"""
|
||||
make env data for vllm launch.
|
||||
"""
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
env_dict = {
|
||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||
}
|
||||
|
||||
backend = self.all2all_backend()
|
||||
vllm_config.parallel_config.all2all_backend = backend
|
||||
if backend is not None:
|
||||
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
|
||||
)
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None
|
||||
)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return mk.FusedMoEActivationFormat.Standard == info.activation_format
|
||||
|
||||
def fe_supported_types(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supported_dtypes
|
||||
|
||||
def pf_supported_types(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supported_dtypes
|
||||
|
||||
def is_block_quant_supported(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.blocked_quantization_support
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_chunking
|
||||
|
||||
def supports_expert_map(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_expert_map
|
||||
|
||||
def supports_apply_weight_on_input(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supports_apply_weight_on_input
|
||||
|
||||
def needs_deep_gemm(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.needs_deep_gemm
|
||||
|
||||
def needs_pplx(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend == "pplx"
|
||||
|
||||
def needs_deep_ep(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (
|
||||
info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency"
|
||||
)
|
||||
|
||||
def all2all_backend(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
|
||||
def is_valid(self) -> tuple[bool, str | None]:
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
return False, "Mismatched format."
|
||||
else:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False, "Mismatched format."
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False, "Chunking not supported."
|
||||
|
||||
# Check quantization sanity
|
||||
if (
|
||||
int(self.is_per_act_token_quant)
|
||||
+ int(self.is_per_tensor_act_quant)
|
||||
+ int(self.quant_block_shape is not None)
|
||||
) > 1:
|
||||
# invalid quant config
|
||||
return False, f"Bad quant_config {self.quant_config}."
|
||||
|
||||
# check type support
|
||||
if self.quant_dtype is None:
|
||||
if (
|
||||
self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False, (
|
||||
f"Unsupported type {self.dtype} not in "
|
||||
f"{self.pf_supported_types()} and "
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False, (
|
||||
f"Unsupported quant type {self.quant_dtype} "
|
||||
f"not in {self.pf_supported_types()} and "
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
|
||||
# Check block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and self.quant_dtype is None:
|
||||
return False, "No block quantization support."
|
||||
|
||||
if is_block_quatized and not self.is_block_quant_supported():
|
||||
return False, "Mismatched block quantization support."
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False, "Needs DeepGEMM but not block quantized."
|
||||
|
||||
# Check dependencies (turn into asserts?)
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False, "Needs DeepEP, but DeepEP not available."
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
return False, "Needs DeepGEMM, but DeepGEMM not available."
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False, "Needs PPLX, but PPLX not available."
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: torch.Tensor | None
|
||||
w2_scale: torch.Tensor | None
|
||||
w1_gs: torch.Tensor | None = None
|
||||
w2_gs: torch.Tensor | None = None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f" - {_describe_tensor(self.w1, 'w1')} \n"
|
||||
s += f" - {_describe_tensor(self.w2, 'w2')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
|
||||
return s
|
||||
|
||||
def is_quantized(self) -> bool:
|
||||
# or w1_scale is not None?
|
||||
return (
|
||||
self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8
|
||||
or self.w1.dtype == torch.int8
|
||||
)
|
||||
|
||||
def to_current_device(self):
|
||||
device = torch.cuda.current_device()
|
||||
self.w1 = self.w1.to(device=device)
|
||||
self.w2 = self.w2.to(device=device)
|
||||
|
||||
if self.w1_scale is not None:
|
||||
self.w1_scale = self.w1_scale.to(device=device)
|
||||
if self.w2_scale is not None:
|
||||
self.w2_scale = self.w2_scale.to(device=device)
|
||||
|
||||
if self.w1_gs is not None:
|
||||
self.w1_gs = self.w1_gs.to(device=device)
|
||||
if self.w2_gs is not None:
|
||||
self.w2_gs = self.w2_gs.to(device=device)
|
||||
|
||||
def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config) -> "WeightTensors":
|
||||
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
# or config.is_per_out_ch_quant
|
||||
per_out_ch_quant=config.is_per_act_token_quant,
|
||||
)
|
||||
return WeightTensors(
|
||||
w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: torch.Tensor | None
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: torch.Tensor | None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
|
||||
s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
|
||||
s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
|
||||
# We dequant and use that as hidden_states so the tests are stable.
|
||||
# quantizing and dequantizing yield slightly different results
|
||||
# depending on the hardware. Here we, quantize and dequantize
|
||||
# first - so further quantize and dequantize will yield the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
|
||||
dtype
|
||||
), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts, config.E)
|
||||
score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
topk_ids[mi] = torch.randperm(config.E)[:topk]
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), fill_value=-1, dtype=torch.int32
|
||||
)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(
|
||||
device=torch.cuda.current_device(), dtype=torch.int32
|
||||
)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(
|
||||
config: Config, weights: WeightTensors, rank_tensors: RankTensors
|
||||
) -> torch.Tensor:
|
||||
if config.quant_dtype == "nvfp4":
|
||||
quant_blocksize = 16
|
||||
dtype = config.dtype
|
||||
|
||||
w1_q = weights.w1
|
||||
w1_blockscale = weights.w1_scale
|
||||
w1_gs = weights.w1_gs
|
||||
|
||||
w2_q = weights.w2
|
||||
w2_blockscale = weights.w2_scale
|
||||
w2_gs = weights.w2_gs
|
||||
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
|
||||
/ torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
assert w1_blockscale.shape[1] % 128 == 0
|
||||
assert w1_blockscale.shape[2] % 4 == 0
|
||||
assert w2_blockscale.shape[1] % 128 == 0
|
||||
assert w2_blockscale.shape[2] % 4 == 0
|
||||
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
rank_tensors.hidden_states, a_global_scale
|
||||
)
|
||||
|
||||
a = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
e = w1_q.shape[0]
|
||||
n = w1_q.shape[1] // 2
|
||||
k = w2_q.shape[1]
|
||||
|
||||
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
a_scale = None
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
quant_dtype = None
|
||||
per_act_token_quant = False
|
||||
block_shape = None
|
||||
else:
|
||||
a = rank_tensors.hidden_states
|
||||
a_scale = rank_tensors.hidden_states_scale
|
||||
w1 = weights.w1
|
||||
w1_scale = weights.w1_scale
|
||||
w2 = weights.w2
|
||||
w2_scale = weights.w2_scale
|
||||
quant_dtype = config.quant_dtype
|
||||
per_act_token_quant = config.is_per_act_token_quant
|
||||
block_shape = config.quant_block_shape
|
||||
|
||||
return torch_experts(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
)
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones(
|
||||
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=config.E,
|
||||
experts_per_token=config.topk,
|
||||
hidden_dim=config.K,
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = make_prepare_finalize(
|
||||
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
|
||||
)
|
||||
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
moe,
|
||||
quant_config,
|
||||
prepare_finalize.num_dispatchers(),
|
||||
config.N,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
|
||||
def run_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
if config.quant_dtype == "nvfp4":
|
||||
gscale = _make_gscale(config.num_local_experts)
|
||||
else:
|
||||
gscale = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
config.quant_dtype,
|
||||
w1_scale=rank_weights.w1_scale,
|
||||
w2_scale=rank_weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
|
||||
a1_gscale=gscale,
|
||||
a2_gscale=gscale,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
per_out_ch_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config, quant_config)
|
||||
|
||||
# impls might update the tensor in place
|
||||
hidden_states = rank_tensors.hidden_states.clone()
|
||||
|
||||
topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
}
|
||||
|
||||
num_tokens = rank_tensors.hidden_states.shape[0]
|
||||
num_tokens_across_dp = torch.tensor(
|
||||
[num_tokens] * config.world_size, device="cuda", dtype=torch.int
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
return out
|
||||
196
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
196
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel,
|
||||
)
|
||||
from .mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
PASS = 1
|
||||
FAIL = 2
|
||||
SKIP = 3
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def make_feature_matrix(csv_file_path: str):
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(
|
||||
config: Config, success: Result, results_df: pd.DataFrame | None = None
|
||||
):
|
||||
config_dict = asdict(config)
|
||||
config_dict["prepare_finalize_type"] = config_dict[
|
||||
"prepare_finalize_type"
|
||||
].__name__
|
||||
config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__
|
||||
config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict["quant_config"]
|
||||
del config_dict["quant_config"]
|
||||
if quant_config_dict is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
result_dict = config_dict | {"success": success.name}
|
||||
|
||||
result_df = pd.DataFrame([result_dict])
|
||||
if results_df is None:
|
||||
results_df = result_df
|
||||
else:
|
||||
results_df = pd.concat([results_df, result_df], ignore_index=True)
|
||||
|
||||
return results_df
|
||||
|
||||
Ms = [64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [[4, 1]]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
FE_TYPES = MK_FUSED_EXPERT_TYPES
|
||||
Q_TYPES = MK_QUANT_CONFIGS
|
||||
|
||||
combinations = list(
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
|
||||
)
|
||||
|
||||
results_df: pd.DataFrame | None = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations
|
||||
):
|
||||
config = Config(
|
||||
Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None,
|
||||
)
|
||||
|
||||
success = None
|
||||
if config.is_valid()[0]:
|
||||
print(f"Running config : {config.describe()} ...")
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(
|
||||
config.world_size,
|
||||
rank_worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
success = Result.PASS
|
||||
except Exception as _:
|
||||
success = Result.FAIL
|
||||
else:
|
||||
success = Result.SKIP
|
||||
|
||||
results_df = add_to_results(config, success, results_df)
|
||||
|
||||
if results_df is not None:
|
||||
results_df.to_csv(f"{csv_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_path = args.feature_matrix_csv_file_path
|
||||
assert csv_path.endswith("csv"), (
|
||||
f"Need a file path ending with .csv, got {csv_path}"
|
||||
)
|
||||
assert Path(csv_path).parent.is_dir(), (
|
||||
f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
)
|
||||
|
||||
make_feature_matrix(args.feature_matrix_csv_file_path)
|
||||
509
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
509
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
@@ -0,0 +1,509 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMoEQuantConfig:
|
||||
quant_dtype: torch.dtype | str | None
|
||||
per_out_ch_quant: bool
|
||||
per_act_token_quant: bool
|
||||
block_shape: list[int] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
backend: str | None
|
||||
supports_apply_weight_on_input: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
common_float_types: list[torch.dtype | str] = [
|
||||
torch.float8_e4m3fn,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
]
|
||||
common_float_and_int_types = common_float_types + [torch.int8]
|
||||
nvfp4_types = ["nvfp4"]
|
||||
fp8_types = [torch.float8_e4m3fn]
|
||||
|
||||
|
||||
def register_prepare_and_finalize(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
backend: str | None,
|
||||
force_multigpu: bool = False,
|
||||
supports_apply_weight_on_input: bool = True,
|
||||
):
|
||||
global PREPARE_FINALIZE_INFO
|
||||
global MK_ALL_PREPARE_FINALIZE_TYPES
|
||||
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
assert kind not in PREPARE_FINALIZE_INFO
|
||||
|
||||
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
backend,
|
||||
supports_apply_weight_on_input,
|
||||
)
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
if backend is not None or force_multigpu:
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
else:
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
|
||||
|
||||
def register_experts(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
):
|
||||
global EXPERT_INFO
|
||||
global MK_FUSED_EXPERT_TYPES
|
||||
assert kind not in EXPERT_INFO
|
||||
|
||||
EXPERT_INFO[kind] = ExpertInfo(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
supports_chunking,
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES.append(kind)
|
||||
|
||||
|
||||
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
|
||||
info = PREPARE_FINALIZE_INFO.get(kind)
|
||||
assert info is not None
|
||||
return info
|
||||
|
||||
|
||||
def expert_info(kind) -> ExpertInfo:
|
||||
info = EXPERT_INFO.get(kind)
|
||||
assert info is not None
|
||||
return info
|
||||
|
||||
|
||||
register_prepare_and_finalize(
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
BatchedTritonExperts,
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
TritonExperts,
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
NaiveBatchedExperts,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
# Disable on blackwell for now
|
||||
if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="deepep_high_throughput",
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="deepep_low_latency",
|
||||
)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
PplxPrepareAndFinalize,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="pplx",
|
||||
)
|
||||
|
||||
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
supports_apply_weight_on_input=False,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
supports_expert_map=True,
|
||||
)
|
||||
else:
|
||||
FlashInferCutlassMoEPrepareAndFinalize = None
|
||||
|
||||
if has_deep_gemm() and is_deep_gemm_supported():
|
||||
register_experts(
|
||||
BatchedDeepGemmExperts,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
TritonOrDeepGemmExperts,
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported():
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
CutlassExpertsFp8,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
register_experts(
|
||||
CutlassBatchedExpertsFp8,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
if cutlass_fp4_supported():
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
|
||||
|
||||
register_experts(
|
||||
CutlassExpertsFp4,
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-tensor weights and per-tensor activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-tensor weights and per-token activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128],
|
||||
),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
]
|
||||
|
||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
MK_QUANT_CONFIGS += [
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: str | None,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
|
||||
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
return t[s:e]
|
||||
|
||||
|
||||
def make_cutlass_strides(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
num_dispatchers: int,
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
|
||||
|
||||
if fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print(f"Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp4:
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||
experts = CutlassExpertsFp4(**kwargs)
|
||||
elif fused_experts_type == FlashInferExperts:
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"ep_rank": moe.ep_rank,
|
||||
"ep_size": moe.ep_size,
|
||||
"tp_rank": moe.tp_rank,
|
||||
"tp_size": moe.tp_size,
|
||||
} | quant_kwargs
|
||||
print(f"Making FlashInferExperts {kwargs} ...")
|
||||
experts = FlashInferExperts(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||
|
||||
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
|
||||
|
||||
return experts
|
||||
134
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
134
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Concatenate
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(
|
||||
vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int
|
||||
):
|
||||
import tempfile
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=local_rank,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
|
||||
vllm_config: VllmConfig | None,
|
||||
env_dict: dict | None,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
cpu_group = None
|
||||
if vllm_config is not None:
|
||||
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
vllm_config,
|
||||
cpu_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch_with_config(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
|
||||
vllm_config: VllmConfig,
|
||||
env_dict: dict[Any, Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
137
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
137
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
def do_profile(
|
||||
fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5,
|
||||
):
|
||||
for _ in range(num_warmups):
|
||||
fn(**fn_kwargs)
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
) as tprof:
|
||||
fn(**fn_kwargs)
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
# TODO (varun): Add a descriptive trace file name
|
||||
tprof.export_chrome_trace(
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json"
|
||||
)
|
||||
|
||||
|
||||
def profile_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> None:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
# make modular kernel
|
||||
mk = make_modular_kernel(config, vllm_config, weights)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(
|
||||
config.world_size, rank_worker, vllm_config, env_dict, config, weights
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .cli_args import make_config, make_config_arg_parser
|
||||
|
||||
parser = make_config_arg_parser(
|
||||
description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert args.torch_trace_dir_path is not None, (
|
||||
"Please pass in a directory to store torch traces"
|
||||
)
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
||||
202
tests/kernels/moe/parallel_utils.py
Normal file
202
tests/kernels/moe/parallel_utils.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from typing import Concatenate
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(
|
||||
group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
return DeepEPHTPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
num_dispatchers=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank * ht_args.num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank,
|
||||
deepep_ll_args.hidden_size,
|
||||
pgi.world_size,
|
||||
deepep_ll_args.num_experts,
|
||||
)
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
|
||||
)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
num_dispatchers=pgi.world_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: DeepEPHTArgs | None,
|
||||
deepep_ll_args: DeepEPLLArgs | None,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(
|
||||
pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
|
||||
)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)
|
||||
106
tests/kernels/moe/test_batched_deepgemm.py
Normal file
106
tests/kernels/moe/test_batched_deepgemm.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
|
||||
|
||||
from .test_deepgemm import make_block_quant_fp8_weights
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
@pytest.mark.parametrize("E", [16, 32]) # number of experts
|
||||
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
|
||||
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
|
||||
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
def test_batched_deepgemm_vs_triton(
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
|
||||
):
|
||||
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
device = "cuda"
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE)
|
||||
|
||||
M = E * T # total tokens
|
||||
a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
a.clamp_(fp8_info.min, fp8_info.max)
|
||||
|
||||
# random router outputs → top-k indices / weights
|
||||
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
# token number for each expert
|
||||
cnt = torch.bincount(topk_ids.flatten(), minlength=E)
|
||||
max_cnt = int(cnt.max().item())
|
||||
# next power of 2 for max token number
|
||||
max_num_tokens = 1 << (max_cnt - 1).bit_length()
|
||||
|
||||
prep_finalize = BatchedPrepareAndFinalize(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_local_experts=E,
|
||||
num_dispatchers=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
per_act_token_quant=False,
|
||||
block_shape=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# triton (reference)
|
||||
triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||
|
||||
out_triton = mk_triton(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
# deepgemm
|
||||
deepgemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||
|
||||
out_deepgemm = mk_deepgemm(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 1e-3, f"Output diff too large: {diff}"
|
||||
352
tests/kernels/moe/test_batched_moe.py
Normal file
352
tests/kernels/moe/test_batched_moe.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import (
|
||||
batched_moe,
|
||||
make_quantized_test_activations,
|
||||
make_test_weights,
|
||||
naive_batched_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 1024, 2048),
|
||||
(32, 128, 128),
|
||||
(32, 512, 512),
|
||||
(32, 1024, 2048),
|
||||
(45, 128, 2048),
|
||||
(45, 1024, 128),
|
||||
(64, 512, 512),
|
||||
(64, 1024, 2048),
|
||||
(222, 128, 2048),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
if not current_platform.is_fp8_fnuz():
|
||||
DTYPES.append(torch.float8_e4m3fn)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMConfig:
|
||||
in_dtype: torch.dtype
|
||||
quant_dtype: torch.dtype | None
|
||||
out_dtype: torch.dtype
|
||||
num_experts: int
|
||||
max_tokens_per_expert: int
|
||||
K: int
|
||||
N: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMTensors:
|
||||
A: torch.Tensor # [E, max_tokens, K]
|
||||
B: torch.Tensor # [E, K, N] - column major
|
||||
C: torch.Tensor # [E, max_tokens, N]
|
||||
num_expert_tokens: torch.Tensor # [E]
|
||||
|
||||
@staticmethod
|
||||
def make_tensors(config: BatchedMMConfig):
|
||||
A = (
|
||||
torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
B = torch.randn(
|
||||
(config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype,
|
||||
)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.out_dtype,
|
||||
)
|
||||
|
||||
num_expert_tokens = torch.randint(
|
||||
low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts,),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 32])
|
||||
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
||||
@pytest.mark.parametrize("K", [128, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 1024])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
def test_batched_mm(
|
||||
num_experts: int,
|
||||
max_tokens_per_expert: int,
|
||||
K: int,
|
||||
N: int,
|
||||
dtype: torch.dtype,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||
pytest.skip("Don't test blocking for non-quantized types.")
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
num_expert_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_expert,
|
||||
size=(num_experts,),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
A, A_q, A_scale = make_quantized_test_activations(
|
||||
num_experts,
|
||||
max_tokens_per_expert,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
(B, B_q, B_scale, _), _ = make_test_weights(
|
||||
num_experts,
|
||||
N // 2,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
|
||||
compute_tl_dtype = {
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32,
|
||||
}[test_output.dtype]
|
||||
|
||||
assert A_q.dtype == B_q.dtype
|
||||
|
||||
invoke_moe_batched_triton_kernel(
|
||||
A_q,
|
||||
B_q,
|
||||
test_output,
|
||||
num_expert_tokens,
|
||||
compute_tl_dtype,
|
||||
# Quantization data
|
||||
A_scale,
|
||||
B_scale,
|
||||
None,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
|
||||
},
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
ref_output = native_batched_masked_quant_matmul(
|
||||
A,
|
||||
B,
|
||||
ref_output,
|
||||
num_expert_tokens,
|
||||
)
|
||||
|
||||
q_ref_output = native_batched_masked_quant_matmul(
|
||||
A_q,
|
||||
B_q,
|
||||
q_ref_output,
|
||||
num_expert_tokens,
|
||||
A_scale,
|
||||
B_scale,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[test_output.dtype]
|
||||
|
||||
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("input_scales", [False])
|
||||
def test_fused_moe_batched_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if topk > e:
|
||||
pytest.skip("topk > e")
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
(w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
block_shape=block_shape,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
if input_scales and quant_dtype is not None:
|
||||
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
baseline_output = torch_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
batched_output = naive_batched_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
triton_output = batched_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
|
||||
|
||||
torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)
|
||||
267
tests/kernels/moe/test_block_fp8.py
Normal file
267
tests/kernels/moe/test_block_fp8.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape,
|
||||
deep_gemm_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4608, 128),
|
||||
(1, 4608, 7168),
|
||||
(83, 128, 128),
|
||||
(83, 512, 512),
|
||||
(83, 4608, 512),
|
||||
(83, 4608, 7168),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
(8192, 128, 128),
|
||||
(8192, 128, 7168),
|
||||
(8192, 1024, 7168),
|
||||
(8192, 4608, 7168),
|
||||
]
|
||||
|
||||
MNK_FACTORS_DG = [
|
||||
(128, 128, 128),
|
||||
(128, 128, 7168),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 128),
|
||||
(128, 4608, 7168),
|
||||
(192, 512, 512),
|
||||
(192, 1024, 7168),
|
||||
(192, 4608, 7168),
|
||||
(1335, 128, 128),
|
||||
(1335, 1024, 7168),
|
||||
(1335, 4608, 512),
|
||||
(1335, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 128, 7168),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 7168),
|
||||
]
|
||||
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [2, 8, 16] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
# 0.039 only needed for M >= 8192
|
||||
tol = 0.035 if M < 8192 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
block_size = get_mk_alignment_for_contiguous_layout()
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (
|
||||
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
|
||||
)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(
|
||||
deep_gemm_moe_fp8, backend="inductor", fullgraph=True
|
||||
)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
|
||||
134
tests/kernels/moe/test_block_int8.py
Normal file
134
tests/kernels/moe/test_block_int8.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4096, 512),
|
||||
(1, 4096, 7168),
|
||||
(33, 512, 512),
|
||||
(33, 128, 7168),
|
||||
(33, 1024, 7168),
|
||||
(33, 4096, 128),
|
||||
(33, 4096, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 1024, 7168),
|
||||
(128, 4096, 512),
|
||||
(128, 4096, 7168),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 7168),
|
||||
(222, 4096, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 4096),
|
||||
]
|
||||
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using
|
||||
native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||
native torch reference."""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
score,
|
||||
topk,
|
||||
block_size,
|
||||
)
|
||||
|
||||
# Check results
|
||||
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
||||
143
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
143
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests compute_expert_num_tokens kernels
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: torch.Tensor | None = None
|
||||
|
||||
def to_device(self, device: str):
|
||||
self.topk_ids = self.topk_ids.to(device=device)
|
||||
if self.expert_map is not None:
|
||||
self.expert_map = self.expert_map.to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
device: str,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
) -> "TestTensors":
|
||||
# make topk ids
|
||||
topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64)
|
||||
for x in range(num_tokens):
|
||||
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
|
||||
topk_ids = topk_ids.to(dtype=torch.int64)
|
||||
return TestTensors(topk_ids=topk_ids)
|
||||
|
||||
def with_ep_rank(
|
||||
self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str
|
||||
):
|
||||
# make an expert map
|
||||
expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32)
|
||||
expert_map.fill_(-1)
|
||||
s = ep_rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device)
|
||||
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map)
|
||||
|
||||
|
||||
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
# do the reference in cpu
|
||||
tt.to_device("cpu")
|
||||
expert_ids, counts = tt.topk_ids.unique(return_counts=True)
|
||||
|
||||
for eid, count in zip(expert_ids, counts):
|
||||
if eid != -1 and tt.expert_map is not None:
|
||||
eid = tt.expert_map[eid]
|
||||
|
||||
if eid == -1:
|
||||
continue
|
||||
|
||||
expert_num_tokens[eid] += count
|
||||
|
||||
|
||||
def do_test_compute_expert_num_tokens(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
):
|
||||
assert num_topk <= num_experts
|
||||
|
||||
tt = TestTensors.make(
|
||||
num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu"
|
||||
)
|
||||
|
||||
num_global_experts = num_experts
|
||||
assert num_global_experts % ep_size == 0
|
||||
num_local_experts = num_global_experts // ep_size
|
||||
for ep_rank in range(ep_size):
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu")
|
||||
|
||||
ref_expert_num_tokens = torch.zeros(
|
||||
(num_local_experts), device="cpu", dtype=torch.int32
|
||||
)
|
||||
ref_impl(tt_rank, ref_expert_num_tokens)
|
||||
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
|
||||
|
||||
tt_rank.to_device("cuda")
|
||||
# Test with expert_map
|
||||
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map
|
||||
)
|
||||
|
||||
# Test without expert map
|
||||
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
|
||||
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
|
||||
topk_ids, num_local_experts, expert_map=None
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
|
||||
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
):
|
||||
do_test_compute_expert_num_tokens(
|
||||
num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("ep_size", [2])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens_from_numel(
|
||||
numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype
|
||||
):
|
||||
do_test_compute_expert_num_tokens(
|
||||
num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
)
|
||||
582
tests/kernels/moe/test_cutedsl_moe.py
Normal file
582
tests/kernels/moe/test_cutedsl_moe.py
Normal file
@@ -0,0 +1,582 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import torch
|
||||
from flashinfer import fp4_quantize
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
scaled_fp4_grouped_quantize,
|
||||
)
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
FLOAT8_E4M3_MAX = 448.0
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
||||
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def generate_balanced_routing(
|
||||
hidden_states: torch.Tensor, num_experts: int, top_k: int
|
||||
):
|
||||
"""
|
||||
Generate routing weights and topk indices such that every expert is active.
|
||||
Returns routing_weights, topk_idx
|
||||
"""
|
||||
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
# num_tokens = batch_size * seq_len
|
||||
|
||||
# First, assign at least one token per expert
|
||||
tokens_per_expert = torch.arange(num_tokens) % num_experts
|
||||
tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle
|
||||
|
||||
# Each token has top_k experts — start with one guaranteed expert
|
||||
topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long)
|
||||
topk_idx[:, 0] = tokens_per_expert
|
||||
|
||||
# For remaining top_k - 1 experts, pick randomly (allowing repeats)
|
||||
if top_k > 1:
|
||||
random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1))
|
||||
topk_idx[:, 1:] = random_choices
|
||||
|
||||
# Normalize routing weights so each token's weights sum to 1
|
||||
routing_weights = torch.rand(num_tokens, top_k)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Reshape back if needed
|
||||
routing_weights = routing_weights.view(num_tokens, top_k)
|
||||
topk_idx = topk_idx.view(num_tokens, top_k)
|
||||
|
||||
return routing_weights, topk_idx
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
):
|
||||
routing_weights, topk_idx = generate_balanced_routing(
|
||||
router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
masked_m = []
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
masked_m.append(mask.sum())
|
||||
|
||||
masked_m = torch.tensor(masked_m, dtype=torch.int32)
|
||||
# Intialize the hidden_states_3d with ones instead of empty to avoid nan
|
||||
# issue.
|
||||
hidden_states_3d = torch.ones(
|
||||
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
|
||||
)
|
||||
for i in range(num_experts):
|
||||
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
|
||||
|
||||
return hidden_states_3d, masked_m, topk_idx, routing_weights
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1024),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
# Reference implementation of torch_moe
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
m = w1[i].shape[0]
|
||||
assert m % 2 == 0
|
||||
# Note: w1 and w3 are swapped!
|
||||
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
|
||||
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
|
||||
inter_gs = torch.tensor(1.0).cuda()
|
||||
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
|
||||
inter = dequantize_nvfp4_to_dtype(
|
||||
inter_q,
|
||||
inter_blockscale,
|
||||
inter_gs,
|
||||
dtype=inter.dtype,
|
||||
device=inter.device,
|
||||
block_size=16,
|
||||
).cuda()
|
||||
out[mask] = inter @ w2[i].transpose(0, 1)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def grouped_gemm_ref(
|
||||
hidden_states_expanded: torch.Tensor,
|
||||
hidden_states_3d: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
B: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
*,
|
||||
block_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the reference grouped GEMM (fp4 quantized per-expert loop),
|
||||
computes flashinfer grouped GEMM (for scale consistency),
|
||||
and returns ONLY the repacked reference output: out_ref.
|
||||
|
||||
Returns:
|
||||
out_ref: Tensor [num_experts, max_m, n_out]
|
||||
"""
|
||||
device_hs = hidden_states_expanded.device
|
||||
device_w = weights.device
|
||||
out_dtype = weights.dtype
|
||||
n_out = weights.shape[1]
|
||||
|
||||
# Flattened reference output (B*topk, n_out)
|
||||
out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
|
||||
|
||||
# Per-expert reference compute loop
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
if mask.any():
|
||||
lhs = hidden_states_expanded[mask]
|
||||
rhs = weights[i]
|
||||
|
||||
a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
|
||||
b_amax = rhs.abs().max().to(torch.float32).to(device_w)
|
||||
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
|
||||
lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
|
||||
rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
|
||||
|
||||
lhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
lhsq,
|
||||
lhsq_sf,
|
||||
a_gs,
|
||||
dtype=lhs.dtype,
|
||||
device=device_hs,
|
||||
block_size=block_size,
|
||||
)
|
||||
rhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
rhsq,
|
||||
rhsq_sf,
|
||||
b_gs,
|
||||
dtype=rhs.dtype,
|
||||
device=device_w,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
|
||||
|
||||
# Determine per-expert max_m
|
||||
max_m_val = int(masked_m.max().item())
|
||||
|
||||
# Repack into [num_experts, max_m, n_out]
|
||||
out_ref = torch.zeros(
|
||||
(num_experts, max_m_val, n_out),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
expert_slot = [0] * num_experts
|
||||
|
||||
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
|
||||
slot = expert_slot[expert_id]
|
||||
if slot < max_m_val:
|
||||
out_ref[expert_id, slot, :] = out[i]
|
||||
expert_slot[expert_id] += 1
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Expert {expert_id} exceeded max slots ({max_m_val}). "
|
||||
"Increase max_m or check masked_m."
|
||||
)
|
||||
|
||||
return out_ref
|
||||
|
||||
|
||||
def flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states: torch.Tensor, # 3d
|
||||
input_global_scale: torch.Tensor, # (l,)
|
||||
weights: torch.Tensor,
|
||||
w_global_scale: torch.Tensor, # (l,)
|
||||
masked_m: torch.Tensor,
|
||||
):
|
||||
# hidden_states: [l, m, k]
|
||||
# weights: [l, n, k]
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
)
|
||||
num_experts, n, k = weights.shape
|
||||
bq, bq_sf = scaled_fp4_grouped_quantize(
|
||||
weights,
|
||||
torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
|
||||
w_global_scale,
|
||||
)
|
||||
|
||||
out = torch.zeros(
|
||||
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
|
||||
)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
c_dtype = "bfloat16"
|
||||
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
|
||||
1, 1, num_experts
|
||||
)
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
cutedsl_gmm_masked(
|
||||
(aq, aq_sf),
|
||||
(bq, bq_sf),
|
||||
out,
|
||||
masked_m.to(aq.device),
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=alpha,
|
||||
alpha_dtype=get_cute_dtype(alpha),
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
|
||||
@pytest.mark.parametrize("topk", [1, 2, 4])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_cutedsl_moe_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_experts = 8
|
||||
hidden_states = (
|
||||
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
|
||||
)
|
||||
w1 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(bs, -1, hidden_dim)
|
||||
.repeat(1, topk, 1)
|
||||
.reshape(-1, hidden_dim)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
|
||||
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
|
||||
input_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
a2_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
) # assume intermediate scale is 1.0
|
||||
|
||||
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
|
||||
w1,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
|
||||
w1_global_scale,
|
||||
)
|
||||
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
|
||||
w2,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
|
||||
w2_global_scale,
|
||||
)
|
||||
|
||||
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
|
||||
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
|
||||
|
||||
out = torch.empty_like(hidden_states_3d)
|
||||
# Note: the 1st dim shouldn't be bs
|
||||
wk = torch.empty(
|
||||
num_experts,
|
||||
hidden_states_3d.shape[1],
|
||||
inter_dim * 2,
|
||||
dtype=hidden_states_3d.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states_3d.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
w1_fp4.permute(2, 0, 1),
|
||||
w1_blockscale,
|
||||
w1_alpha,
|
||||
w2_fp4.permute(2, 0, 1),
|
||||
a2_global_scale,
|
||||
w2_blockscale,
|
||||
w2_alpha,
|
||||
masked_m.to(hidden_states.device),
|
||||
wk,
|
||||
out,
|
||||
)
|
||||
|
||||
# reference
|
||||
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
input_global_scale,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
block_size=16,
|
||||
)
|
||||
w1_d = torch.empty(
|
||||
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
|
||||
)
|
||||
w2_d = torch.empty(
|
||||
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
|
||||
)
|
||||
|
||||
for idx in range(0, num_experts):
|
||||
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
|
||||
w1[idx], w1_global_scale[idx]
|
||||
)
|
||||
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
|
||||
w2[idx], w2_global_scale[idx]
|
||||
)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_fp4_sliced,
|
||||
w1_blockscale_sliced,
|
||||
w1_global_scale[idx],
|
||||
dtype=w1.dtype,
|
||||
device=w1.device,
|
||||
block_size=16,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_fp4_sliced,
|
||||
w2_blockscale_sliced,
|
||||
w2_global_scale[idx],
|
||||
dtype=w2.dtype,
|
||||
device=w2.device,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
ref_output = torch_moe_nvfp4(
|
||||
a_in_dtype,
|
||||
w1_d,
|
||||
w2_d,
|
||||
topk,
|
||||
routing_weights.to(a_in_dtype.device),
|
||||
topk_idx.to(a_in_dtype.device),
|
||||
)
|
||||
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
|
||||
|
||||
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
|
||||
rows, cols = positions[:, 0], positions[:, 1]
|
||||
experts = topk_idx[rows, cols]
|
||||
for i in range(num_experts):
|
||||
mask = experts == i
|
||||
if mask.any():
|
||||
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
|
||||
r, c = rows[idx], cols[idx]
|
||||
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
|
||||
out.device
|
||||
).unsqueeze(-1)
|
||||
torch.testing.assert_close(
|
||||
out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_grouped_gemm_nt_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
) -> None:
|
||||
torch.manual_seed(42)
|
||||
B = bs
|
||||
D = hidden_dim
|
||||
N = inter_dim
|
||||
# CuteDSL group gemm has issue when not all experts are active.
|
||||
# i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
|
||||
# see https://github.com/flashinfer-ai/flashinfer/issues/1856
|
||||
num_experts = bs
|
||||
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
|
||||
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
|
||||
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
a_amax = (
|
||||
hidden_states_3d.abs()
|
||||
.amax(dim=(1, 2))
|
||||
.to(torch.float32)
|
||||
.to(hidden_states.device)
|
||||
)
|
||||
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
|
||||
)
|
||||
# reference
|
||||
out_ref = grouped_gemm_ref(
|
||||
hidden_states_expanded=hidden_states_expanded,
|
||||
hidden_states_3d=hidden_states_3d,
|
||||
weights=weights,
|
||||
topk_idx=topk_idx,
|
||||
masked_m=masked_m,
|
||||
B=B,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
# Note: just to compare the masked position due to cutedsl may write nan
|
||||
# into unmasked position.
|
||||
for i in range(num_experts):
|
||||
torch.testing.assert_close(
|
||||
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
|
||||
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
|
||||
atol=1e-1,
|
||||
rtol=1e-1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
|
||||
test_grouped_gemm_nt_masked(16, 128, 512, 4)
|
||||
92
tests/kernels/moe/test_cutlass_grouped_gemm.py
Normal file
92
tests/kernels/moe/test_cutlass_grouped_gemm.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# DeepGEMM Style Cutlass Grouped GEMM Test
|
||||
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import per_token_cast_to_fp8
|
||||
from tests.kernels.utils import baseline_scaled_mm
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups, expected_m_per_group, k, n",
|
||||
[
|
||||
(4, 8192, 7168, 4096),
|
||||
(4, 8192, 2048, 7168),
|
||||
(8, 4096, 7168, 4096),
|
||||
(8, 4096, 2048, 7168),
|
||||
(32, 1024, 7168, 4096),
|
||||
(32, 1024, 2048, 7168),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or x.to_int() != 100)(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Block Scaled Grouped GEMM is only supported on SM100.",
|
||||
)
|
||||
def test_cutlass_grouped_gemm(
|
||||
num_groups: int,
|
||||
expected_m_per_group: int,
|
||||
k: int,
|
||||
n: int,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
device = "cuda"
|
||||
alignment = 128
|
||||
group_ms = [
|
||||
int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
|
||||
]
|
||||
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
|
||||
|
||||
x = torch.randn((m, k), device=device, dtype=out_dtype)
|
||||
y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype)
|
||||
out = torch.empty((m, n), device=device, dtype=out_dtype)
|
||||
ref_out = torch.randn((m, n), device=device, dtype=out_dtype)
|
||||
|
||||
ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m]
|
||||
pb_size = []
|
||||
for i in range(num_groups):
|
||||
pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k])
|
||||
problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32)
|
||||
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
|
||||
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (
|
||||
torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
|
||||
|
||||
for i in range(num_groups):
|
||||
a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
|
||||
a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
|
||||
b = y_fp8[0][i].t()
|
||||
b_scale = y_fp8[1][i].t()
|
||||
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
|
||||
ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
|
||||
|
||||
ops.cutlass_blockwise_scaled_grouped_mm(
|
||||
out,
|
||||
x_fp8[0],
|
||||
y_fp8[0],
|
||||
x_fp8[1],
|
||||
y_fp8[1],
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3)
|
||||
554
tests/kernels/moe/test_cutlass_moe.py
Normal file
554
tests/kernels/moe/test_cutlass_moe.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import dataclasses
|
||||
from math import prod
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(7, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(224, 1024, 1024),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
(32768, 1024, 1024),
|
||||
# These sizes trigger wrong answers.
|
||||
# (7232, 2048, 5120),
|
||||
# (40000, 2048, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors:
|
||||
a: torch.Tensor
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
ab_strides1: torch.Tensor
|
||||
c_strides1: torch.Tensor
|
||||
ab_strides2: torch.Tensor
|
||||
c_strides2: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors(
|
||||
m: int, k: int, n: int, e: int, dtype: torch.dtype
|
||||
) -> "MOETensors":
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors8Bit(MOETensors):
|
||||
# quantized
|
||||
a_q: torch.Tensor | None = None # a -> a_q
|
||||
w1_q: torch.Tensor | None = None # w1 -> w1_q
|
||||
w2_q: torch.Tensor | None = None # w2 -> w2_q
|
||||
a_scale: torch.Tensor | None = None
|
||||
w1_scale: torch.Tensor | None = None
|
||||
w2_scale: torch.Tensor | None = None
|
||||
# dequantized
|
||||
a_d: torch.Tensor | None = None # a -> a_q -> a_d
|
||||
w1_d: torch.Tensor | None = None # w1 -> w1_q -> w1_d
|
||||
w2_d: torch.Tensor | None = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
|
||||
) -> "MOETensors8Bit":
|
||||
dtype = torch.half
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)
|
||||
|
||||
# a -> a_q, w1 -> w1_q, w2 -> w2_q
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
|
||||
)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
|
||||
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
|
||||
a_d = a_q.float().mul(a_scale).to(dtype)
|
||||
w1_d = torch.empty_like(moe_tensors_fp16.w1)
|
||||
w2_d = torch.empty_like(moe_tensors_fp16.w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
return MOETensors8Bit(
|
||||
a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d,
|
||||
)
|
||||
|
||||
|
||||
def run_with_expert_maps(
|
||||
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
|
||||
):
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q",
|
||||
"w2_q",
|
||||
"ab_strides1",
|
||||
"ab_strides2",
|
||||
"c_strides1",
|
||||
"c_strides2",
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
for k, v in cutlass_moe_kwargs.items()
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
quant_config = cutlass_moe_kwargs["quant_config"]
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
# make expert map
|
||||
expert_map = [-1] * num_experts
|
||||
expert_map[s:e] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
# update cutlass moe arg with expert_map
|
||||
cutlass_moe_kwargs["expert_map"] = expert_map
|
||||
# update cutlass moe arg tensors
|
||||
for k, t in full_tensors.items():
|
||||
cutlass_moe_kwargs[k] = t[s:e]
|
||||
|
||||
new_quant_config = copy.deepcopy(quant_config)
|
||||
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
|
||||
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
|
||||
|
||||
cutlass_moe_kwargs["quant_config"] = new_quant_config
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
for kwargs in slice_experts():
|
||||
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
def run_8_bit(
|
||||
moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert not any(
|
||||
[
|
||||
t is None
|
||||
for t in [
|
||||
moe_tensors.w1_q,
|
||||
moe_tensors.w2_q,
|
||||
moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale,
|
||||
moe_tensors.a_scale,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=moe_tensors.w1_scale,
|
||||
w2_scale=moe_tensors.w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
# Set to moe_tensors.a_scale iff static scales + per tensor.
|
||||
# This is not currently being tested.
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"a": moe_tensors.a,
|
||||
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"ab_strides1": moe_tensors.ab_strides1,
|
||||
"ab_strides2": moe_tensors.ab_strides2,
|
||||
"c_strides1": moe_tensors.c_strides1,
|
||||
"c_strides2": moe_tensors.c_strides2,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
return cutlass_moe_fp8(**kwargs)
|
||||
|
||||
assert num_local_experts is not None
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
number_local_experts = e // ep_size
|
||||
else:
|
||||
number_local_experts = None
|
||||
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
|
||||
)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(
|
||||
triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64])
|
||||
@pytest.mark.parametrize("n", [1024])
|
||||
@pytest.mark.parametrize("k", [4096])
|
||||
@pytest.mark.parametrize("e", [16])
|
||||
@pytest.mark.parametrize("topk", [1, 8])
|
||||
@pytest.mark.parametrize("per_act_token", [True])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
LARGE_MNK_FACTORS = [
|
||||
(1, 8192, 5120, 31),
|
||||
(32768, 1024, 1024, 16),
|
||||
(65536, 512, 1024, 16),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [128])
|
||||
@pytest.mark.parametrize("per_act_token", [False])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP_large(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
|
||||
@pytest.mark.parametrize("e", [128])
|
||||
@pytest.mark.parametrize("per_act_token", [False])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_run_cutlass_moe_fp8(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(
|
||||
m, k, n, e, per_act_token, per_out_channel
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
# we want to make sure there is at least one token that's generated in
|
||||
# this expert shard and at least one token that's NOT generated in this
|
||||
# expert shard
|
||||
topk_ids[0][0] = -1
|
||||
topk_ids[0][1] = 1
|
||||
|
||||
workspace13_shape = (m * topk, max(2 * n, k))
|
||||
workspace2_shape = (m * topk, max(n, k))
|
||||
output_shape = (m, k)
|
||||
|
||||
workspace13 = torch.empty(
|
||||
prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
workspace2 = torch.empty(
|
||||
prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
|
||||
num_local_experts = e // ep_size
|
||||
start, end = 0, num_local_experts
|
||||
expert_map = [-1] * e
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
|
||||
)
|
||||
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output,
|
||||
a1q,
|
||||
mt.w1_q,
|
||||
mt.w2_q,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
mt.w1_scale,
|
||||
mt.w2_scale,
|
||||
a1q_scale,
|
||||
None,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
workspace13,
|
||||
workspace2,
|
||||
None,
|
||||
mt.a.dtype,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
False,
|
||||
topk_weights,
|
||||
)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_random_workspace)
|
||||
|
||||
workspace13.fill_(0)
|
||||
output_zero_workspace = torch.zeros(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_zero_workspace)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
|
||||
)
|
||||
565
tests/kernels/moe/test_deepep_deepgemm_moe.py
Normal file
565
tests/kernels/moe/test_deepep_deepgemm_moe.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test DeepEP + DeepGEMM integration
|
||||
DeepGEMM are gemm kernels specialized for the
|
||||
fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
requires_deep_gemm = pytest.mark.skipif(
|
||||
not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_dp_metadata(M: int, world_size: int):
|
||||
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=M,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
|
||||
)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestConfig:
|
||||
topk: int
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
per_act_token_quant: bool
|
||||
block_size: list[int]
|
||||
# configs for testing low-latency kernels
|
||||
low_latency: bool
|
||||
use_fp8_dispatch: bool | None = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
rank_tokens = (
|
||||
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
)
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
rank_token_scales = None
|
||||
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
high=config.num_experts,
|
||||
size=(m, topk),
|
||||
device=torch.cuda.current_device(),
|
||||
).to(dtype=torch.int64)
|
||||
|
||||
topk_weights = torch.randn(
|
||||
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int,
|
||||
dp_size: int,
|
||||
hidden_size: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
|
||||
a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
deepep_ht_args=None,
|
||||
deepep_ll_args=DeepEPLLArgs(
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=test_config.num_experts,
|
||||
use_fp8_dispatch=test_config.use_fp8_dispatch,
|
||||
),
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
|
||||
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
||||
deepep_ll_args=None,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
|
||||
mk: FusedMoEModularKernel
|
||||
# Make modular kernel
|
||||
if test_config.low_latency:
|
||||
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
hidden_size = test_tensors.rank_tokens.size(-1)
|
||||
|
||||
mk = make_ll_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
return mk
|
||||
|
||||
|
||||
def deepep_deepgemm_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
test_config = test_tensors.config
|
||||
num_experts = test_config.num_experts
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def triton_impl(
|
||||
a: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
block_shape: list[int],
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
|
||||
def _test_deepep_deepgemm_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
w1_scale = w1_scale.to(device=torch.cuda.current_device())
|
||||
w2_scale = w2_scale.to(device=torch.cuda.current_device())
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
triton_moe = triton_impl(
|
||||
a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Slice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
e_start = num_local_experts * pgi.rank
|
||||
e_end = e_start + num_local_experts
|
||||
w1_ep = w1[e_start:e_end]
|
||||
w2_ep = w2[e_start:e_end]
|
||||
w1_scale_ep = w1_scale[e_start:e_end]
|
||||
w2_scale_ep = w2_scale[e_start:e_end]
|
||||
|
||||
deepep_moe = deepep_deepgemm_moe_impl(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
test_tensors,
|
||||
w1_ep,
|
||||
w2_ep,
|
||||
w1_scale_ep,
|
||||
w2_scale_ep,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_moe,
|
||||
deepep_moe,
|
||||
atol=6e-2,
|
||||
rtol=6e-2,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(8, 128, 128),
|
||||
(8, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(129, 128, 256),
|
||||
(129, 1024, 2048),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ht_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
|
||||
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
block_size = [block_m, block_m]
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 2560),
|
||||
(2, 128, 2560),
|
||||
(3, 1024, 2560),
|
||||
(32, 128, 2560),
|
||||
(45, 512, 2560),
|
||||
(64, 1024, 2560),
|
||||
(222, 1024, 2560),
|
||||
]
|
||||
# Fix tests for USE_FP8_DISPATCH=True
|
||||
USE_FP8_DISPATCH = [False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@pytest.mark.parametrize("block_size", [[128, 128]])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
use_fp8_dispatch: bool,
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
"""
|
||||
assert not is_deep_gemm_e8m0_used()
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=True,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
528
tests/kernels/moe/test_deepep_moe.py
Normal file
528
tests/kernels/moe/test_deepep_moe.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test deepep dispatch-combine logic
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
MAX_TOKENS_PER_RANK = 64
|
||||
|
||||
|
||||
def make_weights(
|
||||
e, n, k, dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
return w1, w2, None, None
|
||||
|
||||
# per-out-channel weight quantization
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16)
|
||||
w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16)
|
||||
|
||||
n_b_scales = 2 * n
|
||||
k_b_scales = k
|
||||
w1_q = torch.empty_like(w1, dtype=dtype)
|
||||
w2_q = torch.empty_like(w2, dtype=dtype)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=True
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=True
|
||||
)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestConfig:
|
||||
dtype: torch.dtype
|
||||
topk: int
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
|
||||
# TODO (varun) - check that float16 works ?
|
||||
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
|
||||
token_dtype = (
|
||||
torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype
|
||||
)
|
||||
rank_tokens = (
|
||||
torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
)
|
||||
rank_token_scales = None
|
||||
|
||||
topk = torch.randint(
|
||||
low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda"
|
||||
).to(dtype=torch.int64)
|
||||
topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda")
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
hidden_size: int,
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
use_fp8_dispatch: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
ht_args: DeepEPHTArgs | None = None
|
||||
ll_args: DeepEPLLArgs | None = None
|
||||
|
||||
if low_latency_mode:
|
||||
ll_args = DeepEPLLArgs(
|
||||
max_tokens_per_rank=MAX_TOKENS_PER_RANK,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
else:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 Dispatch is valid only for low-latency kernels"
|
||||
)
|
||||
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
|
||||
|
||||
a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
if low_latency_mode:
|
||||
assert not quant_config.per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(quant_config=quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> torch.Tensor:
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
hidden_size = test_tensors.rank_tokens.size(1)
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
q_dtype = None
|
||||
if is_quantized:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end]
|
||||
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
|
||||
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
|
||||
rank_token_scales_chunk = test_tensors.rank_token_scales
|
||||
if (
|
||||
rank_token_scales_chunk is not None
|
||||
and rank_token_scales_chunk.size(0) == total_num_tokens
|
||||
):
|
||||
# per act token
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end]
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
q_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
hidden_size,
|
||||
dp_size,
|
||||
num_experts,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
use_fp8_dispatch,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(
|
||||
hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True)
|
||||
|
||||
max_num_tokens_per_dp = (
|
||||
MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens
|
||||
)
|
||||
|
||||
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, total_num_tokens - 1)
|
||||
chunk_end = min(chunk_end, total_num_tokens)
|
||||
|
||||
process_chunk(
|
||||
chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens
|
||||
)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
a, topk_ids, topk_weights = (
|
||||
test_tensors.rank_tokens,
|
||||
test_tensors.topk,
|
||||
test_tensors.topk_weights,
|
||||
)
|
||||
if using_fp8_dispatch:
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
# blockwise quant and de-quant.
|
||||
assert not per_act_token_quant
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False)
|
||||
a = (
|
||||
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
|
||||
.view(a.shape)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
a_dtype = a.dtype
|
||||
if is_quantized:
|
||||
w1 = w1.to(dtype=torch.float32) * w1_scale
|
||||
w2 = w2.to(dtype=torch.float32) * w2_scale
|
||||
a = a.to(dtype=torch.float32)
|
||||
|
||||
m, _ = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
out = torch.zeros_like(a)
|
||||
|
||||
for i in range(m):
|
||||
a_i = a[i]
|
||||
o_i = out[i]
|
||||
for j in range(topk):
|
||||
e = topk_ids[i][j]
|
||||
e_w = topk_weights[i][j]
|
||||
w1_e = w1[e]
|
||||
w2_e = w2[e]
|
||||
o_i += (
|
||||
SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)
|
||||
) * e_w
|
||||
|
||||
if is_quantized:
|
||||
out = out.to(dtype=a_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _deep_ep_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode"
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
if is_quantized:
|
||||
w1_scale = w1_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
w2_scale = w2_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(
|
||||
test_tensors,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
e_start = num_local_experts * pgi.rank
|
||||
e_end = e_start + num_local_experts
|
||||
w1_ep = w1[e_start:e_end]
|
||||
w2_ep = w2[e_start:e_end]
|
||||
|
||||
w1_scale_ep, w2_scale_ep = None, None
|
||||
if is_quantized:
|
||||
w1_scale_ep = w1_scale[e_start:e_end] # type: ignore
|
||||
w2_scale_ep = w2_scale[e_start:e_end] # type: ignore
|
||||
deepep_combined = deep_ep_moe_impl(
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
test_tensors,
|
||||
w1_ep,
|
||||
w2_ep,
|
||||
w1_scale_ep,
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch_combined,
|
||||
deepep_combined,
|
||||
atol=6e-2,
|
||||
rtol=6e-2,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_deep_ep_moe,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 2560),
|
||||
(2, 128, 2560),
|
||||
(3, 1024, 2560),
|
||||
(32, 128, 2560),
|
||||
(45, 512, 2560),
|
||||
(64, 1024, 2560),
|
||||
(222, 1024, 2560),
|
||||
]
|
||||
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
|
||||
USE_FP8_DISPATCH = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
def test_low_latency_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = True
|
||||
|
||||
if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES:
|
||||
pytest.skip(
|
||||
f"Skipping test as hidden size {k} is not in list of supported "
|
||||
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_deep_ep_moe,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
False,
|
||||
)
|
||||
180
tests/kernels/moe/test_deepgemm.py
Normal file
180
tests/kernels/moe/test_deepgemm.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit-test DeepGEMM FP8 kernels (no DeepEP).
|
||||
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
calc_diff,
|
||||
is_deep_gemm_supported,
|
||||
per_block_cast_to_fp8,
|
||||
)
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
):
|
||||
"""
|
||||
Generate (w1, w2) expert weights and their per-block scale tensors
|
||||
in FP8 block-quantized format.
|
||||
|
||||
w1 shape: (E, 2N, K)
|
||||
w2 shape: (E, K, N)
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
fp8_max, fp8_min = (
|
||||
torch.finfo(torch.float8_e4m3fn).max,
|
||||
torch.finfo(torch.float8_e4m3fn).min,
|
||||
)
|
||||
|
||||
# bf16 reference weights
|
||||
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
|
||||
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
|
||||
w1_bf16.clamp_(fp8_min, fp8_max)
|
||||
w2_bf16.clamp_(fp8_min, fp8_max)
|
||||
|
||||
block_n, block_k = block_size
|
||||
n_tiles_w1 = math.ceil((2 * n) / block_n)
|
||||
k_tiles_w1 = math.ceil(k / block_k)
|
||||
n_tiles_w2 = math.ceil(k / block_n)
|
||||
k_tiles_w2 = math.ceil(n / block_k)
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
|
||||
w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(
|
||||
w1_bf16[i], block_size=block_size, use_ue8m0=True
|
||||
)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(
|
||||
w2_bf16[i], block_size=block_size, use_ue8m0=True
|
||||
)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
"""
|
||||
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
|
||||
Triton baseline within tolerance.
|
||||
"""
|
||||
tokens_bf16 = (
|
||||
torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
|
||||
.clamp_min_(-1)
|
||||
.clamp_max_(1)
|
||||
)
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
|
||||
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
out_triton = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
# DeepGemm
|
||||
out_deepgemm = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||
MNKs = [
|
||||
(1024, 768, 128),
|
||||
(2048, 768, 512),
|
||||
(512, 1024, 1024),
|
||||
(4096, 4096, 1024),
|
||||
]
|
||||
|
||||
TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe"
|
||||
)
|
||||
|
||||
call_counter = {"cnt": 0}
|
||||
|
||||
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
|
||||
|
||||
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
|
||||
call_counter["cnt"] += 1
|
||||
return orig_fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
run_single_case(
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# ensure that the DeepGEMM path was indeed taken.
|
||||
assert call_counter["cnt"] == 1, (
|
||||
f"DeepGEMM path was not executed during the test. "
|
||||
f"Call counter: {call_counter['cnt']}"
|
||||
)
|
||||
287
tests/kernels/moe/test_flashinfer.py
Normal file
287
tests/kernels/moe/test_flashinfer.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"flashinfer not supported for vLLM on ROCm", allow_module_level=True
|
||||
)
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
90
|
||||
):
|
||||
pytest.skip(
|
||||
"Supported for sm >= 90",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
NUM_EXPERTS = [16]
|
||||
TOP_KS = [1]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(256, 8192, 5120),
|
||||
(127, 4096, 5120),
|
||||
(10, 8192, 5120),
|
||||
(10, 4096, 5120),
|
||||
(1, 8192, 5120),
|
||||
(1, 4096, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
|
||||
|
||||
def quant_fp8_per_tensor_batches(a):
|
||||
num_batches = a.size(0)
|
||||
a_quant = []
|
||||
a_scales = []
|
||||
|
||||
for i in range(num_batches):
|
||||
a_fp8, a_global_sf = input_to_float8(a[i])
|
||||
a_global_sf = 1.0 / a_global_sf
|
||||
a_quant.append(a_fp8)
|
||||
a_scales.append(a_global_sf)
|
||||
|
||||
result_a_quant = torch.stack(a_quant)
|
||||
result_a_scales = torch.stack(a_scales)
|
||||
|
||||
return result_a_quant, result_a_scales
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestData:
|
||||
hidden_states: torch.Tensor
|
||||
w13_quantized: torch.Tensor
|
||||
w2_quantized: torch.Tensor
|
||||
a1_scale: torch.Tensor
|
||||
a2_scale: torch.Tensor
|
||||
w13_weight_scale: torch.Tensor
|
||||
w2_weight_scale: torch.Tensor
|
||||
layer: torch.nn.Module
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
|
||||
) -> "TestData":
|
||||
is_gated = activation != "relu2_no_mul"
|
||||
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn(
|
||||
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
_, a1_scale = input_to_float8(hidden_states)
|
||||
a1_scale = 1.0 / a1_scale
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
|
||||
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = w13_quantized.clone()
|
||||
layer.w2_weight = w2_quantized.clone()
|
||||
layer.w13_input_scale = a1_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
layer.w13_weight_scale = w13_weight_scale
|
||||
layer.w2_weight_scale = w2_weight_scale
|
||||
# Setup dummy config.
|
||||
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
|
||||
tp_size=1,
|
||||
pcp_size=1,
|
||||
dp_size=1,
|
||||
ep_size=1,
|
||||
tp_rank=1,
|
||||
pcp_rank=1,
|
||||
dp_rank=1,
|
||||
ep_rank=1,
|
||||
use_ep=False,
|
||||
all2all_backend="naive",
|
||||
)
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
w2_quantized=w2_quantized,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Test is only supported for sm >= 100")
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
hidden_states=td.hidden_states,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=td.layer,
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, reorder=False, activation=activation
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
hidden_states=td.hidden_states,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
|
||||
w2_scale=td.w2_weight_scale,
|
||||
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
|
||||
a1_scale=td.a1_scale,
|
||||
a1_gscale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
a2_gscale=1.0 / td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return quant_config
|
||||
|
||||
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
|
||||
td.layer.quant_method = td.layer
|
||||
|
||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||
td.hidden_states,
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
152
tests/kernels/moe/test_flashinfer_moe.py
Normal file
152
tests/kernels/moe/test_flashinfer_moe.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
100
|
||||
):
|
||||
pytest.skip(
|
||||
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
is_gated_act = activation == "silu_and_mul"
|
||||
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
make_gate=is_gated_act,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
|
||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=fi_activation,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
_, m_k = a_fp4.shape
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty(
|
||||
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
|
||||
)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
quant_config.w1_scale[idx],
|
||||
(1 / quant_config.g1_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
quant_config.w2_scale[idx],
|
||||
(1 / quant_config.g2_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(
|
||||
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
|
||||
350
tests/kernels/moe/test_gpt_oss_triton_kernels.py
Normal file
350
tests/kernels/moe/test_gpt_oss_triton_kernels.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
pytest.skip(
|
||||
"triton_kernels not found, skipping all related tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
|
||||
def deshuffle(w: torch.Tensor):
|
||||
first = w[..., ::2]
|
||||
second = w[..., 1::2]
|
||||
|
||||
deshuffled = torch.concat((first, second), dim=-1)
|
||||
return deshuffled
|
||||
|
||||
|
||||
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
randbits = [torch.randperm(E) for _ in range(M)]
|
||||
x_list = [
|
||||
(-1) ** i
|
||||
* ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
|
||||
for i, bits in enumerate(randbits)
|
||||
]
|
||||
exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E)
|
||||
|
||||
# create input tensor
|
||||
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
|
||||
w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
|
||||
w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
|
||||
w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
exp_data_tri = exp_data.clone()
|
||||
x_tri = x.clone()
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
|
||||
w1_bias_tri = w1_bias.clone()
|
||||
w2_bias_tri = w2_bias.clone()
|
||||
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
||||
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
||||
|
||||
dtype_dict = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
|
||||
if w_dtype != "mx4":
|
||||
# simulate quantization support on reference impl
|
||||
w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
|
||||
w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)
|
||||
|
||||
# triton moe kernel use transposed shape for matmul
|
||||
w1_tri = w1_tri.transpose(-2, -1)
|
||||
w2_tri = w2_tri.transpose(-2, -1)
|
||||
|
||||
# shuffle weights
|
||||
w1_tri = shuffle_weight(w1_tri)
|
||||
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
||||
|
||||
# quant triton_weights
|
||||
x_tri = x.to(dtype_dict[a_dtype])
|
||||
if w_dtype != "mx4":
|
||||
pytest.skip("NYI")
|
||||
else: # quantize to mx4
|
||||
# careful on the padding here, the activation padding need to be
|
||||
# multiple of 64, the actual engine is not implemented
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
|
||||
|
||||
w2_bottom_pad = w1_right_pad // 2
|
||||
w2_right_pad = w1_bottom_pad
|
||||
|
||||
x_pad = w1_bottom_pad
|
||||
|
||||
w1_tri = F.pad(
|
||||
w1_tri,
|
||||
(0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
w2_tri = F.pad(
|
||||
w2_tri,
|
||||
(0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
|
||||
w1_bias_tri = F.pad(
|
||||
w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0
|
||||
)
|
||||
w2_bias_tri = F.pad(
|
||||
w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
|
||||
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps
|
||||
)
|
||||
)
|
||||
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w1_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts
|
||||
)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts
|
||||
)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
pc1 = PrecisionConfig(
|
||||
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
pc2 = PrecisionConfig(
|
||||
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
|
||||
# tucuate so the rest can run properly
|
||||
w1 = w1[..., :K, : 2 * N]
|
||||
w2 = w2[..., :N, :K]
|
||||
|
||||
w1 = deshuffle(w1)
|
||||
|
||||
w1 = w1.transpose(-1, -2).contiguous()
|
||||
w2 = w2.transpose(-1, -2).contiguous()
|
||||
|
||||
return (
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
num_hidden_layers: int = 36
|
||||
num_experts: int = 128
|
||||
experts_per_token: int = 4
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
head_dim: int = 64
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
sliding_window: int = 128
|
||||
initial_context_length: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_parameters_factor: float = 32.0
|
||||
rope_ntk_alpha: float = 1.0
|
||||
rope_ntk_beta: float = 32.0
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
if limit is not None:
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def oai_moe_forward(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, 2N)
|
||||
w1_bias: torch.Tensor, # (E, 2N, K)
|
||||
w2: torch.Tensor, # (E, K, N)
|
||||
w2_bias: torch.Tensor, # (E, N)
|
||||
gating_output: torch.Tensor, # (M, E)
|
||||
topk: int,
|
||||
):
|
||||
# model.py 309:330, assuming gating and norm
|
||||
t = hidden_states
|
||||
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
|
||||
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
||||
expert_indices = experts.indices
|
||||
|
||||
# MLP #1
|
||||
mlp1_weight = w1[expert_indices, ...]
|
||||
mlp1_bias = w1_bias[expert_indices, ...]
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, limit=7)
|
||||
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
mlp2_bias = w2_bias[expert_indices, ...]
|
||||
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
|
||||
t += mlp2_bias
|
||||
|
||||
# Weighted sum of experts
|
||||
t = torch.einsum("bec,be->bc", t, expert_weights)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
@dataclass
|
||||
class Case:
|
||||
a_dtype: str
|
||||
w_dtype: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
", ".join(f.name for f in fields(Case)),
|
||||
[
|
||||
tuple(getattr(case, f.name) for f in fields(Case))
|
||||
for case in [
|
||||
# Case(a_dtype="bf16", w_dtype="bf16"),
|
||||
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
|
||||
Case(a_dtype="bf16", w_dtype="mx4")
|
||||
]
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||
pytest.skip("make_default_matmul_mxfp4_w_layout not available")
|
||||
|
||||
M = num_token
|
||||
E = ModelConfig.num_experts
|
||||
K = ModelConfig.hidden_size
|
||||
N = ModelConfig.intermediate_size // tp
|
||||
topk = ModelConfig.experts_per_token
|
||||
|
||||
(
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_scale=pc1,
|
||||
w2_scale=pc2,
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
w1=w1_tri,
|
||||
w2=w2_tri,
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
out_ref = oai_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk,
|
||||
)
|
||||
assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005)
|
||||
|
||||
|
||||
def test_unit_shuffle():
|
||||
N = ModelConfig.intermediate_size
|
||||
K = ModelConfig.hidden_size
|
||||
m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
x = torch.randn(K, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
m_shuffled = shuffle_weight(m)
|
||||
|
||||
out_ref = x @ m
|
||||
out_ref = swiglu(out_ref, limit=1.0)
|
||||
|
||||
out = x @ m_shuffled
|
||||
out = triton_kernels.swiglu.swiglu_torch(
|
||||
out,
|
||||
alpha=1.702,
|
||||
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0),
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
81
tests/kernels/moe/test_grouped_topk.py
Normal file
81
tests/kernels/moe/test_grouped_topk.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MoE grouped topk kernel
|
||||
|
||||
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_grouped_topk,
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n_hidden", [1024, 2048])
|
||||
@pytest.mark.parametrize("n_expert", [16])
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("num_expert_group", [8])
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
|
||||
def test_grouped_topk(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
n_expert: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(n_expert,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
test_topk_weights, test_topk_ids = fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)
|
||||
350
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
350
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import textwrap
|
||||
import traceback
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel,
|
||||
)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
)
|
||||
from .modular_kernel_tools.parallel_utils import (
|
||||
ProcessGroupInfo,
|
||||
parallel_launch_with_config,
|
||||
)
|
||||
|
||||
has_any_multi_gpu_package = (
|
||||
has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
|
||||
)
|
||||
|
||||
meets_multi_gpu_requirements = pytest.mark.skipif(
|
||||
not has_any_multi_gpu_package,
|
||||
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def format_result(verbose, msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
newx = x.strip(" \n\t")[:16]
|
||||
if len(newx) < len(x):
|
||||
newx = newx + " ..."
|
||||
|
||||
prefix = "E\t"
|
||||
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
|
||||
print(f"FAILED {msg} - {newx}\n")
|
||||
elif verbose:
|
||||
print(f"PASSED {msg}")
|
||||
else:
|
||||
print(".", end="")
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
base_config: Config,
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
# Initialize workspace manager in child process
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = base_config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = base_config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
exceptions = []
|
||||
count = 0
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
# override m and topk
|
||||
config = copy.deepcopy(base_config)
|
||||
config.Ms = m
|
||||
config.topks = topk
|
||||
|
||||
try:
|
||||
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||
count = count + 1
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(config, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(config, weights, rank_tensors)
|
||||
|
||||
if config.quant_dtype == "nvfp4":
|
||||
atol = 1e-1 if config.K < 4096 else 2e-1
|
||||
rtol = 1e-1 if config.K < 4096 else 2e-1
|
||||
else:
|
||||
atol = 3e-2
|
||||
rtol = 3e-2
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
|
||||
format_result(verbose, config.describe())
|
||||
except Exception as ex:
|
||||
format_result(verbose, config.describe(), ex)
|
||||
exceptions.append(ex)
|
||||
|
||||
if len(exceptions) > 0:
|
||||
raise RuntimeError(
|
||||
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||
f"rank={pgi.rank}."
|
||||
)
|
||||
else:
|
||||
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
|
||||
|
||||
|
||||
def run(config: Config, verbose: bool):
|
||||
assert config.is_valid()[0]
|
||||
assert not is_nyi_config(config)
|
||||
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(
|
||||
config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
|
||||
)
|
||||
|
||||
|
||||
Ms = [32, 64]
|
||||
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||
Ks = [2048]
|
||||
Ns = [1024]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
||||
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
# We know these configs to be legitimate. but still fail.
|
||||
info = expert_info(config.fused_experts_type)
|
||||
|
||||
if info.needs_matching_quant:
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
unsupported_quant_config = (
|
||||
config.is_per_act_token_quant + config.is_per_out_ch_quant
|
||||
) == 1
|
||||
return unsupported_quant_config
|
||||
|
||||
return not info.supports_expert_map
|
||||
|
||||
|
||||
def generate_valid_test_cases(
|
||||
world_size: int, prepare_finalize_types
|
||||
) -> list[tuple[Any, ...]]:
|
||||
cases = []
|
||||
total = 0
|
||||
|
||||
for k, n, e, dtype, quant_config, combination, chunk_size in product(
|
||||
Ks,
|
||||
Ns,
|
||||
Es,
|
||||
DTYPEs,
|
||||
MK_QUANT_CONFIGS,
|
||||
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
|
||||
FUSED_MOE_CHUNK_SIZEs,
|
||||
):
|
||||
total = total + 1
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# TODO(bnell): figure out how to get verbose flag here.
|
||||
verbose = False # pytestconfig.getoption('verbose') > 0
|
||||
|
||||
valid, reason = config.is_valid()
|
||||
|
||||
if not valid:
|
||||
if verbose:
|
||||
print(f"Test config {config} is not valid: {reason}")
|
||||
continue
|
||||
|
||||
if is_nyi_config(config):
|
||||
if verbose:
|
||||
print(f"Test config {config} is nyi.")
|
||||
continue
|
||||
|
||||
cases.append(
|
||||
(
|
||||
k,
|
||||
n,
|
||||
e,
|
||||
dtype,
|
||||
quant_config,
|
||||
combination[0],
|
||||
combination[1],
|
||||
chunk_size,
|
||||
world_size,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"{len(cases)} of {total} valid configs generated.")
|
||||
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
)
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
if cuda_device_count_stateless() < world_size:
|
||||
pytest.skip(
|
||||
f"Not enough GPUs available to run, got "
|
||||
f"{cuda_device_count_stateless()} exepected "
|
||||
f"{world_size}."
|
||||
)
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
)
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if (
|
||||
quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
|
||||
) and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ability to test individual PrepareAndFinalize and FusedExperts combination
|
||||
from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser
|
||||
|
||||
parser = make_config_arg_parser(
|
||||
description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = make_config(args)
|
||||
|
||||
run(config, True)
|
||||
250
tests/kernels/moe/test_modular_oai_triton_moe.py
Normal file
250
tests/kernels/moe/test_modular_oai_triton_moe.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test modular OAI Triton MoE
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
pytest.skip(
|
||||
"triton_kernels not found, skipping all related tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MNK = [
|
||||
(1, 512, 384),
|
||||
(1, 2880, 2880),
|
||||
(2, 512, 384),
|
||||
(2, 2880, 2880),
|
||||
(16, 2880, 2880),
|
||||
]
|
||||
|
||||
|
||||
def unshuffle_weight(w: torch.Tensor):
|
||||
first = w[..., ::2]
|
||||
second = w[..., 1::2]
|
||||
return torch.concat((first, second), dim=-1)
|
||||
|
||||
|
||||
def make_weights(dtype, k, n, e):
|
||||
w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda")
|
||||
w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda")
|
||||
|
||||
w2 = torch.randn((e, n, k), dtype=dtype, device="cuda")
|
||||
w2_bias = torch.randn((e, k), dtype=dtype, device="cuda")
|
||||
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
|
||||
w1_bias_tri = w1_bias.clone()
|
||||
w2_bias_tri = w2_bias.clone()
|
||||
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
||||
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
||||
|
||||
# shuffle weights
|
||||
w1_tri = shuffle_weight(w1_tri)
|
||||
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
||||
|
||||
# quant triton_weights
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
|
||||
w1 = unshuffle_weight(w1)
|
||||
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)
|
||||
|
||||
num_warps = 8
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
|
||||
)
|
||||
|
||||
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w1_precision_config = PrecisionConfig(
|
||||
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
|
||||
return (
|
||||
w1,
|
||||
w2,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
)
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
if limit is not None:
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def torch_moe_impl(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, K, 2N)
|
||||
w2: torch.Tensor, # (E, N, K)
|
||||
w1_bias: torch.Tensor, # (E, 2N)
|
||||
w2_bias: torch.Tensor, # (E, K)
|
||||
topk_weights: torch.Tensor, # (M, topk)
|
||||
topk_ids: torch.Tensor, # (M, topk)
|
||||
):
|
||||
w1 = w1[topk_ids, ...]
|
||||
w1_bias = w1_bias[topk_ids, ...]
|
||||
hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias
|
||||
hidden_states = swiglu(hidden_states, limit=7)
|
||||
|
||||
w2 = w2[topk_ids, ...]
|
||||
w2_bias = w2_bias[topk_ids, ...]
|
||||
hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias
|
||||
|
||||
# Weighted sum of experts
|
||||
hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def oai_triton_moe_impl(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: "PrecisionConfig",
|
||||
w2_scale: "PrecisionConfig",
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
unfused: bool = False,
|
||||
) -> torch.Tensor:
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
if unfused:
|
||||
fused_experts = UnfusedOAITritonExperts(quant_config)
|
||||
else:
|
||||
fused_experts = OAITritonExperts(quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
|
||||
|
||||
return mk.forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation="swigluoai",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("m,n,k", MNK)
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("topk", [4])
|
||||
@pytest.mark.parametrize("unfused", [True, False])
|
||||
def test_oai_triton_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
unfused: bool,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
(
|
||||
w1,
|
||||
w2,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
) = make_weights(dtype, k, n, num_experts)
|
||||
|
||||
x = torch.randn((m, k), dtype=dtype, device="cuda")
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)
|
||||
|
||||
out = oai_triton_moe_impl(
|
||||
x,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
num_experts,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
unfused,
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)
|
||||
1288
tests/kernels/moe/test_moe.py
Normal file
1288
tests/kernels/moe/test_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
408
tests/kernels/moe/test_moe_align_block_size.py
Normal file
408
tests/kernels/moe/test_moe_align_block_size.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MOE align block size function.
|
||||
|
||||
Run `pytest tests/kernels/moe/test_moe_align_block_size.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
NUM_TOKENS = [1, 3, 256, 2256, 4096]
|
||||
NUM_EXPERTS = [32, 160, 256, 257]
|
||||
TOP_KS = [1, 2, 16, 32]
|
||||
BLOCK_SIZES = [32, 128]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
|
||||
def _group_tokens_by_expert(
|
||||
sorted_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
valid_length: int,
|
||||
total_tokens: int,
|
||||
) -> dict:
|
||||
num_blocks = valid_length // block_size
|
||||
expert_tokens: dict[int, list[int]] = {}
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
expert_id = expert_ids[block_idx].item()
|
||||
block_start = block_idx * block_size
|
||||
block_end = min(block_start + block_size, valid_length)
|
||||
|
||||
block_tokens = sorted_ids[block_start:block_end]
|
||||
valid_tokens = block_tokens[block_tokens < total_tokens]
|
||||
|
||||
if expert_id not in expert_tokens:
|
||||
expert_tokens[expert_id] = []
|
||||
expert_tokens[expert_id].extend(valid_tokens.tolist())
|
||||
return expert_tokens
|
||||
|
||||
|
||||
def _verify_expert_level_sorting(
|
||||
actual_sorted_ids: torch.Tensor,
|
||||
golden_sorted_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
valid_length: int,
|
||||
total_tokens: int,
|
||||
):
|
||||
"""
|
||||
Verify that actual_sorted_ids follows the correct expert-level sorting.
|
||||
The kerne limplementation may or may not preserve original token order
|
||||
in topk_ids in the final sorted_ids however this does not impact quality.
|
||||
"""
|
||||
# Group tokens by expert from the golden implementation
|
||||
golden_expert_tokens = _group_tokens_by_expert(
|
||||
golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens
|
||||
)
|
||||
|
||||
actual_expert_tokens = _group_tokens_by_expert(
|
||||
actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens
|
||||
)
|
||||
|
||||
assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), (
|
||||
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
|
||||
f"actual={set(actual_expert_tokens.keys())}"
|
||||
)
|
||||
|
||||
for expert_id in golden_expert_tokens:
|
||||
golden_tokens = torch.tensor(
|
||||
golden_expert_tokens[expert_id], device=actual_sorted_ids.device
|
||||
)
|
||||
actual_tokens = torch.tensor(
|
||||
actual_expert_tokens[expert_id], device=actual_sorted_ids.device
|
||||
)
|
||||
assert torch.equal(
|
||||
torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0]
|
||||
), (
|
||||
f"Expert {expert_id} token mismatch: "
|
||||
f"golden={golden_expert_tokens[expert_id]}, "
|
||||
f"actual={actual_expert_tokens[expert_id]}"
|
||||
)
|
||||
|
||||
|
||||
def torch_moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Golden torch implementation of moe_align_block_size.
|
||||
|
||||
This function aligns the token distribution across experts to be compatible
|
||||
with block size for matrix multiplication by sorting tokens by expert and
|
||||
padding to block boundaries.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = topk_ids.numel() * block_size
|
||||
|
||||
flattened_token_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
|
||||
)
|
||||
flattened_expert_ids = topk_ids.flatten()
|
||||
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True)
|
||||
sorted_token_indices = flattened_token_indices[sort_indices]
|
||||
|
||||
expert_token_counts = torch.zeros(
|
||||
num_experts, dtype=torch.int64, device=topk_ids.device
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
mask = sorted_expert_ids == expert_id
|
||||
expert_token_counts[expert_id] = mask.sum()
|
||||
|
||||
expert_padded_counts = torch.zeros(
|
||||
num_experts, dtype=torch.int64, device=topk_ids.device
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
original_count = expert_token_counts[expert_id]
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
if original_count > 0:
|
||||
expert_padded_counts[expert_id] = (
|
||||
(original_count + block_size - 1) // block_size
|
||||
) * block_size
|
||||
|
||||
sorted_token_ids = torch.full(
|
||||
(max_num_tokens_padded,),
|
||||
topk_ids.numel(),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
|
||||
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
for expert_id in range(num_experts):
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
|
||||
expert_mask = sorted_expert_ids == expert_id
|
||||
expert_tokens = sorted_token_indices[expert_mask]
|
||||
num_expert_tokens = expert_tokens.shape[0]
|
||||
|
||||
if num_expert_tokens > 0:
|
||||
sorted_token_ids[current_pos : current_pos + num_expert_tokens] = (
|
||||
expert_tokens
|
||||
)
|
||||
|
||||
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
|
||||
|
||||
expert_id_new = expert_id
|
||||
if expert_map is not None:
|
||||
expert_id_new = expert_map[expert_id]
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = (
|
||||
expert_id_new
|
||||
)
|
||||
|
||||
current_pos += expert_padded_counts[expert_id]
|
||||
current_block += expert_blocks_needed
|
||||
|
||||
total_padded_tokens = expert_padded_counts.sum()
|
||||
num_tokens_post_pad = torch.tensor(
|
||||
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
return sorted_token_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size(
|
||||
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
|
||||
):
|
||||
"""Test moe_align_block_size without expert mapping"""
|
||||
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
|
||||
for i in range(m):
|
||||
experts = torch.randperm(num_experts, device="cuda")[:topk]
|
||||
topk_ids[i] = experts
|
||||
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
|
||||
|
||||
# For sorted_token_ids, verify block-level correctness rather than exact
|
||||
# order Tokens within each expert's blocks can be in any order, but expert
|
||||
# regions must be correct
|
||||
_verify_expert_level_sorting(
|
||||
actual_sorted_ids,
|
||||
golden_sorted_ids,
|
||||
actual_expert_ids,
|
||||
block_size,
|
||||
actual_num_tokens.item(),
|
||||
m * topk,
|
||||
)
|
||||
|
||||
total_tokens = m * topk
|
||||
assert actual_num_tokens.item() % block_size == 0, (
|
||||
"num_tokens_post_pad should be divisible by block_size"
|
||||
)
|
||||
assert actual_num_tokens.item() >= total_tokens, (
|
||||
"num_tokens_post_pad should be at least total_tokens"
|
||||
)
|
||||
valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens]
|
||||
assert len(valid_tokens) == total_tokens, (
|
||||
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
|
||||
)
|
||||
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
|
||||
"expert_ids should contain valid expert indices"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32, 2048])
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size_with_expert_map(
|
||||
m: int, topk: int, num_experts: int, block_size: int
|
||||
):
|
||||
"""Test moe_align_block_size with expert mapping (EP scenario)"""
|
||||
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
|
||||
for i in range(m):
|
||||
experts = torch.randperm(num_experts, device="cuda")[:topk]
|
||||
topk_ids[i] = experts
|
||||
|
||||
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
local_experts = list(range(0, num_experts, 2))
|
||||
for i, expert_id in enumerate(local_experts):
|
||||
expert_map[expert_id] = i
|
||||
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
|
||||
_verify_expert_level_sorting(
|
||||
actual_sorted_ids,
|
||||
golden_sorted_ids,
|
||||
actual_expert_ids,
|
||||
block_size,
|
||||
actual_num_tokens.item(),
|
||||
m * topk,
|
||||
)
|
||||
|
||||
|
||||
def test_moe_align_block_size_deterministic():
|
||||
m, topk, num_experts, block_size = 128, 2, 32, 64
|
||||
|
||||
torch.manual_seed(42)
|
||||
topk_ids = torch.randint(
|
||||
0, num_experts, (m, topk), device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
# expect the results to be reproducible
|
||||
results = []
|
||||
for _ in range(5):
|
||||
sorted_ids, expert_ids, num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts
|
||||
)
|
||||
results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
|
||||
|
||||
for i in range(1, len(results)):
|
||||
assert torch.equal(results[0][0], results[i][0]), (
|
||||
"sorted_ids should be deterministic"
|
||||
)
|
||||
assert torch.equal(results[0][1], results[i][1]), (
|
||||
"expert_ids should be deterministic"
|
||||
)
|
||||
assert torch.equal(results[0][2], results[i][2]), (
|
||||
"num_tokens should be deterministic"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512])
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("simulate_empty_batches", [False, True])
|
||||
def test_batched_moe_align_block_size(
|
||||
max_tokens_per_batch: int,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
simulate_empty_batches: bool,
|
||||
):
|
||||
def ref_outputs(
|
||||
expert_num_tokens: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
E = expert_num_tokens.size(0)
|
||||
|
||||
# Round up so each batch can be split to blocks evenly.
|
||||
Msum = round_up(max_tokens_per_batch, block_size) * E
|
||||
ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32)
|
||||
ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32)
|
||||
ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32)
|
||||
|
||||
# Intialize
|
||||
sentinel = E * max_tokens_per_batch
|
||||
ref_sorted_ids.fill_(sentinel)
|
||||
ref_expert_ids.fill_(-1)
|
||||
|
||||
# Fill ref_sorted_ids
|
||||
i = 0
|
||||
for expert_id, expert_nt in enumerate(expert_num_tokens):
|
||||
token_offset = expert_id * max_tokens_per_batch
|
||||
for j in range(expert_nt):
|
||||
ref_sorted_ids[i] = token_offset + j
|
||||
i += 1
|
||||
# round up i to the next block_size
|
||||
i = round_up(i, block_size)
|
||||
|
||||
ref_num_tokens_post_pad[0] = i
|
||||
|
||||
# Fill expert_ids
|
||||
nt_ceil_sum = 0
|
||||
for expert_id, expert_nt in enumerate(expert_num_tokens):
|
||||
expert_ids_offset = nt_ceil_sum // block_size
|
||||
ceil_expert_nt = round_up(int(expert_nt.item()), block_size)
|
||||
num_blocks = ceil_expert_nt // block_size
|
||||
for x in range(num_blocks):
|
||||
ref_expert_ids[expert_ids_offset + x] = expert_id
|
||||
nt_ceil_sum += ceil_expert_nt
|
||||
|
||||
return (
|
||||
ref_sorted_ids.to("cuda"),
|
||||
ref_expert_ids.to("cuda"),
|
||||
ref_num_tokens_post_pad.to("cuda"),
|
||||
)
|
||||
|
||||
# Compute expert_num_tokens
|
||||
expert_num_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_batch,
|
||||
size=(num_experts,),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if simulate_empty_batches:
|
||||
# mark half the batches to have 0 tokens
|
||||
zero_batches = torch.randperm(num_experts)[: num_experts // 2]
|
||||
expert_num_tokens[zero_batches] = 0
|
||||
|
||||
# ref outputs
|
||||
ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs(
|
||||
expert_num_tokens
|
||||
)
|
||||
|
||||
# outputs
|
||||
sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size(
|
||||
max_tokens_per_batch, block_size, expert_num_tokens.to("cuda")
|
||||
)
|
||||
|
||||
assert ref_sorted_ids.size() == sorted_ids.size(), (
|
||||
f"{ref_sorted_ids.size()} vs {sorted_ids.size()}"
|
||||
)
|
||||
assert ref_expert_ids.size() == expert_ids.size(), (
|
||||
f"{ref_expert_ids.size()} vs {expert_ids.size()}"
|
||||
)
|
||||
assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), (
|
||||
f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}"
|
||||
)
|
||||
torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0)
|
||||
torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0)
|
||||
torch.testing.assert_close(
|
||||
ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0
|
||||
)
|
||||
311
tests/kernels/moe/test_moe_permute_unpermute.py
Normal file
311
tests/kernels/moe/test_moe_permute_unpermute.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MOE permute/unpermute kernel
|
||||
|
||||
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute,
|
||||
moe_permute_unpermute_supported,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [16, 64, 256]
|
||||
TOP_KS = [2, 6, 8]
|
||||
EP_SIZE = [1, 4, 16]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"moe_permute_unpermute_supported is not defined for ROCm",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def torch_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
if expert_map is not None:
|
||||
is_local_expert = expert_map[topk_ids] != -1
|
||||
not_local_expert = expert_map[topk_ids] == -1
|
||||
topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * (
|
||||
topk_ids + n_expert
|
||||
)
|
||||
token_expert_indices = torch.arange(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
|
||||
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True)
|
||||
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
|
||||
|
||||
expert_first_token_offset = torch.zeros(
|
||||
n_local_expert + 1, dtype=torch.int64, device="cuda"
|
||||
)
|
||||
idx = 0
|
||||
for i in range(0, n_local_expert):
|
||||
cnt = 0
|
||||
while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i:
|
||||
cnt += 1
|
||||
idx += 1
|
||||
expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt
|
||||
|
||||
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
|
||||
valid_row_idx = []
|
||||
if align_block_size is None:
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
|
||||
permuted_row_size = permuted_hidden_states.shape[0]
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
m_indices[first_token_offset:last_token_offset] = i - 1
|
||||
src_row_id2dst_row_id_map = torch.arange(
|
||||
0, n_token * topk, device="cuda", dtype=torch.int32
|
||||
)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map,
|
||||
dst_row_id2src_row_id_map,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
else:
|
||||
permuted_row_size = (
|
||||
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
permuted_idx = torch.full(
|
||||
(permuted_row_size,),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
|
||||
)
|
||||
align_src_row_id2dst_row_id = torch.empty(
|
||||
n_token * topk, device="cuda", dtype=torch.int32
|
||||
)
|
||||
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
# get align_permuted_hidden_states,
|
||||
# valid row_idx and align_expert_first_token_offset
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
n_token_in_expert = last_token_offset - first_token_offset
|
||||
align_expert_first_token_offset[i] = (
|
||||
align_expert_first_token_offset[i - 1]
|
||||
+ (n_token_in_expert + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
align_first_token_offset = align_expert_first_token_offset[i - 1]
|
||||
align_last_token_offset = align_expert_first_token_offset[i]
|
||||
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
|
||||
first_token_offset : first_token_offset + n_token_in_expert
|
||||
]
|
||||
# store token in current expert with align_first_token_offset
|
||||
permuted_hidden_states[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert,
|
||||
...,
|
||||
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
|
||||
permuted_idx[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert
|
||||
] = dst_row_id2src_row_id_in_expert
|
||||
# set current expert m_indices
|
||||
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
|
||||
valid_row_idx += [
|
||||
i
|
||||
for i in range(
|
||||
align_first_token_offset,
|
||||
align_first_token_offset + n_token_in_expert,
|
||||
)
|
||||
]
|
||||
# get align_src_row_id2dst_row_id
|
||||
for i in range(n_token * topk):
|
||||
eid = sorted_topk_ids[i]
|
||||
if eid >= n_local_expert:
|
||||
# check token not in local expert
|
||||
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
|
||||
continue
|
||||
first_token_offset = expert_first_token_offset[eid]
|
||||
align_first_token_offset = align_expert_first_token_offset[eid]
|
||||
token_offset = i - first_token_offset
|
||||
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
|
||||
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
|
||||
(n_token, topk)
|
||||
)
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id,
|
||||
permuted_idx,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
|
||||
|
||||
def torch_unpermute(
|
||||
permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
valid_row_idx: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
) -> torch.Tensor:
|
||||
# ignore invalid row
|
||||
n_hidden = permuted_hidden_states.shape[1]
|
||||
mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda")
|
||||
mask[valid_row_idx] = True
|
||||
permuted_hidden_states[~mask] = 0
|
||||
|
||||
permuted_hidden_states = permuted_hidden_states[
|
||||
src_row_id2dst_row_id_map.flatten(), ...
|
||||
]
|
||||
permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
|
||||
output = (
|
||||
(permuted_hidden_states * topk_weights.unsqueeze(2))
|
||||
.sum(1)
|
||||
.to(permuted_hidden_states.dtype)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000])
|
||||
@pytest.mark.parametrize("n_hidden", [2048, 7168])
|
||||
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("align_block_size", [None, 128])
|
||||
def test_moe_permute_unpermute(
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
align_block_size: int | None,
|
||||
):
|
||||
if not moe_permute_unpermute_supported():
|
||||
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||
fill_invalid_expert = 0
|
||||
ep_rank = np.random.randint(0, ep_size)
|
||||
expert_map = None
|
||||
n_local_expert = n_expert
|
||||
if ep_size != 1:
|
||||
n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert)
|
||||
expert_map = expert_map.cuda()
|
||||
start_expert = n_local_expert * ep_rank
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
|
||||
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, False
|
||||
)
|
||||
(
|
||||
gold_permuted_hidden_states,
|
||||
gold_expert_first_token_offset,
|
||||
gold_inv_permuted_idx,
|
||||
gold_permuted_idx,
|
||||
gold_m_indices,
|
||||
valid_row_idx,
|
||||
) = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
# token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
_,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
n_local_expert=n_local_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
# check expert_first_token_offset
|
||||
torch.testing.assert_close(
|
||||
gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0
|
||||
)
|
||||
# check src_row_id2dst_row_id_map
|
||||
torch.testing.assert_close(
|
||||
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
|
||||
)
|
||||
# check mindice
|
||||
# current kernel usage assumes deepgemm requires align_block_size
|
||||
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||
if align_block_size is not None:
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(
|
||||
gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
# add a random tensor to simulate group gemm
|
||||
result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states)
|
||||
result4 = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset
|
||||
)
|
||||
|
||||
gold4 = torch_unpermute(
|
||||
result0,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
inv_permuted_idx,
|
||||
valid_row_idx,
|
||||
topk,
|
||||
n_local_expert,
|
||||
)
|
||||
# check unpermuted hidden
|
||||
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
|
||||
140
tests/kernels/moe/test_nvfp4_moe.py
Normal file
140
tests/kernels/moe/test_nvfp4_moe.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
"Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
|
||||
)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(64, 1024, 1024),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
quant_blocksize = 16
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
(_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
|
||||
make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_out_ch_quant=False,
|
||||
)
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
g1_alphas=(1 / w1_gs),
|
||||
g2_alphas=(1 / w2_gs),
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
)
|
||||
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
w1_fp4=w1_q,
|
||||
w2_fp4=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=quant_config,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
|
||||
993
tests/kernels/moe/test_ocp_mx_moe.py
Normal file
993
tests/kernels/moe/test_ocp_mx_moe.py
Normal file
@@ -0,0 +1,993 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer()
|
||||
)
|
||||
|
||||
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
from flashinfer import (
|
||||
fp4_quantize,
|
||||
mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_case",
|
||||
[
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
|
||||
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
if torch.cuda.device_count() < model_case.tp:
|
||||
pytest.skip(
|
||||
f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
# `cudagraph_capture_sizes=[16]` to reduce load time.
|
||||
with vllm_runner(
|
||||
model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy",
|
||||
cudagraph_capture_sizes=[16],
|
||||
) as llm:
|
||||
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
||||
# def check_model(model):
|
||||
# from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
# QuarkLinearMethod)
|
||||
# from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501
|
||||
# from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
# QuarkOCP_MX_MoEMethod)
|
||||
|
||||
# layer = model.model.layers[0]
|
||||
|
||||
# qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
# assert isinstance(qkv_proj.scheme, QuarkOCP_MX)
|
||||
|
||||
# assert isinstance(layer.mlp.experts.quant_method,
|
||||
# QuarkOCP_MX_MoEMethod)
|
||||
|
||||
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
# llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
return out_glu * (x_linear + beta)
|
||||
|
||||
|
||||
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
|
||||
|
||||
|
||||
def mxfp4_dequantize(x, scale):
|
||||
assert x.dtype == torch.uint8
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
x_unpacked = torch.zeros(
|
||||
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
|
||||
)
|
||||
x_unpacked[..., 0::2].copy_(x & 0xF)
|
||||
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
||||
|
||||
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
|
||||
for i, val in enumerate(fp4_lookup_table):
|
||||
x_float[x_unpacked == i] = val
|
||||
|
||||
scale = scale.view(torch.uint8).to(torch.int32)
|
||||
scale = (scale << 23).view(torch.float32)
|
||||
scale = scale.reshape(*x.shape[:-1], -1)
|
||||
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
||||
|
||||
return x_float * scale
|
||||
|
||||
|
||||
def mxfp8_dequantize(x, scale):
|
||||
assert x.dtype == torch.float8_e4m3fn
|
||||
x_float = x.to(torch.float32)
|
||||
|
||||
scale = scale.view(torch.uint8).to(torch.int32)
|
||||
scale = (scale << 23).view(torch.float32)
|
||||
scale = scale.reshape(*x.shape[:-1], -1)
|
||||
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
||||
|
||||
return x_float * scale
|
||||
|
||||
|
||||
def reference_moe(
|
||||
roouting_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states,
|
||||
w13,
|
||||
bias13,
|
||||
w2,
|
||||
bias2,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
act_type,
|
||||
):
|
||||
# renormalize routing
|
||||
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
||||
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
||||
expert_indices = experts.indices
|
||||
t = hidden_states.clone()
|
||||
# MLP #1
|
||||
mlp1_weight = w13[expert_indices, ...]
|
||||
mlp1_bias = bias13[expert_indices, ...]
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
||||
|
||||
if act_type == "mxfp8":
|
||||
t_quantized, t_scale = mxfp8_quantize(
|
||||
t.to(torch.bfloat16), is_sf_swizzled_layout=False
|
||||
)
|
||||
t = mxfp8_dequantize(t_quantized, t_scale)
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
mlp2_bias = bias2[expert_indices, ...]
|
||||
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
|
||||
# Weighted sum of experts
|
||||
t = torch.einsum("bec,be->bc", t, expert_weights)
|
||||
assert t.shape == hidden_states.shape
|
||||
return t.to(torch.bfloat16)
|
||||
|
||||
|
||||
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13_weight,
|
||||
w13_weight_scale,
|
||||
w13_bias,
|
||||
w2_weight,
|
||||
w2_weight_scale,
|
||||
w2_bias,
|
||||
act_type,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
transpose_optimized: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sf_block_size = 32
|
||||
assert (
|
||||
w13_weight.dim() == 3
|
||||
and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w2_weight.dim() == 3
|
||||
and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w13_bias.dim() == 2
|
||||
and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2
|
||||
)
|
||||
assert (
|
||||
w2_bias.dim() == 2
|
||||
and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size
|
||||
)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
w13_weight_scale_ = w13_weight_scale.clone()
|
||||
w13_weight_ = w13_weight.clone()
|
||||
w13_bias_ = w13_bias.clone()
|
||||
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight_scale[:, :intermediate_size, :].copy_(
|
||||
w13_weight_scale_[:, intermediate_size:, :]
|
||||
)
|
||||
w13_weight_scale[:, intermediate_size:, :].copy_(
|
||||
w13_weight_scale_[:, :intermediate_size, :]
|
||||
)
|
||||
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
||||
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
||||
|
||||
# Interleave the weights and scaling factors for activation
|
||||
w13_weight_interleaved = []
|
||||
w13_weight_scale_interleaved = []
|
||||
w13_bias_interleaved = []
|
||||
for i in range(num_experts):
|
||||
w13_weight_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
|
||||
)
|
||||
w13_weight_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
|
||||
)
|
||||
w13_bias_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
|
||||
)
|
||||
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
)
|
||||
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32
|
||||
)
|
||||
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size
|
||||
)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_shuffled = []
|
||||
gemm1_scales_shuffled = []
|
||||
gemm2_weights_shuffled = []
|
||||
gemm2_scales_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
if transpose_optimized:
|
||||
for i in range(num_experts):
|
||||
# w13 weight shuffling
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled.append(
|
||||
w13_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
w13_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
w13_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight shuffling
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled.append(
|
||||
w2_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
w2_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w2 bias shuffling
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
w2_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
else:
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_shuffled)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_shuffled)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_shuffled)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
tg_result = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale,
|
||||
gemm1_bias=w13_bias,
|
||||
gemm1_alpha=alpha,
|
||||
gemm1_beta=beta,
|
||||
gemm1_clamp_limit=limit,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale,
|
||||
gemm2_bias=w2_bias,
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=num_experts,
|
||||
top_k=topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
||||
routing_method_type=1, # renormalize
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
return tg_result
|
||||
|
||||
|
||||
def check_accuracy(a, b, atol, rtol, percent):
|
||||
"""Allow a mismatch percentage of 1 - percent."""
|
||||
if torch.any(torch.isnan(a)):
|
||||
raise Exception("NaN in reference output")
|
||||
if torch.any(torch.isnan(b)):
|
||||
raise Exception("NaN in actual output")
|
||||
if torch.any(torch.isinf(a)):
|
||||
raise Exception("Inf in reference output")
|
||||
if torch.any(torch.isinf(b)):
|
||||
raise Exception("Inf in actual output")
|
||||
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
||||
|
||||
left = torch.abs(a - b)
|
||||
right = atol + rtol * torch.abs(b)
|
||||
count = torch.sum(left > right)
|
||||
mismatch_percent = count / a.numel()
|
||||
if mismatch_percent > 1 - percent:
|
||||
raise Exception(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1 - percent:.4f})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
|
||||
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test",
|
||||
)
|
||||
def test_trtllm_gen_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: float | None,
|
||||
act_type: str,
|
||||
transpose_optimized: bool,
|
||||
):
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
|
||||
)
|
||||
w13 = torch.randn(
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
|
||||
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
||||
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
|
||||
|
||||
w13, w13_scale = fp4_quantize(
|
||||
w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, intermediate_size * 2, hidden_size // 32
|
||||
)
|
||||
w2, w2_scale = fp4_quantize(
|
||||
w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 32
|
||||
)
|
||||
if act_type == "mxfp8":
|
||||
hidden_states, hidden_states_scale = mxfp8_quantize(
|
||||
hidden_states, is_sf_swizzled_layout=False
|
||||
)
|
||||
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
# reference result
|
||||
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
|
||||
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
||||
bias13_ref = bias13
|
||||
bias2_ref = bias2
|
||||
if act_type == "mxfp8":
|
||||
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
|
||||
torch.float32
|
||||
)
|
||||
else:
|
||||
hidden_states_ref = hidden_states.to(torch.float32)
|
||||
# Process tokens in chunks of 32 to reduce memory usage
|
||||
chunk_size = 32
|
||||
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min(start_idx + chunk_size, num_tokens)
|
||||
chunk_result = reference_moe(
|
||||
router_logits[start_idx:end_idx].to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states_ref[start_idx:end_idx],
|
||||
w13_ref,
|
||||
bias13_ref,
|
||||
w2_ref,
|
||||
bias2_ref,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
act_type,
|
||||
)
|
||||
ref_result[start_idx:end_idx].copy_(chunk_result)
|
||||
|
||||
# trtllm-gen result
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized,
|
||||
)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
|
||||
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave scales on the last dimension by groups of 4, matching
|
||||
the transformation in mxfp4.py's BF16 (Hopper) path."""
|
||||
s = scales.to(torch.uint8)
|
||||
s_shape = s.shape
|
||||
assert s_shape[-1] % 4 == 0
|
||||
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
|
||||
# Move the 4-group dimension before the row dimension
|
||||
permuted = s.permute(0, 2, 1, 3)
|
||||
# Merge the row dim with the 4-group dim
|
||||
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not HOPPER_MXFP4_BF16_AVAILABLE,
|
||||
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: float | None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
||||
w13_q = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w13_scale = torch.randint(
|
||||
118,
|
||||
123,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
w2_q = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2_scale = torch.randint(
|
||||
118,
|
||||
123,
|
||||
(num_experts, hidden_size, intermediate_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size
|
||||
)
|
||||
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
||||
num_experts, hidden_size, intermediate_size
|
||||
)
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"bf16",
|
||||
)
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
||||
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
|
||||
w13_s = torch.cat([w3_s, w1_s], dim=1)
|
||||
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
||||
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
||||
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
fc1_expert_weights=w13_q_swapped,
|
||||
fc2_expert_weights=w2_q,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha,
|
||||
swiglu_beta=beta,
|
||||
swiglu_limit=limit,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_w4_group_scaling=True,
|
||||
)
|
||||
|
||||
# Allow some mismatch due to MXFP4 quantization
|
||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float | None,
|
||||
beta: float | None,
|
||||
limit: float | None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Float weights in w13 format [w1; w3]
|
||||
w13 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Quantize weights to MXFP4 per expert (SM100 path)
|
||||
from flashinfer import mxfp4_quantize
|
||||
|
||||
def quant_mxfp4_batches(a: torch.Tensor, e: int):
|
||||
qs, sfs = [], []
|
||||
for i in range(e):
|
||||
q, sf = mxfp4_quantize(a[i].cuda())
|
||||
qs.append(q)
|
||||
sfs.append(sf)
|
||||
return torch.stack(qs), torch.stack(sfs)
|
||||
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
|
||||
num_batches = mat_fp4.size(0)
|
||||
scale_tensor = scale_tensor.view(num_batches, -1)
|
||||
from flashinfer import mxfp4_dequantize
|
||||
|
||||
return torch.stack(
|
||||
[
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
]
|
||||
)
|
||||
|
||||
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
||||
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
||||
|
||||
# Reference result using dequantized tensors and reference_moe
|
||||
w13_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size)
|
||||
.to(device)
|
||||
)
|
||||
w2_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, hidden_size, intermediate_size)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
# Quantize activations for SM100 path and dequantize for reference
|
||||
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
||||
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"mxfp8",
|
||||
)
|
||||
|
||||
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
||||
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
# Swap scales halves to match swapped weights
|
||||
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
# Build routing for kernel
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
else:
|
||||
alpha_t = None
|
||||
if beta is not None:
|
||||
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
else:
|
||||
beta_t = None
|
||||
if limit is not None:
|
||||
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
else:
|
||||
limit_t = None
|
||||
|
||||
# Quant scales for SM100 MXFP8+MXFP4 path
|
||||
fake_input_scale = torch.ones(num_experts, device=device)
|
||||
quant_scales = [
|
||||
w13_scale_swapped.view(torch.int32),
|
||||
fake_input_scale,
|
||||
w2_scale.view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states_q,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
|
||||
fc2_expert_weights=w2_q.contiguous().view(torch.long),
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha_t,
|
||||
swiglu_beta=beta_t,
|
||||
swiglu_limit=limit_t,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=hidden_states_sf,
|
||||
)
|
||||
|
||||
# Allow some mismatch due to MXFP4 quantization
|
||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||
356
tests/kernels/moe/test_pplx_cutlass_moe.py
Normal file
356
tests/kernels/moe/test_pplx_cutlass_moe.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
has_pplx = True
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
|
||||
def rank_chunk(num, r, w):
|
||||
rem = num % w
|
||||
return (num // w) + (1 if r < rem else 0)
|
||||
|
||||
|
||||
def chunk_by_rank(t, r, w):
|
||||
num = t.shape[0]
|
||||
chunk = rank_chunk(num, r, w)
|
||||
rem = num % w
|
||||
if rem == 0 or r < rem:
|
||||
return t[(r * chunk) : (r + 1) * chunk].contiguous()
|
||||
else:
|
||||
long_chunks = (num // w + 1) * rem
|
||||
short_chunks = (r - rem) * chunk
|
||||
start = long_chunks + short_chunks
|
||||
return t[start : start + chunk].contiguous()
|
||||
|
||||
|
||||
def pplx_cutlass_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
out_dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
group_name: str | None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
)
|
||||
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
intermediate_dim = w2.shape[2]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = hidden_dim # TODO support more cases
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
|
||||
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
if block_size == hidden_dim:
|
||||
scale_elems = 4 # hack to circumvent pplx data format requirements
|
||||
else:
|
||||
scale_elems = (hidden_dim + block_size - 1) // block_size
|
||||
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
||||
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
ata = AllToAll.internode(**args)
|
||||
else:
|
||||
args["group_name"] = group_name
|
||||
ata = AllToAll.intranode(**args)
|
||||
|
||||
w1 = w1.to(device)
|
||||
w2 = w2.to(device)
|
||||
w1_scale = w1_scale.to(device)
|
||||
w2_scale = w2_scale.to(device)
|
||||
a1_scale = a1_scale.to(device)
|
||||
|
||||
assert num_experts % world_size == 0
|
||||
num_local_experts = cdiv(num_experts, world_size)
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
ata,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full(
|
||||
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
ab_strides2 = torch.full(
|
||||
(num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides1 = torch.full(
|
||||
(num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides2 = torch.full(
|
||||
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
num_local_experts,
|
||||
num_dispatchers,
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token
|
||||
else a1_scale[rank],
|
||||
),
|
||||
)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
|
||||
chunk_topk_ids = (
|
||||
chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
|
||||
)
|
||||
|
||||
out = fused_cutlass_experts(
|
||||
a_chunk,
|
||||
chunk_by_rank(w1, rank, world_size),
|
||||
chunk_by_rank(w2, rank, world_size),
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None, # TODO
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ata.destroy()
|
||||
|
||||
return out[:rank_num_tokens]
|
||||
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
def _pplx_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
out_dtype,
|
||||
a_full: torch.Tensor,
|
||||
w1_full: torch.Tensor,
|
||||
w2_full: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_internode: bool,
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_experts(
|
||||
a_full, w1_full, w2_full, topk_weights, topk_ids
|
||||
)
|
||||
pplx_output = pplx_cutlass_moe(
|
||||
pgi,
|
||||
dp_size,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a1_scale,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
group_name,
|
||||
)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
|
||||
pplx_output.device
|
||||
)
|
||||
|
||||
# Uncomment if more debugging is needed
|
||||
# print("PPLX OUT:", pplx_output)
|
||||
# print("TORCH OUT:", torch_output)
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 224])
|
||||
@pytest.mark.parametrize("n", [3072])
|
||||
@pytest.mark.parametrize("k", [1536])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
@requires_pplx
|
||||
def test_cutlass_moe_pplx(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch
|
||||
)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
a_scale1 = (
|
||||
torch.randn(
|
||||
(m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
if not per_act_token:
|
||||
a_scale1 = a_scale1.repeat(world_size, 1)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_moe,
|
||||
dp_size,
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a_scale1,
|
||||
dtype,
|
||||
a,
|
||||
w1_d,
|
||||
w2_d,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_internode,
|
||||
)
|
||||
1018
tests/kernels/moe/test_pplx_moe.py
Normal file
1018
tests/kernels/moe/test_pplx_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
219
tests/kernels/moe/test_rocm_aiter_topk.py
Normal file
219
tests/kernels/moe/test_rocm_aiter_topk.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# This is a test for the AITER ops.
|
||||
# It tests if the AITER ops are
|
||||
# 1. correctly registered as custom ops
|
||||
# 2. correctly defined the relationship between
|
||||
# implementation and fake function
|
||||
# 3. can be used with torch.compile
|
||||
# This file will be skipped if AITER is not installed
|
||||
# and the platform is not ROCm.
|
||||
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# this import statement is needed to ensure the ops are registered
|
||||
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# need to import once to ensure the ops are registered
|
||||
# Check if aiter package is installed
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and aiter_available),
|
||||
reason="AITER ops are only available on ROCm with aiter package installed",
|
||||
)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
token = 64
|
||||
expert = 256
|
||||
num_expert_group = 8
|
||||
topk = 8
|
||||
topk_group = 4
|
||||
renormalize = True
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(expert,), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def biased_grouped_topk_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights, topk_ids
|
||||
):
|
||||
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output,
|
||||
e_score_correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk,
|
||||
(gating_output, e_score_correction_bias, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"routed_scaling_factor": scale_factor,
|
||||
},
|
||||
test_utils=("test_faketensor"),
|
||||
)
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(
|
||||
biased_grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
topk_weights_original = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
biased_grouped_topk_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
|
||||
)
|
||||
compiled_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
|
||||
)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(
|
||||
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
|
||||
|
||||
def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
token = 64
|
||||
expert = 256
|
||||
num_expert_group = 8
|
||||
topk = 8
|
||||
topk_group = 4
|
||||
renormalize = True
|
||||
scoring_func = "softmax"
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
|
||||
return torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
scoring_func,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk,
|
||||
(gating_output, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"scoring_func": scoring_func,
|
||||
"routed_scaling_factor": scale_factor,
|
||||
},
|
||||
test_utils=("test_faketensor"),
|
||||
)
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(
|
||||
grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
topk_weights_original = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
grouped_topk_fn(
|
||||
gating_output, topk_weights_original, topk_ids_original, scoring_func
|
||||
)
|
||||
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(
|
||||
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
293
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
293
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
CASES = [
|
||||
(1, 1, 128, fp8_dtype),
|
||||
(1, 4, 128 * 1, fp8_dtype),
|
||||
(2, 4, 128 * 2, fp8_dtype),
|
||||
(1, 4, 128 * 3, fp8_dtype),
|
||||
(8, 16, 128 * 4, fp8_dtype),
|
||||
(8, 16, 128 * 5, fp8_dtype),
|
||||
(8, 16, 128 * 6, fp8_dtype),
|
||||
(8, 16, 128 * 7, fp8_dtype),
|
||||
(8, 16, 128 * 8, fp8_dtype),
|
||||
(8, 16, 128 * 9, fp8_dtype),
|
||||
(8, 64, 7168, fp8_dtype),
|
||||
(8, 128, 128 * 33, fp8_dtype),
|
||||
(1, 4, 128 * 10, fp8_dtype),
|
||||
(8, 128, 7168, fp8_dtype),
|
||||
(8, 512, 7168, fp8_dtype),
|
||||
(8, 1024, 7168, fp8_dtype),
|
||||
(17, 31, 768, fp8_dtype),
|
||||
(32, 64, 256, fp8_dtype),
|
||||
(256, 8, 7168, fp8_dtype),
|
||||
(256, 32, 7168, fp8_dtype),
|
||||
(256, 64, 7168, fp8_dtype),
|
||||
# Only add a few fnuz tests to help with long CI times.
|
||||
(8, 512, 7168, torch.float8_e4m3fnuz),
|
||||
(8, 1024, 7168, torch.float8_e4m3fnuz),
|
||||
]
|
||||
|
||||
|
||||
def as_uint8(x) -> torch.Tensor:
|
||||
return (
|
||||
torch.empty(x.shape, dtype=x.dtype, device=x.device).copy_(x).view(torch.uint8)
|
||||
)
|
||||
|
||||
|
||||
def silu(x: torch.Tensor) -> torch.Tensor:
|
||||
one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32)
|
||||
x_f32 = x.to(torch.float32)
|
||||
act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32))
|
||||
assert act_f32.dtype == torch.float32
|
||||
return act_f32.to(torch.bfloat16)
|
||||
|
||||
|
||||
def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
|
||||
eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16)
|
||||
one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16)
|
||||
fp8_max_bf16 = torch.tensor(
|
||||
[torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16
|
||||
)
|
||||
fp8_min_bf16 = torch.tensor(
|
||||
[torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16
|
||||
)
|
||||
fp8_max_inv = one_bf16 / fp8_max_bf16
|
||||
assert fp8_max_inv.dtype == torch.bfloat16
|
||||
|
||||
assert x.size(-1) % group_size == 0
|
||||
num_groups = x.numel() // group_size
|
||||
x_og_shape = x.shape
|
||||
|
||||
x = x.to(torch.bfloat16)
|
||||
x = x.view((-1, group_size))
|
||||
amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
|
||||
assert amax.dtype == torch.bfloat16
|
||||
s = amax * fp8_max_inv
|
||||
|
||||
if ceil_ue8m0:
|
||||
s = torch.exp2(
|
||||
torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16)
|
||||
).to(torch.bfloat16)
|
||||
|
||||
inv_s = one_bf16 / s
|
||||
inv_s = inv_s.view((num_groups, 1))
|
||||
xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to(
|
||||
fp8_dtype
|
||||
)
|
||||
|
||||
xq = xq.view(x_og_shape)
|
||||
xs = s.view((-1, xq.size(-1) // group_size))
|
||||
return xq, xs
|
||||
|
||||
|
||||
def silu_mul_quant(
|
||||
gate: torch.Tensor, up: torch.Tensor, group_size: int, ceil_ue8m0: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert gate.size(-1) % group_size == 0
|
||||
assert up.size(-1) % group_size == 0
|
||||
|
||||
assert gate.dtype == torch.bfloat16
|
||||
assert up.dtype == torch.bfloat16
|
||||
|
||||
act_bf16 = silu(gate)
|
||||
assert act_bf16.dtype == torch.bfloat16
|
||||
|
||||
# act & mul
|
||||
a_m = act_bf16 * up
|
||||
assert a_m.dtype == torch.bfloat16
|
||||
|
||||
q, s = do_quant(a_m, group_size, ceil_ue8m0)
|
||||
return q, s
|
||||
|
||||
|
||||
def pack_scales(x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
pack float32 scales into a int32 tensor
|
||||
"""
|
||||
assert x.dtype == torch.float32
|
||||
E, T, G = x.size()
|
||||
|
||||
# Add i32_padding here so we can view it as a i32 tensor later on.
|
||||
i32_padding = round_up(G, 4) - G
|
||||
ref_s_i8 = torch.empty((E, T, G + i32_padding), dtype=torch.uint8, device="cuda")
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
ref_s_i8[e, :nt, :G] = x[e, :nt].view(torch.int32) >> 23
|
||||
|
||||
ref_s_i32 = ref_s_i8.view(torch.int32)
|
||||
|
||||
return ref_s_i32
|
||||
|
||||
|
||||
def ref_with_scale_fmt(
|
||||
E: int,
|
||||
T: int,
|
||||
H: int,
|
||||
group_size: int,
|
||||
tokens_per_expert: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
up: torch.Tensor,
|
||||
scale_fmt: DeepGemmQuantScaleFMT,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The precision types of the operations triggered by this function
|
||||
match closely with the kernel implementation so we compare more
|
||||
accurately.
|
||||
"""
|
||||
scale_dtype = (
|
||||
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
|
||||
)
|
||||
ceil_ue8m0 = scale_fmt in [
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
]
|
||||
|
||||
ref_q = torch.empty((E, T, H), dtype=fp8_dtype, device="cuda")
|
||||
ref_s_f32 = torch.empty(
|
||||
(E, T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
if nt == 0:
|
||||
continue
|
||||
ref_q[e, :nt], ref_s_f32[e, :nt] = silu_mul_quant(
|
||||
gate[e, :nt], up[e, :nt], group_size, ceil_ue8m0=ceil_ue8m0
|
||||
)
|
||||
|
||||
if scale_dtype == torch.float32:
|
||||
return ref_q, ref_s_f32
|
||||
|
||||
assert scale_dtype == torch.int32
|
||||
return ref_q, pack_scales(ref_s_f32, tokens_per_expert)
|
||||
|
||||
|
||||
def token_random(E, T, H2, tokens_per_expert):
|
||||
"""
|
||||
Initialize each token in a random range so we test a range of
|
||||
scale values.
|
||||
"""
|
||||
y = torch.empty((E, T, H2), dtype=torch.bfloat16, device="cuda")
|
||||
for e in range(E):
|
||||
for t in range(tokens_per_expert[e].item()):
|
||||
exp = random.choice(range(1, 20))
|
||||
y[e, t].uniform_(-(2**exp), 2**exp)
|
||||
return y
|
||||
|
||||
|
||||
@pytest.mark.parametrize("E,T,H,fp8_type", CASES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dtype):
|
||||
group_size = 128
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
tokens_per_expert = torch.randint(
|
||||
low=0,
|
||||
high=T,
|
||||
size=(E,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = token_random(E, T, 2 * H, tokens_per_expert)
|
||||
|
||||
gate = y[..., :H].to(torch.bfloat16)
|
||||
up = y[..., H:].to(torch.bfloat16)
|
||||
|
||||
scale_fmts = [
|
||||
DeepGemmQuantScaleFMT.FLOAT32,
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
]
|
||||
|
||||
# Run the SiLU V2 kernel
|
||||
for scale_fmt in scale_fmts:
|
||||
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
||||
y,
|
||||
tokens_per_expert,
|
||||
group_size=group_size,
|
||||
quant_scale_fmt=scale_fmt,
|
||||
)
|
||||
|
||||
ref_y_q, ref_y_s = ref_with_scale_fmt(
|
||||
E, T, H, group_size, tokens_per_expert, gate, up, scale_fmt=scale_fmt
|
||||
)
|
||||
|
||||
# deepgemm scales transform
|
||||
dg_scales = None
|
||||
if (
|
||||
has_deep_gemm()
|
||||
and current_platform.has_device_capability(100)
|
||||
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
||||
):
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
_q, _s = ref_with_scale_fmt(
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
group_size,
|
||||
tokens_per_expert,
|
||||
gate,
|
||||
up,
|
||||
scale_fmt=DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
)
|
||||
dg_scales = transform_sf_into_required_layout(
|
||||
sf=_s,
|
||||
mn=_q.size(1),
|
||||
k=_q.size(2),
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=_q.size(0),
|
||||
is_sfa=True,
|
||||
)
|
||||
|
||||
expected_scale_dtype = (
|
||||
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
|
||||
)
|
||||
assert y_s.dtype == expected_scale_dtype
|
||||
assert ref_y_s.dtype == expected_scale_dtype
|
||||
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
|
||||
torch.testing.assert_close(
|
||||
y_q[e, :nt].to(torch.float32),
|
||||
ref_y_q[e, :nt].to(torch.float32),
|
||||
)
|
||||
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
G = H // group_size
|
||||
y_s_sliced = as_uint8(y_s[e])
|
||||
ref_s_sliced = as_uint8(ref_y_s[e])
|
||||
torch.testing.assert_close(y_s_sliced[:nt, :G], ref_s_sliced[:nt, :G])
|
||||
if dg_scales is not None:
|
||||
dg_sliced = as_uint8(dg_scales[e])
|
||||
torch.testing.assert_close(y_s_sliced[:nt, :G], dg_sliced[:nt, :G])
|
||||
else:
|
||||
torch.testing.assert_close(
|
||||
y_s[e, :nt],
|
||||
ref_y_s[e, :nt],
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
FLOAT8_DTYPE = torch.float8_e4m3fn
|
||||
GROUP_SIZE = 128
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
|
||||
|
||||
# Allocate the scale tensor in column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_DTYPE)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
T, N = x.size()
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
torch.ops._C.silu_and_mul(ref_act_out, x)
|
||||
return reference_quant(ref_act_out, use_ue8m0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("T", [128, 256, 512])
|
||||
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
|
||||
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Test
|
||||
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input, use_ue8m0=use_ue8m0
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_output, ref_output_scales = reference(input, use_ue8m0)
|
||||
|
||||
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
|
||||
torch.testing.assert_close(output_scales, ref_output_scales)
|
||||
170
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
Normal file
170
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_triton_moe_channel_fp8_kernel.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def fp8_mask(a, mask):
|
||||
dtype = a.dtype
|
||||
return a.view(torch.int8)[mask].view(dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
"""This function performs fused moe with per-column int8
|
||||
quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
fp8_mask(a_q, mask),
|
||||
w1[i],
|
||||
fp8_mask(a_s, mask),
|
||||
w1_s[i],
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = ops.scaled_fp8_quant(
|
||||
act_out, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
fp8_min = finfo.min
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
quant_config=fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
),
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.05
|
||||
521
tests/kernels/moe/utils.py
Normal file
521
tests/kernels/moe/utils.py
Normal file
@@ -0,0 +1,521 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
|
||||
def triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
def make_quantized_test_activations(
|
||||
E: int,
|
||||
m: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||
a_q = a
|
||||
a_scale = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
|
||||
"only fp8/int8 supported"
|
||||
)
|
||||
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
||||
a_scale_l = [None] * E
|
||||
for e in range(E):
|
||||
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
||||
a[e], None, quant_dtype, per_act_token_quant, block_shape
|
||||
)
|
||||
a_scale = torch.stack(a_scale_l)
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
a_scale = a_scale.view(E, 1, 1)
|
||||
|
||||
return a, a_q, a_scale
|
||||
|
||||
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
assert (
|
||||
quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8
|
||||
or quant_dtype == "nvfp4"
|
||||
), "only fp8/int8/nvfp4 supported"
|
||||
|
||||
w_gs = None
|
||||
|
||||
if block_shape is not None:
|
||||
assert not per_token_quant
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||
elif quant_dtype == "nvfp4":
|
||||
raise RuntimeError("blocked quantization not supported for nvfp4")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||
else:
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = ops.scaled_int8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant
|
||||
)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = ops.scaled_fp8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant
|
||||
)
|
||||
elif quant_dtype == "nvfp4":
|
||||
assert not per_token_quant
|
||||
w_amax = torch.abs(w).max().to(torch.float32)
|
||||
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
|
||||
w, w_s = ops.scaled_fp4_quant(w, w_gs)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||
|
||||
return w, w_s, w_gs
|
||||
|
||||
|
||||
def make_test_weight(
|
||||
e: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
w_gs = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
w_l = [None] * e
|
||||
w_s_l = [None] * e
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
|
||||
)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
if e > 0 and w_gs_l[0] is not None:
|
||||
w_gs = torch.stack(w_gs_l)
|
||||
if w_s.ndim == 2:
|
||||
assert w_s.shape[-1] == 1
|
||||
w_s = w_s.view(-1, 1, 1)
|
||||
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape
|
||||
n_tiles = (rows + block_n - 1) // block_n
|
||||
k_tiles = (cols + block_k - 1) // block_k
|
||||
assert w_s.shape == (e, n_tiles, k_tiles)
|
||||
else:
|
||||
w = w_16
|
||||
w_s = None
|
||||
w_gs = None
|
||||
|
||||
return w_16, w, w_s, w_gs
|
||||
|
||||
|
||||
def make_test_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
make_gate: bool = True,
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
]:
|
||||
return (
|
||||
make_test_weight(
|
||||
e,
|
||||
(2 if make_gate else 1) * n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_out_ch_quant,
|
||||
),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: int = 128
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def make_test_quant_config(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
make_gate: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
make_gate=make_gate,
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: torch.Tensor | None = None
|
||||
a2_gscale: torch.Tensor | None = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a1_scale = a1_gscale
|
||||
a2_scale = a2_gscale
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
return (
|
||||
w1,
|
||||
w2,
|
||||
FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states, score.float(), topk, renormalize
|
||||
)
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
class TestMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = BaselineMM(w1, out_dtype)
|
||||
self.down_proj = BaselineMM(w2, out_dtype)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_naive_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.nn.Module:
|
||||
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
|
||||
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
|
||||
return TestMLP(w1, w2, out_dtype=in_dtype)
|
||||
|
||||
|
||||
class RealMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
hidden_act: str = "silu",
|
||||
quant_config=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
w1_s: torch.Tensor | None = None,
|
||||
w2_s: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w1, requires_grad=False)
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"input_scale", None
|
||||
) # torch.nn.Parameter(None, requires_grad=False))
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w2, requires_grad=False)
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"input_scale", None
|
||||
) # torch.nn.Parameter(None, requires_grad=False))
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
) -> torch.nn.Module:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
1,
|
||||
N,
|
||||
K,
|
||||
in_dtype=in_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
)
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(in_dtype)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
w1 = w1[0].transpose(0, 1)
|
||||
w2 = w2[0].transpose(0, 1)
|
||||
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
|
||||
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
|
||||
quant_config = Fp8Config(True)
|
||||
else:
|
||||
w1 = w1[0]
|
||||
w2 = w2[0]
|
||||
w1_s = None
|
||||
w2_s = None
|
||||
quant_config = None
|
||||
|
||||
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
315
tests/kernels/quant_utils.py
Normal file
315
tests/kernels/quant_utils.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8FNUZ_MAX = 224.0
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def ref_dynamic_per_token_quant(
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
|
||||
qtype_traits = (
|
||||
torch.iinfo(quant_dtype)
|
||||
if quant_dtype == torch.int8
|
||||
else torch.finfo(quant_dtype)
|
||||
)
|
||||
use_fp8fnuz = (
|
||||
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
|
||||
)
|
||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
# the same operations as in the corresponding fp8 kernel to prevent
|
||||
# rounding errors.
|
||||
|
||||
# Compute scales
|
||||
x_token_max, _ = x.abs().max(dim=-1)
|
||||
x_token_max = as_float32_tensor(x_token_max)
|
||||
if scale_ub is not None:
|
||||
x_token_max = x_token_max.clamp(max=scale_ub)
|
||||
scales = (x_token_max / qtype_max)[:, None]
|
||||
|
||||
# Quant
|
||||
if quant_dtype == torch.int8:
|
||||
iscales = as_float32_tensor(s_1 / scales)
|
||||
torch_out = as_float32_tensor(x) * iscales
|
||||
torch_out = torch_out.round()
|
||||
torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype)
|
||||
else:
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
min_scaling_factor = s_1 / (qtype_max * s_512)
|
||||
scales = scales.clamp(min=min_scaling_factor)
|
||||
torch_out = as_float32_tensor(x) / scales
|
||||
torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype)
|
||||
|
||||
return torch_out, scales
|
||||
|
||||
|
||||
# The int8 version is very similar. Incorporate the int8 version, like in
|
||||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||
# kernel
|
||||
def ref_dynamic_per_tensor_fp8_quant(
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = (
|
||||
ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else fp8_traits.max
|
||||
)
|
||||
fp8_traits_min = (
|
||||
-ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else fp8_traits.min
|
||||
)
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
# the same operations as in the corresponding fp8 kernel to prevent
|
||||
# rounding errors.
|
||||
|
||||
x_max = as_float32_tensor(x.abs().max())
|
||||
ref_scale = x_max / fp8_max
|
||||
ref_iscale = one / ref_scale
|
||||
ref_out = (
|
||||
(as_float32_tensor(x) * ref_iscale)
|
||||
.clamp(fp8_traits_min, fp8_traits_max)
|
||||
.to(FP8_DTYPE)
|
||||
)
|
||||
return ref_out, ref_scale.view(1)
|
||||
|
||||
|
||||
def native_w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
compute_type: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization using native torch.
|
||||
It is agnostic to the input data type and can be used for both int8 and
|
||||
fp8 data types.
|
||||
|
||||
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||
`Bs` (float32).
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
A = A.to(compute_type)
|
||||
B = B.to(compute_type)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N,)
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
|
||||
assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
|
||||
|
||||
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
|
||||
B_tiles = [
|
||||
[
|
||||
B[
|
||||
j * block_n : min((j + 1) * block_n, N),
|
||||
i * block_k : min((i + 1) * block_k, K),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
|
||||
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
def native_per_token_group_quant_fp8(
|
||||
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
|
||||
):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch."""
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
"the last dimension of `x` must be divisible by `group_size`"
|
||||
)
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch.
|
||||
|
||||
It converts the tensor values into int8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
"""
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
"the last dimension of `x` must be divisible by `group_size`"
|
||||
)
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
int8_min = iinfo.min
|
||||
int8_max = iinfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
# Use float32 for scale calculation for stability
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / int8_max
|
||||
x_q = (
|
||||
(x_.to(torch.float32) / x_s).round().clamp(min=int8_min, max=int8_max).to(dtype)
|
||||
) # Round before clamping
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SHAPE = [128, 128]
|
||||
|
||||
|
||||
def per_block_cast_to_int8(
|
||||
x: torch.Tensor,
|
||||
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
block_m, block_n = block_shape
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def dequant(
|
||||
t: torch.Tensor,
|
||||
scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: torch.dtype | None = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
f32 = torch.float32
|
||||
if per_act_token_quant or block_shape is None:
|
||||
return (t.to(f32) * scale).to(out_dtype)
|
||||
else:
|
||||
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
|
||||
else:
|
||||
return t.to(out_dtype)
|
||||
|
||||
|
||||
def batched_dequant(
|
||||
t: torch.Tensor,
|
||||
scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: torch.dtype | None = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
assert t.shape[0] == scale.shape[0]
|
||||
out = torch.empty_like(t, dtype=out_dtype)
|
||||
for e in range(t.shape[0]):
|
||||
out[e] = dequant(
|
||||
t[e], scale[e], block_shape, per_act_token_quant, out_dtype
|
||||
)
|
||||
return out
|
||||
|
||||
return t.to(out_dtype)
|
||||
|
||||
|
||||
def native_batched_masked_quant_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor,
|
||||
A_scale: torch.Tensor | None = None,
|
||||
B_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||
num_experts = num_expert_tokens.size(0)
|
||||
|
||||
for e in range(num_experts):
|
||||
num_tokens = num_expert_tokens_cpu[e]
|
||||
if A.dtype.itemsize == 1 and block_shape is not None:
|
||||
assert A_scale is not None and B_scale is not None
|
||||
tmp = native_w8a8_block_matmul(
|
||||
A[e], B[e], A_scale[e], B_scale[e], block_shape, C.dtype
|
||||
)
|
||||
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||
elif A.dtype.itemsize == 1 and block_shape is None:
|
||||
assert A_scale is not None and B_scale is not None
|
||||
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
|
||||
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
|
||||
C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(
|
||||
C.dtype
|
||||
)
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
|
||||
return C
|
||||
76
tests/kernels/quantization/nvfp4_utils.py
Normal file
76
tests/kernels/quantization/nvfp4_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import scaled_fp4_quant
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
||||
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def get_nvfp4_global_scale(a: torch.Tensor):
|
||||
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
|
||||
|
||||
|
||||
def quant_nvfp4_tensor(a: torch.Tensor):
|
||||
a_global_scale = get_nvfp4_global_scale(a)
|
||||
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
||||
return a_quant, a_block_scale, a_global_scale
|
||||
127
tests/kernels/quantization/test_allspark_gemm.py
Normal file
127
tests/kernels/quantization/test_allspark_gemm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_K_ALIGN,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_AMPERE_N_ALIGN,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
|
||||
return (
|
||||
capability.to_int() >= min_capability and capability.to_int() <= max_capability
|
||||
)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 4, 8),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(48, 16, 24),
|
||||
(67, 13, 88),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
(1033, 9, 17),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
HAS_ZP_OPTS = [False, True]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_gptq_allspark_supported(80, 89),
|
||||
reason="AllSpark Ampere kernel is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
m = m_factor
|
||||
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
|
||||
k = k_factor * ALLSPARK_AMPERE_K_ALIGN
|
||||
|
||||
input = rand_data((m, k), dtype=dtype)
|
||||
weight = rand_data((k, n), dtype=dtype)
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, qw, s, zp = quantize_weights(
|
||||
weight, scalar_types.uint8b128, group_size, has_zp
|
||||
)
|
||||
|
||||
qw = qw.to(torch.uint8)
|
||||
if has_zp:
|
||||
zp = zp.to(dtype)
|
||||
properties = torch.cuda.get_device_properties(qw.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
n_32align = (n + 32 - 1) // 32 * 32
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp)
|
||||
opcheck(
|
||||
torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align),
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.allspark_w8a16_gemm,
|
||||
(
|
||||
input,
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
n,
|
||||
group_size,
|
||||
sm_count,
|
||||
sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp,
|
||||
True,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
input,
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
n,
|
||||
group_size,
|
||||
sm_count,
|
||||
sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp,
|
||||
True,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(input, w_ref)
|
||||
torch.cuda.synchronize()
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
49
tests/kernels/quantization/test_awq.py
Normal file
49
tests/kernels/quantization/test_awq.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.",
|
||||
)
|
||||
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(
|
||||
torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not working; needs investigation.")
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.",
|
||||
)
|
||||
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
input = torch.rand((2, 8192), device="cuda", dtype=torch.float16)
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
qzeros = torch.randint(
|
||||
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))
|
||||
171
tests/kernels/quantization/test_awq_triton.py
Normal file
171
tests/kernels/quantization/test_awq_triton.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the AWQ Triton kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_awq_triton.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
bits = 4
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
reverse_order_tensor = torch.arange(
|
||||
t.shape[-1],
|
||||
dtype=torch.int32,
|
||||
device=t.device,
|
||||
)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
t = t[:, reverse_order_tensor] & 0xF
|
||||
return t
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
||||
) -> torch.Tensor:
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
iweights = reverse_awq_order(iweights)
|
||||
|
||||
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
qweight_dtype = torch.int32
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = qweight_cols * 8
|
||||
scales_dtype = torch.float16
|
||||
zeros_rows = scales_rows
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device)
|
||||
zeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
|
||||
torch.isnan(iweights_triton)
|
||||
)
|
||||
|
||||
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
|
||||
torch.testing.assert_close(iweights_triton, iweights_torch)
|
||||
|
||||
|
||||
# input - [N, K]
|
||||
# qweight - [K, M // 8]
|
||||
# qzeros - [K // G, M // 8]
|
||||
# scales - [K // G, M]
|
||||
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
|
||||
@pytest.mark.parametrize("K", [128])
|
||||
@pytest.mark.parametrize("M", [16, 24, 32])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("splitK", [1, 8])
|
||||
def test_gemm(N, K, M, splitK, group_size):
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
split_k_iters = splitK
|
||||
|
||||
input_rows = N
|
||||
input_cols = K
|
||||
input_dtype = torch.float32
|
||||
qweight_rows = input_cols
|
||||
qweight_cols = M // 8
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = M
|
||||
scales_dtype = torch.float32
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
|
||||
qweight = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device
|
||||
)
|
||||
qzeros = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device
|
||||
)
|
||||
scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device)
|
||||
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
assert not torch.any(torch.isinf(output_triton)) and not torch.any(
|
||||
torch.isnan(output_triton)
|
||||
)
|
||||
|
||||
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
|
||||
output_torch = torch.matmul(input, dequantized_weights)
|
||||
|
||||
assert not torch.any(torch.isinf(output_torch)) and not torch.any(
|
||||
torch.isnan(output_torch)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
|
||||
)
|
||||
207
tests/kernels/quantization/test_block_fp8.py
Normal file
207
tests/kernels/quantization/test_block_fp8.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm,
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_block_cast_to_fp8,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 2050]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
GROUP_SIZE = [64, 128, 512]
|
||||
M = [1, 7, 8, 83, 84, 4096]
|
||||
N = [128, 512, 7168, 7748, 13824]
|
||||
K = [256, 3884, 4096, 13824, 16384]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
torch.manual_seed(seed)
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
assert torch.allclose(scale, ref_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_cutlass_matmul():
|
||||
# Test simple case where weight.shape % 128 != 0,
|
||||
# like in DSV3 kv_a_proj_with_mqa
|
||||
M = 32
|
||||
N = 576
|
||||
K = 7168
|
||||
block_size = [128, 128]
|
||||
out_dtype = torch.bfloat16
|
||||
seed = 0
|
||||
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
# Hopper requires row-major format for scales
|
||||
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
|
||||
|
||||
A_fp8, As = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=False
|
||||
)
|
||||
# CUTLASS uses column-major format for scales
|
||||
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=True
|
||||
)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = cutlass_scaled_mm(
|
||||
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
|
||||
)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
# only aligned sizes are supported by deepgemm
|
||||
if not should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
|
||||
):
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) // 128), (
|
||||
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
)
|
||||
|
||||
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
67
tests/kernels/quantization/test_block_int8.py
Normal file
67
tests/kernels/quantization/test_block_int8.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
w8a8_block_int8_matmul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
A_fp8 = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
B_fp8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
236
tests/kernels/quantization/test_cutlass_2of4_sparse.py
Normal file
236
tests/kernels/quantization/test_cutlass_2of4_sparse.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
# This function checks that applying an identity matrix multiplication
|
||||
# to the compressed weights yields the original uncompressed weights.
|
||||
def check_compress_decompress_invariance(
|
||||
dtype: torch.dtype,
|
||||
b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor,
|
||||
):
|
||||
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
||||
# same dtype as its inputs. This line addresses that constraint while
|
||||
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
||||
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
||||
|
||||
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
|
||||
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(
|
||||
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device="cuda")
|
||||
b = torch.randn((n, k), device="cuda").t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
# ensure A and B aren't all zeros after rounding
|
||||
a = a * 5.0
|
||||
b = b * 5.0
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
check_compress_decompress_invariance(dtype, b, b_compressed, e)
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||
)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 512),
|
||||
(16, 256, 512),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 512),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 512),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(
|
||||
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_int8_gemm(
|
||||
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
682
tests/kernels/quantization/test_cutlass_scaled_mm.py
Normal file
682
tests/kernels/quantization/test_cutlass_scaled_mm.py
Normal file
@@ -0,0 +1,682 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 496),
|
||||
(16, 256, 496),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 496),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 496),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# -1 means full extent in that dimension
|
||||
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||
PER_TOKEN_GROUP_SHAPE = (1, -1)
|
||||
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def group_scale_helper(shape, group_shape):
|
||||
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
group_shape = group_scale_helper(shape, group_shape)
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
b = to_fp8(torch.randn((n, k), device=device).t())
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
# make scales K-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % 4 != 0 and current_platform.has_device_capability(100):
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
torch.bfloat16,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
# For the following two tests:
|
||||
# N and K correspond to the size of the weight matrix and likely to be multiples
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
|
||||
|
||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n,)
|
||||
|
||||
baseline_q = (
|
||||
scale_a.to(device="cpu")
|
||||
* scale_b.to(device="cpu")
|
||||
* ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
|
||||
).to(dtype=out_dtype, device="cuda")
|
||||
|
||||
out = ops.cutlass_scaled_mm(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
|
||||
)
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(
|
||||
m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
|
||||
):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
else:
|
||||
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
|
||||
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
|
||||
)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
|
||||
)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||
atol = 1e-3
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
if azp_per_token:
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
|
||||
)
|
||||
else:
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
|
||||
)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
|
||||
whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
|
||||
a = whole_a[0:m, 0:k]
|
||||
b = whole_b[0:k, 0:n]
|
||||
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class CutlassLayer(torch.nn.Module):
|
||||
def __init__(self, b, scale_a, scale_b, out_dtype):
|
||||
super().__init__()
|
||||
self.b = b
|
||||
self.scale_a = scale_a
|
||||
self.scale_b = scale_b
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm(
|
||||
a, self.b, self.scale_a, self.scale_b, self.out_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
a = to_int8(torch.randn((m, k), device="cuda"))
|
||||
b = to_int8(torch.randn((n, k), device="cuda").t())
|
||||
|
||||
m_a_scales = m if per_act_token else 1
|
||||
n_b_scales = n if per_out_ch else 1
|
||||
|
||||
scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out = model(a)
|
||||
out.zero_()
|
||||
g.replay()
|
||||
|
||||
baseline = torch.mm(
|
||||
scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
|
||||
).to(torch.bfloat16)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_group_gemm(
|
||||
num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
|
||||
# Create separate A, B, C tensors for each group
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
alignment = 16 # 128 // 8
|
||||
# For variation, each group has dimensions
|
||||
n_g = alignment * random.randint(1, 64)
|
||||
k_g = alignment * random.randint(1, 64)
|
||||
for g in range(num_experts):
|
||||
m_g = alignment * random.randint(1, 64)
|
||||
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][0] = m_g
|
||||
problem_sizes[g][1] = n_g
|
||||
problem_sizes[g][2] = k_g
|
||||
|
||||
m_a_scales = m_g if per_act_token else 1
|
||||
n_b_scales = n_g if per_out_ch else 1
|
||||
|
||||
# Create group-specific A and B (FP8) and output (FP16/FP32)
|
||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_tensors_stacked = torch.empty(
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||
a_scales_tensors[g]
|
||||
)
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty(
|
||||
(num_experts, n_b_scales), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros(
|
||||
(expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
ab_strides = torch.full(
|
||||
(num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
ops.cutlass_moe_mm(
|
||||
out_tensors_stacked,
|
||||
a_tensors_stacked,
|
||||
b_tensors_stacked,
|
||||
a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
ab_strides,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
329
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
329
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the CUTLASS W4A8 kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
pack_cols,
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
# TODO(czhu): get supported schedules from fn
|
||||
SCHEDULES = [
|
||||
"128x16_1x1x1",
|
||||
"256x16_1x1x1",
|
||||
"128x32_1x1x1",
|
||||
"256x32_1x1x1",
|
||||
"128x64_1x1x1",
|
||||
"256x64_1x1x1",
|
||||
"128x128_1x1x1",
|
||||
"256x128_1x1x1",
|
||||
"128x256_1x1x1",
|
||||
"128x256_2x1x1",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: torch.Tensor
|
||||
w_ch_s: torch.Tensor
|
||||
w_tok_s: torch.Tensor
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
TypeConfig(
|
||||
act_type=torch.float8_e4m3fn,
|
||||
weight_type=w_type,
|
||||
output_type=o_type,
|
||||
group_scale_type=torch.float8_e4m3fn,
|
||||
channel_scale_type=torch.float32,
|
||||
token_scale_type=torch.float32,
|
||||
)
|
||||
for w_type in [scalar_types.int4]
|
||||
# TODO(czhu): fp16 out type
|
||||
for o_type in [torch.bfloat16]
|
||||
),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
# For testing quantized linear kernels
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w, wtype, group_size=group_size, zero_points=zero_points
|
||||
)
|
||||
|
||||
# since scales are cast to fp8, we need to compute w_ref this way
|
||||
w_ref = (
|
||||
(w_q).to(torch.float32)
|
||||
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||
).to(atype)
|
||||
|
||||
# bit mask prevents sign extending int4 when packing
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype))
|
||||
|
||||
return w_ref, w_q_packed, w_s_packed, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
print(
|
||||
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
|
||||
)
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
w = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
|
||||
)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
# for the practical use case we need per-tok scales for fp8 activations
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
||||
w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type)
|
||||
|
||||
return Tensors(
|
||||
w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
|
||||
|
||||
def mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
output_ref = torch._scaled_mm(
|
||||
tensors.a_ref.to(types.act_type),
|
||||
tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major
|
||||
tensors.w_tok_s.unsqueeze(1),
|
||||
tensors.w_ch_s.unsqueeze(0),
|
||||
out_dtype=types.output_type,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
@pytest.mark.parametrize("schedule", SCHEDULES)
|
||||
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
||||
group_sizes = [128]
|
||||
for group_size in group_sizes:
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class W4A8Layer(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
def test_w4a8_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
b = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
wtype = scalar_types.int4
|
||||
stype = torch.float8_e4m3fn
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
|
||||
)
|
||||
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
||||
w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Construct a trivial model with a single layer that calls the kernel
|
||||
model = W4A8Layer(
|
||||
b_q=w_q_packed,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=w_ch_s,
|
||||
a_token_scales=w_tok_s,
|
||||
)
|
||||
|
||||
output_ref = torch._scaled_mm(
|
||||
a,
|
||||
w_ref.to(a.dtype).t().contiguous().t(), # col major
|
||||
w_tok_s.unsqueeze(1),
|
||||
w_ch_s.unsqueeze(0),
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES)
|
||||
def test_convert_packed_uint4b8_to_signed_int4_inplace(shape):
|
||||
"""
|
||||
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
|
||||
The CUTLASS kernels expect signed int4 packed to int32.
|
||||
This tests checks that the runtime int4b8 -> signed int4 conversion
|
||||
matches the offline conversion step exactly.
|
||||
"""
|
||||
_, N, K = shape
|
||||
# random weights packed to int32
|
||||
t = torch.randint(
|
||||
low=torch.iinfo(torch.int32).min,
|
||||
high=torch.iinfo(torch.int32).max + 1,
|
||||
size=(N, K // 8),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# compute reference
|
||||
unpacked = unpack_quantized_values_into_int32(
|
||||
t.clone(), scalar_types.uint4b8, packed_dim=1
|
||||
)
|
||||
unpacked = unpacked - 8 # int4b8 -> signed int4
|
||||
ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape)
|
||||
|
||||
out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone())
|
||||
|
||||
assert torch.equal(ref, out)
|
||||
assert not torch.equal(ref, t)
|
||||
342
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
342
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
"""
|
||||
Quantize weights into W4 and compute reference dequantized weights.
|
||||
|
||||
Encoding/reordering of weights and packing of scales is deferred
|
||||
until after all experts are combined.
|
||||
"""
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w, wtype, group_size=group_size, zero_points=zero_points
|
||||
)
|
||||
|
||||
# Since scales are later cast to fp8, recompute w_ref in atype here.
|
||||
w_ref = (
|
||||
w_q.to(torch.float32)
|
||||
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||
).to(atype)
|
||||
|
||||
# Bit mask prevents sign extension of int4 when packing.
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
# Make weights row-major (N, K).
|
||||
w_q = w_q.t().contiguous()
|
||||
|
||||
return w_ref, w_q, w_s.to(atype), w_zp
|
||||
|
||||
|
||||
def cutlass_preprocess(
|
||||
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Reorder/encode expert weights and pack scales.
|
||||
|
||||
Returns:
|
||||
w_q_packed: Packed/encoded int4 weights for all experts.
|
||||
w_s_packed: Packed fp8 scales for all experts.
|
||||
packed_layout: Layout/stride metadata for grouped GEMM.
|
||||
"""
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
|
||||
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
|
||||
torch.stack(w_q_experts)
|
||||
) # expects dim 3
|
||||
return w_q_packed, w_s_packed, packed_layout
|
||||
|
||||
|
||||
GROUP_SIZE = 128
|
||||
# (num_experts, N, K)
|
||||
TEST_SHAPES = [
|
||||
(8, 512, 2048),
|
||||
(8, 2048, 2048),
|
||||
(64, 512, 1024),
|
||||
(64, 2048, 2048),
|
||||
(4, 2048, 768),
|
||||
(8, 768, 2048),
|
||||
(64, 1536, 2048),
|
||||
(128, 8192, 4096), # test overflow int32
|
||||
]
|
||||
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoETestSetup:
|
||||
num_experts: int
|
||||
K: int
|
||||
N: int
|
||||
Ms: list[int]
|
||||
M_full: int
|
||||
a: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a_strides: torch.Tensor
|
||||
out: torch.Tensor
|
||||
c_strides: torch.Tensor
|
||||
per_tok_scales: torch.Tensor
|
||||
per_chan_scales: torch.Tensor
|
||||
w_refs: list[torch.Tensor]
|
||||
w_q_packed: torch.Tensor
|
||||
w_s_packed: torch.Tensor
|
||||
problem_sizes: torch.Tensor
|
||||
expert_offsets: torch.Tensor
|
||||
b_strides: torch.Tensor
|
||||
group_scale_strides: torch.Tensor
|
||||
|
||||
|
||||
def make_moe_test_setup(
|
||||
num_experts: int,
|
||||
K: int,
|
||||
N: int,
|
||||
*,
|
||||
alignment: int = ALIGNMENT,
|
||||
max_blocks: int = 64,
|
||||
device: str = "cuda",
|
||||
random_zero: bool = False,
|
||||
) -> MoETestSetup:
|
||||
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
|
||||
|
||||
assert K % GROUP_SIZE == 0
|
||||
# Token counts per expert (multiples of `alignment`).
|
||||
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
|
||||
|
||||
# set random experts to 0 tokens
|
||||
if random_zero and num_experts > 1:
|
||||
num_zero = max(1, num_experts // 8)
|
||||
zero_indices = random.sample(range(num_experts), k=num_zero)
|
||||
for idx in zero_indices:
|
||||
Ms[idx] = 0
|
||||
|
||||
M_full = sum(Ms)
|
||||
assert M_full > 0
|
||||
|
||||
# Activations.
|
||||
a = to_fp8(torch.randn((M_full, K), device=device))
|
||||
a_ref = a.to(torch.float32)
|
||||
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
|
||||
|
||||
# Output buffer.
|
||||
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
|
||||
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
|
||||
|
||||
# Channel/token scales.
|
||||
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
|
||||
per_chan_scales = torch.randn(
|
||||
(num_experts, N, 1), dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Expert weights and scales.
|
||||
wtype = scalar_types.int4
|
||||
atype = stype = torch.float8_e4m3fn
|
||||
w_refs, w_qs, w_ss = [], [], []
|
||||
for _ in range(num_experts):
|
||||
b = to_fp8(torch.randn((K, N), device=device))
|
||||
w_ref, w_q, w_s, _ = cutlass_quantize(
|
||||
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
|
||||
)
|
||||
w_refs.append(w_ref)
|
||||
w_qs.append(w_q)
|
||||
w_ss.append(w_s)
|
||||
|
||||
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
|
||||
|
||||
problem_sizes = torch.tensor(
|
||||
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
expert_offsets = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int64),
|
||||
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
|
||||
]
|
||||
).to(device=device)
|
||||
|
||||
# B strides and group scale strides.
|
||||
b_strides = packed_layout
|
||||
group_scale_strides = torch.zeros(
|
||||
(num_experts, 2), dtype=torch.int64, device=device
|
||||
)
|
||||
group_scale_strides[:, 0] = N
|
||||
|
||||
return MoETestSetup(
|
||||
num_experts=num_experts,
|
||||
K=K,
|
||||
N=N,
|
||||
Ms=Ms,
|
||||
M_full=M_full,
|
||||
a=a,
|
||||
a_ref=a_ref,
|
||||
a_strides=a_strides,
|
||||
out=out,
|
||||
c_strides=c_strides,
|
||||
per_tok_scales=per_tok_scales,
|
||||
per_chan_scales=per_chan_scales,
|
||||
w_refs=w_refs,
|
||||
w_q_packed=w_q_packed,
|
||||
w_s_packed=w_s_packed,
|
||||
problem_sizes=problem_sizes,
|
||||
expert_offsets=expert_offsets,
|
||||
b_strides=b_strides,
|
||||
group_scale_strides=group_scale_strides,
|
||||
)
|
||||
|
||||
|
||||
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
|
||||
"""Compute reference output using torch._scaled_mm per expert."""
|
||||
out_ref = torch.empty_like(setup.out)
|
||||
|
||||
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
|
||||
starts = setup.expert_offsets.cpu().tolist()
|
||||
|
||||
for i in range(setup.num_experts):
|
||||
start, end = starts[i], ends[i]
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
out_ref_i = torch._scaled_mm(
|
||||
setup.a_ref[start:end].to(torch.float8_e4m3fn),
|
||||
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
|
||||
setup.per_tok_scales[start:end], # (M, 1)
|
||||
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
out_ref[start:end] = out_ref_i
|
||||
|
||||
return out_ref
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU,
|
||||
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("shape", TEST_SHAPES)
|
||||
@pytest.mark.parametrize("random_zero", [True, False])
|
||||
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
|
||||
num_experts, N, K = shape
|
||||
current_platform.seed_everything(42)
|
||||
setup = make_moe_test_setup(
|
||||
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
|
||||
)
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
setup.out,
|
||||
setup.a,
|
||||
setup.w_q_packed,
|
||||
setup.per_tok_scales,
|
||||
setup.per_chan_scales,
|
||||
setup.w_s_packed,
|
||||
GROUP_SIZE,
|
||||
setup.expert_offsets,
|
||||
setup.problem_sizes,
|
||||
setup.a_strides,
|
||||
setup.b_strides,
|
||||
setup.c_strides,
|
||||
setup.group_scale_strides,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out_ref = compute_moe_reference_output(setup)
|
||||
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
class W4A8MoELayer(torch.nn.Module):
|
||||
"""
|
||||
Minimal wrapper module to test cuda graphs
|
||||
"""
|
||||
|
||||
def __init__(self, setup: MoETestSetup):
|
||||
super().__init__()
|
||||
self.setup = setup
|
||||
|
||||
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
||||
s = self.setup
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
s.out,
|
||||
a,
|
||||
s.w_q_packed,
|
||||
s.per_tok_scales,
|
||||
s.per_chan_scales,
|
||||
s.w_s_packed,
|
||||
GROUP_SIZE,
|
||||
s.expert_offsets,
|
||||
s.problem_sizes,
|
||||
s.a_strides,
|
||||
s.b_strides,
|
||||
s.c_strides,
|
||||
s.group_scale_strides,
|
||||
)
|
||||
return s.out
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU,
|
||||
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_w4a8_moe_mm_cuda_graph():
|
||||
current_platform.seed_everything(42)
|
||||
# Fixed config for CUDA graph test (single parameter point).
|
||||
num_experts = 8
|
||||
K = 512
|
||||
N = 2048
|
||||
|
||||
setup = make_moe_test_setup(
|
||||
num_experts=num_experts,
|
||||
K=K,
|
||||
N=N,
|
||||
max_blocks=32,
|
||||
)
|
||||
|
||||
# Construct model that calls the grouped GEMM kernel.
|
||||
model = W4A8MoELayer(setup)
|
||||
|
||||
# Build reference output once.
|
||||
out_ref = compute_moe_reference_output(setup)
|
||||
|
||||
# Capture and run the model in a CUDA graph.
|
||||
a_static = setup.a.clone() # static input tensor for graph replay
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out_static = model(a_static)
|
||||
|
||||
out_static.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)
|
||||
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
convert_swizzled_to_linear,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_sf,
|
||||
b_sf,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
backend: str,
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
if backend == "trtllm" and dtype == torch.float16:
|
||||
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
# So instead of needing to swizzle for cutlass as in modelopt.py,
|
||||
# we need to unswizzle for trtllm here.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
if backend == "trtllm":
|
||||
epilogue_tile_m = 128
|
||||
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
|
||||
|
||||
b_scale_interleaved = convert_swizzled_to_linear(
|
||||
b_scale_interleaved, n, k, block_size
|
||||
)
|
||||
b_scale_interleaved = (
|
||||
flashinfer.shuffle_matrix_sf_a(
|
||||
b_scale_interleaved.view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
.reshape(b_scale_interleaved.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp4_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
alpha,
|
||||
dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
|
||||
72
tests/kernels/quantization/test_flashinfer_scaled_mm.py
Normal file
72
tests/kernels/quantization/test_flashinfer_scaled_mm.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp8_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
use_bias: bool,
|
||||
seed: int,
|
||||
device: str,
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, k = shape
|
||||
a = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b = torch.randn((n, k), dtype=dtype, device=device) / k
|
||||
|
||||
a_fp8, a_scale = ops.scaled_fp8_quant(a)
|
||||
b_fp8, b_scale = ops.scaled_fp8_quant(b)
|
||||
|
||||
expected_out = torch.mm(
|
||||
a_scale * a_fp8.to(dtype=torch.float32),
|
||||
b_scale * b_fp8.to(dtype=torch.float32).t(),
|
||||
).to(dtype=dtype)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.randn((n,), dtype=dtype, device=device)
|
||||
expected_out = expected_out + bias
|
||||
else:
|
||||
bias = None
|
||||
|
||||
import flashinfer
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp8_mm(
|
||||
a_fp8,
|
||||
b_fp8.t(),
|
||||
a_scale,
|
||||
b_scale,
|
||||
dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)
|
||||
120
tests/kernels/quantization/test_fp8_quant.py
Normal file
120
tests/kernels/quantization/test_fp8_quant.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (
|
||||
FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant,
|
||||
)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||
NUM_TOKENS = [1, 7, 4096]
|
||||
SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(
|
||||
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
|
||||
):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(input.shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant,
|
||||
(output, input, scale, scale_ub),
|
||||
)
|
||||
else:
|
||||
scale = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = (
|
||||
torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6
|
||||
) # avoid nans
|
||||
|
||||
scale_ub = (
|
||||
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
|
||||
)
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(
|
||||
x, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_scales, ops_scales)
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ref_scale, ops_scale)
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x)
|
||||
|
||||
|
||||
# Regression test for a case with large activations where an int32 index cannot
|
||||
# represent the number of elements.
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_fp8_quant_large(seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
|
||||
hidden_size = 1152 # Smallest hidden_size to reproduce the error
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, _ = ops.scaled_fp8_quant(x, scale)
|
||||
|
||||
# Minimize memory footprint in this test by freeing x and upconverting
|
||||
# the outputs in place. (torch.allclose does not support fp8)
|
||||
del x
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
ops_out = ops_out.to(dtype=dtype)
|
||||
|
||||
torch.testing.assert_close(ref_out, ops_out)
|
||||
166
tests/kernels/quantization/test_fp8_quant_group.py
Normal file
166
tests/kernels/quantization/test_fp8_quant_group.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for QuantFP8 Group Quantization implementation."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,hidden_dim,group_size",
|
||||
[
|
||||
(16, 256, 32), # Small
|
||||
(64, 1024, 64), # Medium
|
||||
(128, 2048, 128), # Large
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(
|
||||
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
|
||||
) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
and verifies consistency between implementations.
|
||||
"""
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
expected_num_groups = (hidden_dim + group_size - 1) // group_size
|
||||
is_divisible = hidden_dim % group_size == 0
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
assert x_quant_native.shape == x.shape
|
||||
assert scales_native.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
|
||||
assert x_quant_cuda.shape == x.shape
|
||||
assert scales_cuda.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# Verify CUDA/native consistency
|
||||
torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
|
||||
|
||||
# Quantized values should mostly match
|
||||
diff_count = (x_quant_cuda != x_quant_native).sum().item()
|
||||
diff_ratio = diff_count / x_quant_cuda.numel()
|
||||
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
group_size = 64
|
||||
|
||||
# Test with 3D input
|
||||
batch1, batch2, hidden_dim = 4, 8, 1024
|
||||
x_3d = (
|
||||
torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
|
||||
* 8
|
||||
)
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test with 4D input
|
||||
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
|
||||
x_4d = (
|
||||
torch.randn(
|
||||
(batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
* 8
|
||||
)
|
||||
|
||||
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
|
||||
assert x_quant_4d.shape == x_4d.shape
|
||||
assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
|
||||
|
||||
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
|
||||
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_edge_cases(seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
batch_size = 16
|
||||
group_size = 64
|
||||
|
||||
# Test with single group (group_size >= hidden_dim)
|
||||
x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(
|
||||
static=False, group_shape=group_shape, column_major_scales=False
|
||||
)
|
||||
|
||||
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
|
||||
assert x_quant_small.shape == x_small.shape
|
||||
assert scales_small.shape == (batch_size, 1)
|
||||
|
||||
# Test with zero inputs
|
||||
x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
|
||||
assert x_quant_zero.shape == x_zero.shape
|
||||
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
|
||||
|
||||
# Test very large values
|
||||
x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
|
||||
assert x_quant_large.shape == x_large.shape
|
||||
# FP8 max is typically 448 or 224, so scales should be > 1
|
||||
assert (scales_large > 1.0).all(), "Large values should have scales > 1"
|
||||
54
tests/kernels/quantization/test_ggml.py
Normal file
54
tests/kernels/quantization/test_ggml.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_type", [12])
|
||||
def test_ggml_opcheck(quant_type):
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
|
||||
shape = [256, 1152]
|
||||
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
|
||||
m = qweight.shape[0]
|
||||
n = qweight.shape[1] // type_size * block_size
|
||||
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16))
|
||||
|
||||
x = torch.rand((m, 512), device="cuda", dtype=torch.float16)
|
||||
opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(
|
||||
torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0])
|
||||
)
|
||||
|
||||
shape = [256, 1024, 336]
|
||||
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
|
||||
x = torch.rand((1, 1024), device="cuda", dtype=torch.float16)
|
||||
sorted_token_ids = torch.arange(776, device="cuda")
|
||||
expert_ids = torch.randint(0, 256, (194,), device="cuda")
|
||||
num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda")
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8,
|
||||
(
|
||||
x,
|
||||
qweight,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
quant_type,
|
||||
qweight.shape[0],
|
||||
1,
|
||||
x.shape[0],
|
||||
),
|
||||
)
|
||||
|
||||
topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8_vec,
|
||||
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]),
|
||||
)
|
||||
207
tests/kernels/quantization/test_gguf.py
Normal file
207
tests/kernels/quantization/test_gguf.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
||||
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
|
||||
|
||||
|
||||
def get_gguf_sample_tensors(
|
||||
hidden_size: int, quant_type: GGMLQuantizationType
|
||||
) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
def get_gguf_MoE_tensors(
|
||||
hidden_size: int, quant_type: GGMLQuantizationType
|
||||
) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE_MOE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# Hidden_size for testing, must match the sample file in HF repo,
|
||||
# we have `hidden_size = 256, 1024` for test in HF repo currently.
|
||||
HIDDEN_SIZES = [256, 1024]
|
||||
NUM_TOKENS = [7, 2050] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
QUANT_TYPES = [
|
||||
# i-matrix
|
||||
GGMLQuantizationType.IQ1_M,
|
||||
GGMLQuantizationType.IQ1_S,
|
||||
GGMLQuantizationType.IQ2_S,
|
||||
GGMLQuantizationType.IQ2_XS,
|
||||
GGMLQuantizationType.IQ3_S,
|
||||
GGMLQuantizationType.IQ3_XXS,
|
||||
GGMLQuantizationType.IQ4_NL,
|
||||
GGMLQuantizationType.IQ4_XS,
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quantization
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_dequantize(
|
||||
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
|
||||
):
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
for tensor in tensors:
|
||||
shape_str = tensor.name.split("_")[-1]
|
||||
shape = map(int, shape_str.split("x"))
|
||||
|
||||
ref_output = torch.tensor(
|
||||
dequantize(tensor.data, quant_type), device="cuda"
|
||||
).to(dtype)
|
||||
output = ops.ggml_dequantize(
|
||||
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_type",
|
||||
[
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quants
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_mmq(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
|
||||
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
|
||||
# test matrix has inputs centered around 0 and lower precision from
|
||||
# bfloat16 tends to accumulate and can greatly inflate rtol
|
||||
# since outputs are also very close to 0
|
||||
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
|
||||
torch.testing.assert_close(
|
||||
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", [512])
|
||||
@pytest.mark.parametrize("top_k", [4, 8])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_moe(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType,
|
||||
top_k: int,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
H, E = 1024, 256
|
||||
|
||||
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
|
||||
topk_ids = torch.randint(
|
||||
0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
|
||||
|
||||
w13 = tensors[0]
|
||||
w2 = tensors[1]
|
||||
|
||||
w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
|
||||
w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype)
|
||||
|
||||
output = _fused_moe_gguf(
|
||||
x,
|
||||
torch.tensor(w13.data, device="cuda"),
|
||||
torch.tensor(w2.data, device="cuda"),
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type,
|
||||
quant_type,
|
||||
"silu",
|
||||
)
|
||||
|
||||
ref_output = fused_experts(
|
||||
x, w13_dequant, w2_dequant, topk_weights, topk_ids
|
||||
).reshape(output.shape)
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
35
tests/kernels/quantization/test_gptq.py
Normal file
35
tests/kernels/quantization/test_gptq.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_gptq_shuffle_opcheck():
|
||||
weight = torch.randint(
|
||||
-2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32
|
||||
)
|
||||
perm = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
|
||||
|
||||
|
||||
def test_gptq_gemm_opcheck():
|
||||
a = torch.rand((240, 4096), device="cuda", dtype=torch.float16)
|
||||
weight = torch.randint(
|
||||
-2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32
|
||||
)
|
||||
zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32)
|
||||
scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16)
|
||||
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
use_exllama = True
|
||||
bit = 4
|
||||
# Test both GPTQv1 and GPTQv2 format
|
||||
opcheck(
|
||||
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit)
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit)
|
||||
)
|
||||
33
tests/kernels/quantization/test_hadacore.py
Normal file
33
tests/kernels/quantization/test_hadacore.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.transform import deterministic_hadamard_matrix
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require hadacore_transform, not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
|
||||
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
|
||||
x = torch.eye(hidden_dim, dtype=dtype, device=device)
|
||||
hadamard = deterministic_hadamard_matrix(
|
||||
hidden_dim, dtype=torch.float64, device="cuda"
|
||||
) / math.sqrt(hidden_dim)
|
||||
|
||||
y = ops.hadacore_transform(x.clone())
|
||||
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
|
||||
assert torch.allclose(y, y_true)
|
||||
|
||||
y = ops.hadacore_transform(y)
|
||||
assert torch.allclose(y, x)
|
||||
155
tests/kernels/quantization/test_int8_kernel.py
Normal file
155
tests/kernels/quantization/test_int8_kernel.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids):
|
||||
"""This function performs fused moe with per-column int8 quantization
|
||||
using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids
|
||||
)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.05
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user