649 lines
22 KiB
Python
649 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
import importlib
|
|
|
|
import torch
|
|
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
|
|
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
|
|
|
if current_platform.is_cuda_alike():
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
@triton.jit
|
|
def _indexer_k_quant_and_cache_kernel(
|
|
k_ptr, # [num_tokens, head_dim]
|
|
kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B]
|
|
# [n_blocks, blk_size, head_dim]
|
|
kv_cache_scale_ptr, # [n_blks, blk_size]
|
|
slot_mapping_ptr, # [num_tokens]
|
|
kv_cache_scale_stride,
|
|
kv_cache_value_stride,
|
|
block_size,
|
|
num_tokens,
|
|
head_dim: tl.constexpr,
|
|
LAYOUT: tl.constexpr,
|
|
BLOCK_TILE_SIZE: tl.constexpr,
|
|
HEAD_TILE_SIZE: tl.constexpr,
|
|
IS_FNUZ: tl.constexpr,
|
|
USE_UE8M0: tl.constexpr,
|
|
):
|
|
tid = tl.program_id(0)
|
|
offset = tl.arange(0, head_dim)
|
|
if LAYOUT == "SHUFFLE":
|
|
tile_offset = (
|
|
offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE
|
|
+ offset % HEAD_TILE_SIZE
|
|
)
|
|
else:
|
|
tile_offset = offset
|
|
tile_store_offset = tile_offset
|
|
# for idx in tl.range(tid, num_tokens, n_program):
|
|
src_ptr = k_ptr + tid * head_dim
|
|
slot_id = tl.load(slot_mapping_ptr + tid)
|
|
if slot_id < 0:
|
|
return
|
|
block_id = slot_id // block_size
|
|
block_offset = slot_id % block_size
|
|
tile_block_id = block_offset // BLOCK_TILE_SIZE
|
|
tile_block_offset = block_offset % BLOCK_TILE_SIZE
|
|
val = tl.load(src_ptr + offset)
|
|
amax = tl.max(val.abs(), axis=-1).to(tl.float32)
|
|
if IS_FNUZ:
|
|
scale = tl.maximum(1e-4, amax) / 224.0
|
|
else:
|
|
scale = tl.maximum(1e-4, amax) / 448.0
|
|
|
|
if USE_UE8M0:
|
|
scale = tl.exp2(tl.ceil(tl.log2(scale)))
|
|
|
|
fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
|
|
if LAYOUT == "SHUFFLE":
|
|
dst_ptr = (
|
|
kv_cache_ptr
|
|
+ block_id * kv_cache_value_stride
|
|
+ tile_block_id * BLOCK_TILE_SIZE * head_dim
|
|
+ tile_block_offset * HEAD_TILE_SIZE
|
|
)
|
|
else:
|
|
dst_ptr = (
|
|
kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim
|
|
)
|
|
tl.store(dst_ptr + tile_store_offset, fp8_val)
|
|
dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset
|
|
tl.store(dst_scale_ptr, scale)
|
|
|
|
|
|
def indexer_k_quant_and_cache_triton(
|
|
k: torch.Tensor,
|
|
kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
|
|
slot_mapping: torch.Tensor,
|
|
quant_block_size,
|
|
scale_fmt,
|
|
block_tile_size=16,
|
|
head_tile_size=16,
|
|
):
|
|
num_blocks = kv_cache.shape[0]
|
|
head_dim = k.shape[-1]
|
|
num_tokens = slot_mapping.shape[0]
|
|
block_size = kv_cache.shape[1]
|
|
# In real layout, we store the first portion as kv cache value
|
|
# and second portion as kv cache scale
|
|
kv_cache = kv_cache.view(num_blocks, -1)
|
|
kv_cache_value = kv_cache[:, : block_size * head_dim]
|
|
kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32)
|
|
head_tile_size = head_tile_size // kv_cache.element_size()
|
|
grid = (num_tokens,)
|
|
_indexer_k_quant_and_cache_kernel[grid](
|
|
k,
|
|
kv_cache_value,
|
|
kv_cache_scale,
|
|
slot_mapping,
|
|
kv_cache_scale.stride(0),
|
|
kv_cache_value.stride(0),
|
|
block_size,
|
|
num_tokens,
|
|
head_dim,
|
|
"NHD",
|
|
block_tile_size,
|
|
head_tile_size,
|
|
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
|
|
USE_UE8M0=scale_fmt == "ue8m0",
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _cp_gather_indexer_quant_cache_kernel(
|
|
kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B]
|
|
# [n_blks, blk_size, head_dim]
|
|
kv_cache_scale_ptr, # [n_blks, blk_size]
|
|
k_fp8_ptr, # [num_tokens, head_dim]
|
|
k_scale_ptr, # [num_tokens]
|
|
block_table_ptr, # [batch_size, block_table_stride]
|
|
cu_seqlen_ptr, # [batch_size + 1]
|
|
token_to_seq_ptr, # [num_tokens]
|
|
block_size,
|
|
block_table_stride,
|
|
kv_cache_stride,
|
|
kv_cache_scale_stride,
|
|
LAYOUT: tl.constexpr,
|
|
HEAD_DIM: tl.constexpr,
|
|
BLOCK_TILE_SIZE: tl.constexpr,
|
|
HEAD_TILE_SIZE: tl.constexpr,
|
|
):
|
|
tid = tl.program_id(0)
|
|
offset = tl.arange(0, HEAD_DIM)
|
|
batch_id = tl.load(token_to_seq_ptr + tid)
|
|
batch_start = tl.load(cu_seqlen_ptr + batch_id)
|
|
batch_end = tl.load(cu_seqlen_ptr + batch_id + 1)
|
|
batch_offset = tid - batch_start
|
|
if tid >= batch_end:
|
|
return
|
|
block_table_id = batch_offset // block_size
|
|
block_offset = batch_offset % block_size
|
|
block_table_offset = batch_id * block_table_stride + block_table_id
|
|
block_id = tl.load(block_table_ptr + block_table_offset)
|
|
tiled_block_id = block_offset // BLOCK_TILE_SIZE
|
|
tiled_block_offset = block_offset % BLOCK_TILE_SIZE
|
|
if LAYOUT == "SHUFFLE":
|
|
src_cache_offset = (
|
|
block_id * kv_cache_stride
|
|
+ tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE
|
|
+ tiled_block_offset * HEAD_TILE_SIZE
|
|
)
|
|
else:
|
|
src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM
|
|
src_scale_offset = block_id * kv_cache_scale_stride + block_offset
|
|
dst_offset = tid * HEAD_DIM
|
|
src_scale_ptr = kv_cache_scale_ptr + src_scale_offset
|
|
src_cache_ptr = kv_cache_ptr + src_cache_offset
|
|
dst_k_ptr = k_fp8_ptr + dst_offset
|
|
scale_val = tl.load(src_scale_ptr)
|
|
tl.store(k_scale_ptr + tid, scale_val)
|
|
if LAYOUT == "SHUFFLE":
|
|
tiled_src_offset = (
|
|
offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE
|
|
+ offset % HEAD_TILE_SIZE
|
|
)
|
|
else:
|
|
tiled_src_offset = offset
|
|
val = tl.load(src_cache_ptr + tiled_src_offset)
|
|
tl.store(dst_k_ptr + offset, val)
|
|
|
|
|
|
def cp_gather_indexer_k_quant_cache_triton(
|
|
k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
|
|
k_fp8: torch.Tensor,
|
|
k_fp8_scale: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
cu_seqlen: torch.Tensor,
|
|
token_to_seq: torch.Tensor,
|
|
block_tile_size: int = 16,
|
|
head_tile_size: int = 16,
|
|
):
|
|
num_tokens = k_fp8.size(0)
|
|
block_size = k_cache.size(1)
|
|
block_table_stride = block_table.stride(0)
|
|
head_dim = k_fp8.shape[-1]
|
|
num_blocks = k_cache.shape[0]
|
|
# we assume the kv cache already been split to 2 portion
|
|
k_cache = k_cache.view(num_blocks, -1)
|
|
fp8_dtype = current_platform.fp8_dtype()
|
|
k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype)
|
|
k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32)
|
|
grid = (num_tokens,)
|
|
k_fp8_scale = k_fp8_scale.view(torch.float32)
|
|
_cp_gather_indexer_quant_cache_kernel[grid](
|
|
k_cache_value,
|
|
k_cache_scale,
|
|
k_fp8,
|
|
k_fp8_scale,
|
|
block_table,
|
|
cu_seqlen,
|
|
token_to_seq,
|
|
block_size,
|
|
block_table_stride,
|
|
k_cache_value.stride(0),
|
|
k_cache_scale.stride(0),
|
|
"NHD",
|
|
head_dim,
|
|
block_tile_size,
|
|
head_tile_size,
|
|
)
|
|
|
|
|
|
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
|
|
def fp8_paged_mqa_logits_torch(
|
|
q: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
max_model_len: int,
|
|
):
|
|
from vllm.utils.math_utils import cdiv
|
|
|
|
fp8_dtype = current_platform.fp8_dtype()
|
|
batch_size, next_n, _, dim = q.size()
|
|
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
|
scale = scale.contiguous().view(torch.float)
|
|
q = q.float()
|
|
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
|
num_block, block_size, _, dim = kv_cache.size()
|
|
logits = torch.full(
|
|
[batch_size * next_n, max_model_len],
|
|
float("-inf"),
|
|
device=q.device,
|
|
dtype=torch.float32,
|
|
)
|
|
context_lens = context_lens.tolist()
|
|
for i in range(batch_size):
|
|
context_len = context_lens[i]
|
|
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
|
weight_slice = (
|
|
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
|
)
|
|
for block_rk in range(cdiv(context_len, block_size)):
|
|
block_idx = block_tables[i][block_rk]
|
|
qx, kx = q[i], kv_cache[block_idx]
|
|
k_offsets = torch.arange(
|
|
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
|
|
)
|
|
mask = (k_offsets[None, :] < context_len) & (
|
|
k_offsets[None, :] <= q_offsets[:, None]
|
|
)
|
|
s = torch.where(
|
|
mask[None, :, :],
|
|
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
|
logits.dtype
|
|
),
|
|
float("-inf"),
|
|
)
|
|
s = torch.relu(s) * weight_slice[..., None]
|
|
s = s.sum(dim=0)
|
|
logits[
|
|
i * next_n : (i + 1) * next_n,
|
|
block_rk * block_size : (block_rk + 1) * block_size,
|
|
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
|
return logits
|
|
|
|
|
|
def rocm_fp8_paged_mqa_logits(
|
|
q_fp8: torch.Tensor,
|
|
kv_cache_fp8: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
schedule_metadata: torch.Tensor,
|
|
max_model_len: int,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits using paged KV-cache.
|
|
|
|
Args:
|
|
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
|
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
|
4 bytes per (block,pos) store the `float` dequant scale.
|
|
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
|
for each batch element.
|
|
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
|
block indices to physical blocks in the paged cache.
|
|
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
|
used to distribute work across SMs.
|
|
max_model_len: Maximum sequence length used to size the logits output.
|
|
|
|
Returns:
|
|
Logits tensor of shape [B * next_n, max_model_len], dtype
|
|
`torch.float32`.
|
|
"""
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
|
|
@functools.lru_cache
|
|
def paged_mqa_logits_module():
|
|
paged_mqa_logits_module_path = None
|
|
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
|
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
|
elif (
|
|
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
|
|
is not None
|
|
):
|
|
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
|
|
|
if paged_mqa_logits_module_path is not None:
|
|
try:
|
|
module = importlib.import_module(paged_mqa_logits_module_path)
|
|
return module
|
|
except ImportError:
|
|
return None
|
|
return None
|
|
|
|
aiter_paged_mqa_logits_module = None
|
|
if rocm_aiter_ops.is_enabled():
|
|
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
|
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
|
|
# update aiter to the fix PR.
|
|
aiter_paged_mqa_logits_module = None
|
|
|
|
if aiter_paged_mqa_logits_module is not None:
|
|
deepgemm_fp8_paged_mqa_logits_stage1 = (
|
|
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
|
|
)
|
|
batch_size, next_n, heads, _ = q_fp8.shape
|
|
out_qk = torch.full(
|
|
(heads, batch_size * next_n, max_model_len),
|
|
float("-inf"),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
)
|
|
deepgemm_fp8_paged_mqa_logits_stage1(
|
|
q_fp8,
|
|
kv_cache_fp8,
|
|
weights,
|
|
out_qk,
|
|
context_lens,
|
|
block_tables,
|
|
max_model_len,
|
|
)
|
|
return out_qk.sum(dim=0)
|
|
else:
|
|
return fp8_paged_mqa_logits_torch(
|
|
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
|
|
)
|
|
|
|
|
|
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
|
def fp8_mqa_logits_torch(
|
|
q: torch.Tensor,
|
|
kv: tuple[torch.Tensor, torch.Tensor],
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
|
|
|
Args:
|
|
q: Query tensor of shape [M, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
|
[N, 1]) with dtype `torch.float32`.
|
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
|
|
Returns:
|
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
|
"""
|
|
kv, scale = kv
|
|
seq_len_kv = kv.shape[0]
|
|
k = kv.to(torch.bfloat16)
|
|
q = q.to(torch.bfloat16)
|
|
|
|
mask_lo = (
|
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
|
)
|
|
mask_hi = (
|
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
|
)
|
|
mask = mask_lo & mask_hi
|
|
|
|
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
|
logits = logits.masked_fill(~mask, float("-inf"))
|
|
|
|
return logits
|
|
|
|
|
|
def rocm_fp8_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv: tuple[torch.Tensor, torch.Tensor],
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
|
|
|
Args:
|
|
q: Query tensor of shape [M, H, D]. Casted to
|
|
`torch.float8_e4m3fn` by caller.
|
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
|
[N, 1]) with dtype `torch.float32`.
|
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
|
shape [M], dtype int32.
|
|
|
|
Returns:
|
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
|
"""
|
|
|
|
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
|
# path after aiter merge this kernel into main
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
|
|
@functools.lru_cache
|
|
def mqa_logits_module():
|
|
mqa_logits_module_path = None
|
|
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
|
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
|
elif (
|
|
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
|
is not None
|
|
):
|
|
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
|
|
|
if mqa_logits_module_path is not None:
|
|
try:
|
|
module = importlib.import_module(mqa_logits_module_path)
|
|
return module
|
|
except ImportError:
|
|
return None
|
|
return None
|
|
|
|
aiter_mqa_logits_module = None
|
|
if rocm_aiter_ops.is_enabled():
|
|
aiter_mqa_logits_module = mqa_logits_module()
|
|
|
|
if aiter_mqa_logits_module is not None:
|
|
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
|
|
kv, scale = kv
|
|
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
else:
|
|
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
|
|
|
|
def rocm_aiter_sparse_attn_indexer_fake(
|
|
hidden_states: torch.Tensor,
|
|
k_cache_prefix: str,
|
|
kv_cache: torch.Tensor,
|
|
q_fp8: torch.Tensor,
|
|
k: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
quant_block_size: int,
|
|
scale_fmt: str | None,
|
|
topk_tokens: int,
|
|
head_dim: int,
|
|
max_model_len: int,
|
|
total_seq_lens: int,
|
|
topk_indices_buffer: torch.Tensor | None,
|
|
) -> torch.Tensor:
|
|
# profile run
|
|
# NOTE(Chen): create the max possible flattened_kv. So that
|
|
# profile_run can get correct memory usage.
|
|
_flattened_kv = torch.empty(
|
|
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
|
|
)
|
|
fp8_dtype = current_platform.fp8_dtype()
|
|
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
|
|
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
|
return topk_indices_buffer
|
|
|
|
|
|
def rocm_aiter_sparse_attn_indexer(
|
|
hidden_states: torch.Tensor,
|
|
k_cache_prefix: str,
|
|
kv_cache: torch.Tensor,
|
|
q_fp8: torch.Tensor,
|
|
k: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
quant_block_size: int,
|
|
scale_fmt: str | None,
|
|
topk_tokens: int,
|
|
head_dim: int,
|
|
max_model_len: int,
|
|
total_seq_lens: int,
|
|
topk_indices_buffer: torch.Tensor | None,
|
|
) -> torch.Tensor:
|
|
# careful! this will be None in dummy run
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
fp8_dtype = current_platform.fp8_dtype()
|
|
# assert isinstance(attn_metadata, dict)
|
|
if not isinstance(attn_metadata, dict):
|
|
return rocm_aiter_sparse_attn_indexer_fake(
|
|
hidden_states,
|
|
k_cache_prefix,
|
|
kv_cache,
|
|
q_fp8,
|
|
k,
|
|
weights,
|
|
quant_block_size,
|
|
scale_fmt,
|
|
topk_tokens,
|
|
head_dim,
|
|
max_model_len,
|
|
total_seq_lens,
|
|
topk_indices_buffer,
|
|
)
|
|
attn_metadata = attn_metadata[k_cache_prefix]
|
|
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
has_decode = attn_metadata.num_decodes > 0
|
|
has_prefill = attn_metadata.num_prefills > 0
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
ops.indexer_k_quant_and_cache(
|
|
k,
|
|
kv_cache,
|
|
slot_mapping,
|
|
quant_block_size,
|
|
scale_fmt,
|
|
)
|
|
|
|
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
|
if has_prefill:
|
|
prefill_metadata = attn_metadata.prefill
|
|
for chunk in prefill_metadata.chunks:
|
|
k_fp8 = torch.empty(
|
|
[chunk.total_seq_lens, head_dim],
|
|
device=k.device,
|
|
dtype=fp8_dtype,
|
|
)
|
|
k_scale = torch.empty(
|
|
[chunk.total_seq_lens, 4],
|
|
device=k.device,
|
|
dtype=torch.uint8,
|
|
)
|
|
|
|
ops.cp_gather_indexer_k_quant_cache(
|
|
kv_cache,
|
|
k_fp8,
|
|
k_scale,
|
|
chunk.block_table,
|
|
chunk.cu_seq_lens,
|
|
)
|
|
|
|
logits = rocm_fp8_mqa_logits(
|
|
q_fp8[chunk.token_start : chunk.token_end],
|
|
(k_fp8, k_scale.view(torch.float32)),
|
|
weights[chunk.token_start : chunk.token_end],
|
|
chunk.cu_seqlen_ks,
|
|
chunk.cu_seqlen_ke,
|
|
)
|
|
num_rows = logits.shape[0]
|
|
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
|
topk_indices = topk_indices_buffer[
|
|
chunk.token_start : chunk.token_end, :topk_tokens
|
|
]
|
|
torch.ops._C.top_k_per_row_prefill(
|
|
logits,
|
|
chunk.cu_seqlen_ks,
|
|
chunk.cu_seqlen_ke,
|
|
topk_indices,
|
|
num_rows,
|
|
logits.stride(0),
|
|
logits.stride(1),
|
|
topk_tokens,
|
|
)
|
|
|
|
if has_decode:
|
|
decode_metadata = attn_metadata.decode
|
|
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
|
# we only have [num_block, block_size, head_dim],
|
|
kv_cache = kv_cache.unsqueeze(-2)
|
|
decode_lens = decode_metadata.decode_lens
|
|
if decode_metadata.requires_padding:
|
|
# pad in edge case where we have short chunked prefill length <
|
|
# decode_threshold since we unstrictly split
|
|
# prefill and decode by decode_threshold
|
|
# (currently set to 1 + speculative tokens)
|
|
padded_q_fp8_decode_tokens = pack_seq_triton(
|
|
q_fp8[:num_decode_tokens], decode_lens
|
|
)
|
|
else:
|
|
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
|
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
|
)
|
|
# TODO: move and optimize below logic with triton kernels
|
|
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
|
next_n = padded_q_fp8_decode_tokens.shape[1]
|
|
assert batch_size == decode_metadata.seq_lens.shape[0]
|
|
num_padded_tokens = batch_size * next_n
|
|
|
|
logits = rocm_fp8_paged_mqa_logits(
|
|
padded_q_fp8_decode_tokens,
|
|
kv_cache,
|
|
weights[:num_padded_tokens],
|
|
decode_metadata.seq_lens,
|
|
decode_metadata.block_table,
|
|
decode_metadata.schedule_metadata,
|
|
max_model_len=max_model_len,
|
|
)
|
|
|
|
num_rows = logits.shape[0]
|
|
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
|
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
|
torch.ops._C.top_k_per_row_decode(
|
|
logits,
|
|
next_n,
|
|
decode_metadata.seq_lens,
|
|
topk_indices,
|
|
num_rows,
|
|
logits.stride(0),
|
|
logits.stride(1),
|
|
topk_tokens,
|
|
)
|
|
|
|
if decode_metadata.requires_padding:
|
|
# if padded, we need to unpack
|
|
# the topk indices removing padded tokens
|
|
topk_indices = unpack_seq_triton(
|
|
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
|
decode_lens,
|
|
)
|
|
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
|
topk_indices
|
|
)
|
|
|
|
return topk_indices_buffer
|