Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.distributed.parallel_state import get_dcp_group
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
import ixformer.inference.functions as ixf_ops
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.parallel_state import get_dcp_group
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
# "auto",
# "bfloat16",
# ]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_name() -> str:
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
# layer: AttentionLayer,
k_c_normed: torch.Tensor |None = None,
k_pe: torch.Tensor |None = None,
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
k_c_normed: torch.Tensor | None,
k_pe: torch.Tensor | None,
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
B = q_nope.shape[0]
if self.dcp_world_size > 1:
q = torch.cat([q_nope, q_pe], dim=-1)
q = get_dcp_group().all_gather(q, dim=1)
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q_int8, q_scale = ops.quant_kv(q)
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
o,
q_int8,
q_scale,
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
return_softmax_lse=True
)
else:
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
max_context_len=decode_meta.max_decode_seq_len,
return_softmax_lse=True)
return attn_out, softmax_lse
o = torch.empty(B,
self.num_heads,
self.kv_lora_rank,
dtype=q_nope.dtype,
device=q_nope.device)
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q = torch.cat([q_nope, q_pe], dim=-1)
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.use_cuda_graph
)
else:
# fused q concat & cache write
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope,
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph
)
if k_c_normed is None:
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
ixf_ops.vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
use_cuda_graph=decode_meta.use_cuda_graph,
)
else:
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope.contiguous(),
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph,
)
return self._v_up_proj(o), None