Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -13,7 +13,7 @@ from flashinfer import (
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper,
|
||||
)
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper:
|
||||
):
|
||||
"""Plan the prefill operation with given parameters."""
|
||||
self._context.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu,
|
||||
num_qo_heads * dcp_world_size,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
qo_indptr=qo_indptr_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads * dcp_world_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim_qk=head_dim,
|
||||
page_size=page_size,
|
||||
causal=False, # This is context run
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
@@ -374,13 +374,13 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is not None and capability.major == 10:
|
||||
return "HND"
|
||||
return None
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FIPrefill:
|
||||
@@ -573,20 +573,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
|
||||
# TRTLLM attention requires strictly contiguous KV cache tensors.
|
||||
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
|
||||
# permuted into non-contiguous views, which causes assertion failures.
|
||||
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
|
||||
if can_use_trtllm and self._kv_transfer_enabled:
|
||||
logger.info_once(
|
||||
"TRTLLM attention is disabled because KV transfer "
|
||||
"(P/D disaggregation) is enabled. TRTLLM attention requires "
|
||||
"strictly contiguous KV cache tensors which may not be "
|
||||
"guaranteed with KV transfer."
|
||||
)
|
||||
can_use_trtllm = False
|
||||
|
||||
if (
|
||||
can_use_trtllm
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
@@ -816,6 +802,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
|
||||
self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
|
||||
)
|
||||
return paged_kv_indices
|
||||
|
||||
def build(
|
||||
@@ -860,9 +849,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
|
||||
if self._kv_transfer_enabled:
|
||||
prefill_use_trtllm = False
|
||||
decode_use_trtllm = (
|
||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||
)
|
||||
@@ -997,14 +983,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
paged_kv_last_page_len=[
|
||||
shared_kv_last_page_len_cpu,
|
||||
paged_kv_last_page_len_cpu,
|
||||
],
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
@@ -1082,14 +1071,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
prefill_wrapper.plan(
|
||||
qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_prefill_cpu,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
qo_indptr=qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim_qk=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
@@ -1130,14 +1119,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
decode_wrapper,
|
||||
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len.cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads * self.dcp_world_size,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
indices=paged_kv_indices,
|
||||
last_page_len_cpu=self.paged_kv_last_page_len.cpu[
|
||||
:num_input_tokens
|
||||
],
|
||||
num_qo_heads=self.num_qo_heads * self.dcp_world_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.sm_scale,
|
||||
@@ -1330,32 +1320,15 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
@@ -1599,13 +1572,39 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
return output_padded
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def fast_plan_decode(
|
||||
self, # decode wrapper
|
||||
indptr_cpu: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len_cpu: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
@@ -1642,110 +1641,56 @@ def fast_plan_decode(
|
||||
# this warm up is to generate the _cached_module for the decode wrapper.
|
||||
if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
|
||||
self.plan(
|
||||
indptr_cpu,
|
||||
indices,
|
||||
last_page_len_cpu,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
pos_encoding_mode,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
o_data_type,
|
||||
data_type,
|
||||
sm_scale,
|
||||
rope_scale,
|
||||
rope_theta,
|
||||
non_blocking,
|
||||
None, # block_tables
|
||||
None, # seq_lens
|
||||
fixed_split_size,
|
||||
disable_split_kv,
|
||||
indptr=indptr_cpu,
|
||||
indices=indices,
|
||||
last_page_len=last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_data_type,
|
||||
o_data_type=o_data_type,
|
||||
data_type=data_type,
|
||||
sm_scale=sm_scale,
|
||||
rope_scale=rope_scale,
|
||||
rope_theta=rope_theta,
|
||||
non_blocking=non_blocking,
|
||||
block_tables=None,
|
||||
seq_lens=None,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
self.vllm_first_call = False
|
||||
return
|
||||
|
||||
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
|
||||
|
||||
batch_size = len(last_page_len_cpu)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
# Handle data types consistently
|
||||
if data_type is not None:
|
||||
if q_data_type is None:
|
||||
q_data_type = data_type
|
||||
if kv_data_type is None:
|
||||
kv_data_type = data_type
|
||||
elif q_data_type is None:
|
||||
q_data_type = "float16"
|
||||
|
||||
if kv_data_type is None:
|
||||
kv_data_type = q_data_type
|
||||
q_data_type = (
|
||||
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
||||
fast_decode_plan(
|
||||
self,
|
||||
indptr=indptr_cpu,
|
||||
indices=indices,
|
||||
last_page_len=last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_data_type,
|
||||
data_type=data_type,
|
||||
sm_scale=sm_scale,
|
||||
rope_scale=rope_scale,
|
||||
rope_theta=rope_theta,
|
||||
non_blocking=non_blocking,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
kv_data_type = (
|
||||
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
|
||||
)
|
||||
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime "
|
||||
"batch size {} mismatches the batch size set during "
|
||||
"initialization {}".format(batch_size, self._fixed_batch_size)
|
||||
)
|
||||
if len(indices) > len(self._paged_kv_indices_buf):
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
|
||||
# host-to-device copy for the indptr buffer
|
||||
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
|
||||
# host-to-device copy for the last_page_len buffer
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
|
||||
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 19 arguments for tensor core version
|
||||
args = [
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
window_left,
|
||||
]
|
||||
if self._backend == "fa2":
|
||||
args.append(fixed_split_size)
|
||||
args.append(disable_split_kv)
|
||||
args.append(0) # num_colocated_ctas
|
||||
self._plan_info = self._cached_module.plan(
|
||||
*args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user