Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -9,58 +9,108 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_mqa_logits,
|
||||
fp8_mqa_logits_torch,
|
||||
fp8_paged_mqa_logits,
|
||||
fp8_paged_mqa_logits_torch,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._xpu_ops import xpu_ops as ops
|
||||
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@torch.inference_mode()
|
||||
def cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, # [num_blocks, block_size, head_dim]
|
||||
dst_value, # [cu_seq_lens[-1], head_dim]
|
||||
block_table, # [batch_size, num_blocks]
|
||||
cu_seq_lens, # [batch_size + 1, ]
|
||||
batch_size,
|
||||
):
|
||||
num_blocks, block_size, _ = kv_cache.shape
|
||||
head_dim = dst_value.shape[-1]
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
|
||||
expected_value = []
|
||||
# expected_scale = []
|
||||
for b in range(batch_size):
|
||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = cdiv(s, block_size)
|
||||
blocks = block_table[b, :tot]
|
||||
|
||||
value = []
|
||||
scale = []
|
||||
full_block = torch.arange(tot - 1,
|
||||
device=kv_cache.device,
|
||||
dtype=torch.int32)
|
||||
non_remaining_value = kv_cache[blocks[full_block], :block_size *
|
||||
head_dim].view(-1, head_dim)
|
||||
# non_remaining_scale = kv_cache[blocks[full_block],
|
||||
# block_size * head_dim:].view(-1, 4)
|
||||
|
||||
remaining = s - (tot - 1) * block_size
|
||||
|
||||
value = torch.cat([
|
||||
non_remaining_value,
|
||||
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
|
||||
],
|
||||
dim=0)
|
||||
# scale = torch.cat([
|
||||
# non_remaining_scale,
|
||||
# kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
||||
# remaining * 4].view(-1, 4)
|
||||
# ],
|
||||
# dim=0)
|
||||
|
||||
expected_value.append(value)
|
||||
# expected_scale.append(scale)
|
||||
|
||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
||||
# gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
||||
gather_value = gather_value.view(torch.bfloat16)
|
||||
# gather_scale = gather_scale.view(torch.float32)
|
||||
dst_value.copy_(gather_value)
|
||||
# dst_scale.copy_(gather_scale)
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Reserve workspace for indexer during profiling run
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
q,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
@@ -74,12 +124,118 @@ def sparse_attn_indexer(
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
ops.indexer_k_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping
|
||||
)
|
||||
|
||||
# topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
logits = ixfops.dsa_indexer_mqa_logits_with_blocks(
|
||||
q[chunk.token_start:chunk.token_end],
|
||||
chunk.cu_seqlens_q,
|
||||
chunk.cu_seq_lens,
|
||||
kv_cache,
|
||||
chunk.block_table,
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
max_q_len=chunk.max_q_len,
|
||||
max_kv_len=chunk.max_kv_len,
|
||||
max_context_len=chunk.max_context_len
|
||||
)
|
||||
ixfops.dsa_update_topk_indices(
|
||||
logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_tokens,
|
||||
topk_indices_buffer[chunk.token_start:chunk.token_end]
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# TODO: support speculative decode
|
||||
if decode_metadata.requires_padding:
|
||||
raise NotImplementedError(
|
||||
"Sparse attention indexer does not support requires_padding"
|
||||
)
|
||||
|
||||
# Use dsa_indexer_mqa_logits_with_blocks similar to prefill
|
||||
logits = ixfops.dsa_indexer_mqa_logits_with_blocks(
|
||||
q[:num_decode_tokens],
|
||||
decode_metadata.cu_seqlens_q,
|
||||
decode_metadata.cu_seqlens_kv,
|
||||
kv_cache,
|
||||
decode_metadata.block_table,
|
||||
weights[:num_decode_tokens],
|
||||
max_q_len=decode_metadata.max_q_len,
|
||||
max_kv_len=decode_metadata.max_kv_len,
|
||||
max_context_len=decode_metadata.max_context_len,
|
||||
)
|
||||
|
||||
ixfops.dsa_update_topk_indices(
|
||||
logits,
|
||||
decode_metadata.cu_seqlen_ks,
|
||||
decode_metadata.cu_seqlen_ke,
|
||||
topk_tokens,
|
||||
topk_indices_buffer[:num_decode_tokens],
|
||||
)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer_original(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Reserve workspace for indexer during profiling run
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q,
|
||||
k,
|
||||
weights,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
# During speculative decoding, k may be padded to the CUDA graph batch
|
||||
# size while slot_mapping only covers actual tokens. Truncate k to avoid
|
||||
# out-of-bounds reads in the kernel.
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
k = k[:num_tokens]
|
||||
|
||||
ops.indexer_k_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
@@ -88,44 +244,42 @@ def sparse_attn_indexer(
|
||||
|
||||
# Get the full shared workspace buffers once (will allocate on first use)
|
||||
workspace_manager = current_workspace_manager()
|
||||
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
k_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.bfloat16),
|
||||
)[0]
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
k = k_full[: chunk.total_seq_lens]
|
||||
# k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
k,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
chunk.num_reqs,
|
||||
)
|
||||
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
logits = ops.ref_mqa_logits(
|
||||
q[chunk.token_start:chunk.token_end],
|
||||
k,
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
clean_logits=False,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||
dim=-1)[1]
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
chunk.cu_seqlen_ks)[:, None] < 0
|
||||
mask = torch.full_like(topk_indices,
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=topk_indices.device)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start:chunk.token_end, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
# Compute lengths from row spans
|
||||
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
|
||||
@@ -147,63 +301,50 @@ def sparse_attn_indexer(
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens
|
||||
)
|
||||
padded_q_decode_tokens = pack_seq_triton(
|
||||
q[:num_decode_tokens], decode_lens)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
||||
)
|
||||
padded_q_decode_tokens = q[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q.shape[1:])
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
batch_size = padded_q_decode_tokens.shape[0]
|
||||
next_n = padded_q_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
logits = ops.ref_paged_mqa_logits(
|
||||
padded_q_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
clean_logits=False,
|
||||
)
|
||||
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
|
||||
if decode_metadata.use_large_context_topk:
|
||||
if next_n == 1:
|
||||
lengths = decode_metadata.seq_lens
|
||||
else:
|
||||
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
|
||||
lengths = (
|
||||
decode_metadata.seq_lens.unsqueeze(1)
|
||||
- next_n
|
||||
+ 1
|
||||
+ decode_metadata.offsets
|
||||
).flatten()
|
||||
|
||||
torch.ops._C.large_context_topk(
|
||||
logits,
|
||||
topk_indices,
|
||||
lengths,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
# padded query len
|
||||
current_device = padded_q_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = torch.arange(max_model_len,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
row_indices = torch.arange(padded_num_tokens,
|
||||
device=current_device) // next_n
|
||||
next_n_offset = torch.arange(
|
||||
padded_num_tokens,
|
||||
device=padded_q_decode_tokens.device) % next_n
|
||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
@@ -211,9 +352,8 @@ def sparse_attn_indexer(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens,
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
@@ -222,11 +362,9 @@ def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
@@ -278,9 +416,12 @@ class SparseAttnIndexer(CustomOp):
|
||||
self.max_model_len = max_model_len
|
||||
self.max_total_seq_len = max_total_seq_len
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
if current_platform.is_cuda() and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
|
||||
if current_platform.is_cuda() and not is_deep_gemm_supported():
|
||||
logger.warning_once(
|
||||
"DeepGEMM is not supported or available. SparseAttnIndexer will use a "
|
||||
"less efficient PyTorch implementation. "
|
||||
"Please make sure you have the required hardware and software setup "
|
||||
"for DeepGEMM to achieve optimal performance."
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
@@ -303,7 +444,7 @@ class SparseAttnIndexer(CustomOp):
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
):
|
||||
@@ -311,11 +452,9 @@ class SparseAttnIndexer(CustomOp):
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
q,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
|
||||
Reference in New Issue
Block a user