[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
|
||||
from sgl_kernel import (
|
||||
fast_topk_transform_fused,
|
||||
fast_topk_transform_ragged_fused,
|
||||
fast_topk_v2,
|
||||
)
|
||||
|
||||
|
||||
def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor:
|
||||
@@ -26,6 +32,21 @@ def _ref_torch_transform_decode_impl(
|
||||
return topk_indices
|
||||
|
||||
|
||||
def _ref_torch_transform_ragged_impl(
|
||||
score: torch.Tensor,
|
||||
seq_len: int,
|
||||
topk_indices_offset: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
assert score.shape[0] == topk_indices_offset.shape[0]
|
||||
assert seq_len >= topk
|
||||
indices = _ref_torch_impl(score, seq_len, topk)
|
||||
|
||||
mask = indices != -1
|
||||
topk_indices_offset = topk_indices_offset.unsqueeze(1)
|
||||
return torch.where(mask, indices + topk_indices_offset, indices)
|
||||
|
||||
|
||||
MAX_SEQ_LEN = 131072
|
||||
MAX_PERMIT_ERROR = 0
|
||||
|
||||
@@ -37,6 +58,7 @@ def assert_equal(
|
||||
bs: int,
|
||||
k: int,
|
||||
seq_len: int,
|
||||
topk_indices_offset: Optional[torch.Tensor] = None,
|
||||
):
|
||||
indices_our_cpu = indices_our.cpu().tolist()
|
||||
indices_ref_cpu = indices_ref.cpu().tolist()
|
||||
@@ -45,11 +67,13 @@ def assert_equal(
|
||||
indices_our_set_i = set(indices_our_cpu[i])
|
||||
more = indices_our_set_i - indices_ref_set_i
|
||||
less = indices_ref_set_i - indices_our_set_i
|
||||
if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR:
|
||||
offset = topk_indices_offset[i].item() if topk_indices_offset is not None else 0
|
||||
if len(more) > 0 or len(less) > 0:
|
||||
print(f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=}")
|
||||
# check whether more values are the same with less values
|
||||
# if so, either one is acceptable, since their values are the same
|
||||
more_values = sorted(score[i, idx].item() for idx in more)
|
||||
less_values = sorted(score[i, idx].item() for idx in less)
|
||||
more_values = sorted(score[i, idx - offset].item() for idx in more)
|
||||
less_values = sorted(score[i, idx - offset].item() for idx in less)
|
||||
assert (
|
||||
more_values == less_values
|
||||
), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
|
||||
@@ -116,5 +140,52 @@ def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None:
|
||||
assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
|
||||
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
|
||||
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
|
||||
@torch.inference_mode()
|
||||
def test_topk_transform_ragged_kernel(bs: int, k: int, seq_len: int) -> None:
|
||||
# TODO(dark): test prefill kernel, though nothing special
|
||||
MAX_PERMIT_ERROR = 1
|
||||
torch.manual_seed(42)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
# bs: # of q tokens
|
||||
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
|
||||
# kv_len
|
||||
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
|
||||
topk_indices_offset = torch.randint(
|
||||
0, 1024, (bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
dst_page_table_ref = _ref_torch_transform_ragged_impl(
|
||||
score=score,
|
||||
seq_len=seq_len,
|
||||
topk_indices_offset=topk_indices_offset,
|
||||
topk=k,
|
||||
)
|
||||
dst_page_table_our = fast_topk_transform_ragged_fused(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk_indices_offset=topk_indices_offset,
|
||||
topk=k,
|
||||
)
|
||||
|
||||
# sort and compare
|
||||
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
|
||||
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
|
||||
|
||||
assert_equal(
|
||||
score,
|
||||
dst_page_table_ref,
|
||||
dst_page_table_our,
|
||||
bs,
|
||||
k,
|
||||
seq_len,
|
||||
topk_indices_offset,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user