Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
# "auto",
|
||||
# "bfloat16",
|
||||
# ]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
# layer: AttentionLayer,
|
||||
k_c_normed: torch.Tensor |None = None,
|
||||
k_pe: torch.Tensor |None = None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
if k_c_normed is None:
|
||||
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
|
||||
ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
else:
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope.contiguous(),
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
|
||||
Reference in New Issue
Block a user