Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
648
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Normal file
648
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Normal file
@@ -0,0 +1,648 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
|
||||
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _indexer_k_quant_and_cache_kernel(
|
||||
k_ptr, # [num_tokens, head_dim]
|
||||
kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B]
|
||||
# [n_blocks, blk_size, head_dim]
|
||||
kv_cache_scale_ptr, # [n_blks, blk_size]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
kv_cache_scale_stride,
|
||||
kv_cache_value_stride,
|
||||
block_size,
|
||||
num_tokens,
|
||||
head_dim: tl.constexpr,
|
||||
LAYOUT: tl.constexpr,
|
||||
BLOCK_TILE_SIZE: tl.constexpr,
|
||||
HEAD_TILE_SIZE: tl.constexpr,
|
||||
IS_FNUZ: tl.constexpr,
|
||||
USE_UE8M0: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(0)
|
||||
offset = tl.arange(0, head_dim)
|
||||
if LAYOUT == "SHUFFLE":
|
||||
tile_offset = (
|
||||
offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE
|
||||
+ offset % HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
tile_offset = offset
|
||||
tile_store_offset = tile_offset
|
||||
# for idx in tl.range(tid, num_tokens, n_program):
|
||||
src_ptr = k_ptr + tid * head_dim
|
||||
slot_id = tl.load(slot_mapping_ptr + tid)
|
||||
if slot_id < 0:
|
||||
return
|
||||
block_id = slot_id // block_size
|
||||
block_offset = slot_id % block_size
|
||||
tile_block_id = block_offset // BLOCK_TILE_SIZE
|
||||
tile_block_offset = block_offset % BLOCK_TILE_SIZE
|
||||
val = tl.load(src_ptr + offset)
|
||||
amax = tl.max(val.abs(), axis=-1).to(tl.float32)
|
||||
if IS_FNUZ:
|
||||
scale = tl.maximum(1e-4, amax) / 224.0
|
||||
else:
|
||||
scale = tl.maximum(1e-4, amax) / 448.0
|
||||
|
||||
if USE_UE8M0:
|
||||
scale = tl.exp2(tl.ceil(tl.log2(scale)))
|
||||
|
||||
fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
|
||||
if LAYOUT == "SHUFFLE":
|
||||
dst_ptr = (
|
||||
kv_cache_ptr
|
||||
+ block_id * kv_cache_value_stride
|
||||
+ tile_block_id * BLOCK_TILE_SIZE * head_dim
|
||||
+ tile_block_offset * HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
dst_ptr = (
|
||||
kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim
|
||||
)
|
||||
tl.store(dst_ptr + tile_store_offset, fp8_val)
|
||||
dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset
|
||||
tl.store(dst_scale_ptr, scale)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache_triton(
|
||||
k: torch.Tensor,
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
|
||||
slot_mapping: torch.Tensor,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
block_tile_size=16,
|
||||
head_tile_size=16,
|
||||
):
|
||||
num_blocks = kv_cache.shape[0]
|
||||
head_dim = k.shape[-1]
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
block_size = kv_cache.shape[1]
|
||||
# In real layout, we store the first portion as kv cache value
|
||||
# and second portion as kv cache scale
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
kv_cache_value = kv_cache[:, : block_size * head_dim]
|
||||
kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32)
|
||||
head_tile_size = head_tile_size // kv_cache.element_size()
|
||||
grid = (num_tokens,)
|
||||
_indexer_k_quant_and_cache_kernel[grid](
|
||||
k,
|
||||
kv_cache_value,
|
||||
kv_cache_scale,
|
||||
slot_mapping,
|
||||
kv_cache_scale.stride(0),
|
||||
kv_cache_value.stride(0),
|
||||
block_size,
|
||||
num_tokens,
|
||||
head_dim,
|
||||
"NHD",
|
||||
block_tile_size,
|
||||
head_tile_size,
|
||||
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
|
||||
USE_UE8M0=scale_fmt == "ue8m0",
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cp_gather_indexer_quant_cache_kernel(
|
||||
kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B]
|
||||
# [n_blks, blk_size, head_dim]
|
||||
kv_cache_scale_ptr, # [n_blks, blk_size]
|
||||
k_fp8_ptr, # [num_tokens, head_dim]
|
||||
k_scale_ptr, # [num_tokens]
|
||||
block_table_ptr, # [batch_size, block_table_stride]
|
||||
cu_seqlen_ptr, # [batch_size + 1]
|
||||
token_to_seq_ptr, # [num_tokens]
|
||||
block_size,
|
||||
block_table_stride,
|
||||
kv_cache_stride,
|
||||
kv_cache_scale_stride,
|
||||
LAYOUT: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_TILE_SIZE: tl.constexpr,
|
||||
HEAD_TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(0)
|
||||
offset = tl.arange(0, HEAD_DIM)
|
||||
batch_id = tl.load(token_to_seq_ptr + tid)
|
||||
batch_start = tl.load(cu_seqlen_ptr + batch_id)
|
||||
batch_end = tl.load(cu_seqlen_ptr + batch_id + 1)
|
||||
batch_offset = tid - batch_start
|
||||
if tid >= batch_end:
|
||||
return
|
||||
block_table_id = batch_offset // block_size
|
||||
block_offset = batch_offset % block_size
|
||||
block_table_offset = batch_id * block_table_stride + block_table_id
|
||||
block_id = tl.load(block_table_ptr + block_table_offset)
|
||||
tiled_block_id = block_offset // BLOCK_TILE_SIZE
|
||||
tiled_block_offset = block_offset % BLOCK_TILE_SIZE
|
||||
if LAYOUT == "SHUFFLE":
|
||||
src_cache_offset = (
|
||||
block_id * kv_cache_stride
|
||||
+ tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE
|
||||
+ tiled_block_offset * HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM
|
||||
src_scale_offset = block_id * kv_cache_scale_stride + block_offset
|
||||
dst_offset = tid * HEAD_DIM
|
||||
src_scale_ptr = kv_cache_scale_ptr + src_scale_offset
|
||||
src_cache_ptr = kv_cache_ptr + src_cache_offset
|
||||
dst_k_ptr = k_fp8_ptr + dst_offset
|
||||
scale_val = tl.load(src_scale_ptr)
|
||||
tl.store(k_scale_ptr + tid, scale_val)
|
||||
if LAYOUT == "SHUFFLE":
|
||||
tiled_src_offset = (
|
||||
offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE
|
||||
+ offset % HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
tiled_src_offset = offset
|
||||
val = tl.load(src_cache_ptr + tiled_src_offset)
|
||||
tl.store(dst_k_ptr + offset, val)
|
||||
|
||||
|
||||
def cp_gather_indexer_k_quant_cache_triton(
|
||||
k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
|
||||
k_fp8: torch.Tensor,
|
||||
k_fp8_scale: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
token_to_seq: torch.Tensor,
|
||||
block_tile_size: int = 16,
|
||||
head_tile_size: int = 16,
|
||||
):
|
||||
num_tokens = k_fp8.size(0)
|
||||
block_size = k_cache.size(1)
|
||||
block_table_stride = block_table.stride(0)
|
||||
head_dim = k_fp8.shape[-1]
|
||||
num_blocks = k_cache.shape[0]
|
||||
# we assume the kv cache already been split to 2 portion
|
||||
k_cache = k_cache.view(num_blocks, -1)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype)
|
||||
k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32)
|
||||
grid = (num_tokens,)
|
||||
k_fp8_scale = k_fp8_scale.view(torch.float32)
|
||||
_cp_gather_indexer_quant_cache_kernel[grid](
|
||||
k_cache_value,
|
||||
k_cache_scale,
|
||||
k_fp8,
|
||||
k_fp8_scale,
|
||||
block_table,
|
||||
cu_seqlen,
|
||||
token_to_seq,
|
||||
block_size,
|
||||
block_table_stride,
|
||||
k_cache_value.stride(0),
|
||||
k_cache_scale.stride(0),
|
||||
"NHD",
|
||||
head_dim,
|
||||
block_tile_size,
|
||||
head_tile_size,
|
||||
)
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
|
||||
def fp8_paged_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
batch_size, next_n, _, dim = q.size()
|
||||
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
||||
scale = scale.contiguous().view(torch.float)
|
||||
q = q.float()
|
||||
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
||||
num_block, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_rk * block_size : (block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@functools.lru_cache
|
||||
def paged_mqa_logits_module():
|
||||
paged_mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
||||
|
||||
if paged_mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(paged_mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
aiter_paged_mqa_logits_module = None
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
||||
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
|
||||
# update aiter to the fix PR.
|
||||
aiter_paged_mqa_logits_module = None
|
||||
|
||||
if aiter_paged_mqa_logits_module is not None:
|
||||
deepgemm_fp8_paged_mqa_logits_stage1 = (
|
||||
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
|
||||
)
|
||||
batch_size, next_n, heads, _ = q_fp8.shape
|
||||
out_qk = torch.full(
|
||||
(heads, batch_size * next_n, max_model_len),
|
||||
float("-inf"),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
deepgemm_fp8_paged_mqa_logits_stage1(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
out_qk,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
return out_qk.sum(dim=0)
|
||||
else:
|
||||
return fp8_paged_mqa_logits_torch(
|
||||
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
|
||||
)
|
||||
|
||||
|
||||
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
||||
def fp8_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
kv, scale = kv
|
||||
seq_len_kv = kv.shape[0]
|
||||
k = kv.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
|
||||
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
||||
# path after aiter merge this kernel into main
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@functools.lru_cache
|
||||
def mqa_logits_module():
|
||||
mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
||||
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
||||
|
||||
if mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
aiter_mqa_logits_module = None
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
aiter_mqa_logits_module = mqa_logits_module()
|
||||
|
||||
if aiter_mqa_logits_module is not None:
|
||||
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
|
||||
kv, scale = kv
|
||||
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
else:
|
||||
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
def rocm_aiter_sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty(
|
||||
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
|
||||
)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def rocm_aiter_sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return rocm_aiter_sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty(
|
||||
[chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=fp8_dtype,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
[chunk.total_seq_lens, 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = rocm_fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens
|
||||
)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
||||
)
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
|
||||
logits = rocm_fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens,
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
return topk_indices_buffer
|
||||
Reference in New Issue
Block a user