diff --git a/README.md b/README.md
index 64e8424..26f500c 100644
--- a/README.md
+++ b/README.md
@@ -125,6 +125,30 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans
|
|
+
+ | DeepSeek-R1 |
+ ✅ |
+ ✅ |
+ |
+ ✅ |
+ |
+
+
+ | DeepSeek-V3 |
+ ✅ |
+ ✅ |
+ |
+ ✅ |
+ |
+
+
+ | DeepSeek-V3.2 |
+ ✅ |
+ ✅ |
+ |
+ ✅ |
+ |
+
diff --git a/vllm_kunlun/__init__.py b/vllm_kunlun/__init__.py
index d0124d5..f4fbbbb 100644
--- a/vllm_kunlun/__init__.py
+++ b/vllm_kunlun/__init__.py
@@ -10,34 +10,15 @@ import vllm.envs as envs
OLD_IMPORT_HOOK = builtins.__import__
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
try:
- start_time = time.time()
-
- # 模块映射表
module_mappings = {
- "vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer",
- "vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe",
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
- "vllm.v1.worker.gpu_model_runner": "vllm_kunlun.v1.worker.gpu_model_runner"
+ "vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils",
+ "vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader",
+ "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
+ "vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
+ "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
}
- # 需要保持原始导入的模块
- original_imports = [
- "vllm.model_executor.layers.fused_moe.base",
- "vllm.model_executor.layers.fused_moe.config",
- "vllm.model_executor.layers.fused_moe.layer"
- ]
-
- if module_name in original_imports:
- if module_name == "vllm.model_executor.layers.fused_moe.layer" and fromlist:
- if "FusedMoEMethodBase" in fromlist:
- return OLD_IMPORT_HOOK(
- module_name,
- globals=globals,
- locals=locals,
- fromlist=fromlist,
- level=level
- )
-
if module_name in module_mappings:
if module_name in sys.modules:
return sys.modules[module_name]
@@ -45,25 +26,6 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
module = importlib.import_module(target_module)
sys.modules[module_name] = module
sys.modules[target_module] = module
- return module
-
- relative_mappings = {
- ("compressed_tensors_moe", "compressed_tensors"): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
- ("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer",
- }
-
- if level == 1:
- parent = globals.get('__package__', '').split('.')[-1] if globals else ''
- key = (module_name, parent)
- if key in relative_mappings:
- if module_name in sys.modules:
- return sys.modules[module_name]
- target_module = relative_mappings[key]
- module = importlib.import_module(target_module)
- sys.modules[module_name] = module
- sys.modules[target_module] = module
- return module
-
except Exception:
pass
@@ -77,79 +39,16 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
def import_hook():
"""Apply import hook for VLLM Kunlun"""
- if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")):
- builtins.__import__ = _custom_import
-
- try:
- modules_to_preload = [
- "vllm_kunlun.ops.quantization.compressed_tensors_moe",
- "vllm_kunlun.ops.fused_moe.custom_ops",
- "vllm_kunlun.ops.fused_moe.layer",
- "vllm_kunlun.ops.quantization.fp8",
- ]
- for module_name in modules_to_preload:
- importlib.import_module(module_name)
- except Exception:
- pass
+ builtins.__import__ = _custom_import
def register():
"""Register the Kunlun platform"""
from .utils import redirect_output
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
- patch_bitsandbytes_loader()
import_hook()
- if envs.VLLM_USE_V1:
- # patch_V1blockTable()
- patch_V1top_p_K()
- # TODO fixed fast top & k for vLLM 0.10.2,
- pass
- else:
- patch_sampler()
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
def register_model():
"""Register models for training and inference"""
from .models import register_model as _reg
- _reg()
-
-# [monkey patach sampler]
-import sys
-import sys, importlib, warnings
-
-def patch_bitsandbytes_loader():
- try:
- # 载入你插件里自定义的 direct_register_custom_op 实现
- custom_utils = importlib.import_module("vllm_kunlun.models.model_loader.bitsandbytes_loader")
- # 覆盖 vllm.utils
- sys.modules["vllm.model_executor.model_loader.bitsandbytes_loader"] = custom_utils
- print("[vllm_kunlun] bitsandbytes_loader patched ->", custom_utils.__file__)
- except Exception as e:
- warnings.warn(f"[vllm_kunlun] bitsandbytes_loader patch failed: {e!r}")
-
-def patch_sampler():
- try:
- custom_sampler = importlib.import_module("vllm_kunlun.ops.sample.sampler")
- sys.modules["vllm.model_executor.layers.sampler"] = custom_sampler
- print("[vllm_kunlun] sampler patched ->", custom_sampler.__file__)
- except Exception as e:
- warnings.warn(f"[vllm_kunlun] sampler patch failed: {e!r}")
-
-
-def patch_V1top_p_K():
- try:
- custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.topk_topp_sampler")
- sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler
- print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__)
- except Exception as e:
- warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}")
-
-def patch_V1blockTable():
- try:
- custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table")
- sys.modules["vllm.v1.worker.block_table"] = custom_sampler
- print("[vllm_kunlun] V1 block table patched ->", custom_sampler.__file__)
- except Exception as e:
- warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}")
-
-# 在模块导入时自动应用补丁
-import_hook()
+ _reg()
\ No newline at end of file
diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py
index 5dd90ec..9fd12c5 100644
--- a/vllm_kunlun/models/__init__.py
+++ b/vllm_kunlun/models/__init__.py
@@ -80,6 +80,14 @@ def register_model():
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
-
+
+ ModelRegistry.register_model(
+ "DeepseekV3ForCausalLM",
+ "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
+
+ ModelRegistry.register_model(
+ "DeepseekV32ForCausalLM",
+ "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
+
def register_quant_method():
"""to do"""
diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py
new file mode 100644
index 0000000..1c3c11a
--- /dev/null
+++ b/vllm_kunlun/models/deepseek_v2.py
@@ -0,0 +1,1445 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The vLLM team.
+# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only DeepseekV2/DeepseekV3 model."""
+import typing
+from collections.abc import Callable, Iterable
+from itertools import islice
+from typing import Any, Optional, Union
+
+import torch
+from torch.library import custom_op
+from torch import nn
+from transformers import DeepseekV2Config, DeepseekV3Config
+
+from vllm_kunlun.ops.attention.layer import Attention
+from vllm.attention.backends.abstract import AttentionBackend
+from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
+ get_current_vllm_config)
+from vllm.distributed import (get_ep_group, get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_gather)
+from vllm.forward_context import get_forward_context
+from vllm.logger import init_logger
+from vllm_kunlun.ops.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
+from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ RowParallelLinear)
+from vllm_kunlun.ops.linear import ReplicatedLinear
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm_kunlun.ops.attention.mla import MLAModules, MultiHeadLatentAttention
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.models.utils import sequence_parallel_chunk
+from vllm.platforms import current_platform
+from vllm.sequence import IntermediateTensors
+from vllm.utils import cdiv, direct_register_custom_op
+from vllm_kunlun.ops.deep_gemm import int8_mqa_logits, int8_paged_mqa_logits
+from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend
+from vllm_kunlun.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
+from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
+
+from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
+from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
+
+if current_platform.is_cuda_alike():
+ from vllm import _custom_ops as ops
+elif current_platform.is_xpu():
+ from vllm._ipex_ops import ipex_ops as ops
+
+import xspeedgate_ops
+_is_kunlun = True
+logger = init_logger(__name__)
+
+class DeepseekV2MLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ reduce_results: bool = True,
+ is_sequence_parallel=False,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ # If is_sequence_parallel, the input and output tensors are sharded
+ # across the ranks within the tp_group. In this case the weights are
+ # replicated and no collective ops are needed.
+ # Otherwise we use standard TP with an allreduce at the end.
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ disable_tp=is_sequence_parallel,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ disable_tp=is_sequence_parallel,
+ prefix=f"{prefix}.down_proj")
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class DeepseekV2MoE(nn.Module):
+
+ def __init__(
+ self,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
+ parallel_config: ParallelConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = self.ep_group.rank()
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts: int = config.n_routed_experts
+ self.n_shared_experts: int = config.n_shared_experts
+
+ self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
+
+ if config.hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {config.hidden_act}. "
+ "Only silu is supported for now.")
+
+ self.gate = ReplicatedLinear(config.hidden_size,
+ config.n_routed_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate")
+ if config.topk_method == "noaux_tc":
+ self.gate.e_score_correction_bias = nn.Parameter(
+ torch.empty(config.n_routed_experts, dtype=torch.float32))
+ else:
+ self.gate.e_score_correction_bias = None
+
+ # Load balancing settings.
+ eplb_config = parallel_config.eplb_config
+ self.enable_eplb = parallel_config.enable_eplb
+
+ self.n_redundant_experts = eplb_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = (self.n_logical_experts +
+ self.n_redundant_experts)
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = (self.ep_rank *
+ self.n_local_physical_experts)
+ self.physical_expert_end = (self.physical_expert_start +
+ self.n_local_physical_experts)
+
+ if config.n_shared_experts is None:
+ self.experts = FusedMoE(
+ num_experts=config.n_routed_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func=config.scoring_func,
+ # we do scaling outside, set factor to 1.0 to avoid double mul
+ routed_scaling_factor=1.0,
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ is_sequence_parallel=self.is_sequence_parallel,
+ )
+ self.shared_experts = None
+ else:
+ intermediate_size = (config.moe_intermediate_size *
+ config.n_shared_experts)
+
+ self.shared_experts = DeepseekV2MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ is_sequence_parallel=self.is_sequence_parallel,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_experts,
+ num_experts=config.n_routed_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func=config.scoring_func,
+ # we do scaling outside, set factor to 1.0 to avoid double mul
+ routed_scaling_factor=1.0,
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ is_sequence_parallel=self.is_sequence_parallel,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ # Chunk the hidden states so they aren't replicated across TP ranks.
+ # This avoids duplicate computation in self.experts.
+ # TODO: We can replace the all_reduce at the end of attn with a
+ # reduce_scatter instead of chunking here.
+ if self.is_sequence_parallel:
+ hidden_states = sequence_parallel_chunk(hidden_states)
+
+ # router_logits: (num_tokens, n_experts)
+ router_logits, _ = self.gate(hidden_states)
+ fused_moe_out = self.experts(hidden_states=hidden_states,
+ router_logits=router_logits)
+
+ if self.shared_experts is not None:
+ shared_output, final_hidden_states = fused_moe_out
+ else:
+ shared_output = None
+ final_hidden_states = fused_moe_out
+
+ # Fix FP16 overflow
+ # See DeepseekV2DecoderLayer for more details.
+ if hidden_states.dtype != torch.float16:
+ final_hidden_states *= self.routed_scaling_factor
+ elif self.shared_experts is not None:
+ assert shared_output is not None
+ shared_output *= (1. / self.routed_scaling_factor)
+
+ if self.shared_experts is not None:
+ assert shared_output is not None
+ final_hidden_states += shared_output
+
+ if self.is_sequence_parallel:
+ final_hidden_states = tensor_model_parallel_all_gather(
+ final_hidden_states, 0)
+ final_hidden_states = final_hidden_states[:num_tokens]
+ elif self.tp_size > 1:
+ final_hidden_states = (
+ self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states))
+
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
+ import math
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+class DeepseekV2Attention(nn.Module):
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
+ hidden_size: int,
+ num_heads: int,
+ qk_nope_head_dim: int,
+ qk_rope_head_dim: int,
+ v_head_dim: int,
+ q_lora_rank: int,
+ kv_lora_rank: int,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embeddings: int = 8192,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ topk_indices_buffer: Optional[torch.Tensor] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+ self.num_heads = num_heads
+ tp_size = get_tensor_model_parallel_world_size()
+ assert num_heads % tp_size == 0
+ self.num_local_heads = num_heads // tp_size
+ self.scaling = self.qk_head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ assert topk_indices_buffer is None, "topk_indices_buffer is not \
+ supported for DeepseekV2Attention"
+
+ if self.q_lora_rank is not None:
+ self.q_a_proj = ReplicatedLinear(self.hidden_size,
+ self.q_lora_rank,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_a_proj")
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank,
+ eps=config.rms_norm_eps)
+ self.q_b_proj = ColumnParallelLinear(q_lora_rank,
+ self.num_heads *
+ self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_b_proj")
+ else:
+ self.q_proj = ColumnParallelLinear(self.hidden_size,
+ self.num_heads *
+ self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_proj")
+
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_a_proj_with_mqa")
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
+ eps=config.rms_norm_eps)
+ self.kv_b_proj = ColumnParallelLinear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_b_proj")
+ # O projection.
+ self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
+ if rope_scaling:
+ rope_scaling["rope_type"] = 'deepseek_yarn'
+
+ self.rotary_emb = get_rope(qk_rope_head_dim,
+ rotary_dim=qk_rope_head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ is_neox_style=False)
+
+ if rope_scaling:
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
+ scaling_factor = rope_scaling["factor"]
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
+ self.scaling = self.scaling * mscale * mscale
+
+ self.attn = Attention(self.num_local_heads,
+ self.qk_head_dim,
+ self.scaling,
+ num_kv_heads=self.num_local_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.q_lora_rank is not None:
+ q = self.q_a_proj(hidden_states)[0]
+ q = self.q_a_layernorm(q)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
+ self.qk_head_dim)
+ else:
+ q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
+ self.qk_head_dim)
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
+ dim=-1)
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
+ kv_a, _ = latent_cache.split(
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+ latent_cache = latent_cache.unsqueeze(1)
+ kv_a = self.kv_a_layernorm(kv_a)
+ kv = self.kv_b_proj(kv_a)[0]
+ kv = kv.view(-1, self.num_local_heads,
+ self.qk_nope_head_dim + self.v_head_dim)
+ k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+ k_pe = latent_cache[:, :, self.kv_lora_rank:]
+
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+
+ q[..., self.qk_nope_head_dim:] = q_pe
+ k = torch.empty_like(q)
+ k[..., :self.qk_nope_head_dim] = k_nope
+ k[..., self.qk_nope_head_dim:] = k_pe
+ # padding value to qk_head_dim for alignment
+ v = torch.nn.functional.pad(
+ v, [0, self.qk_head_dim - self.v_head_dim],
+ value=0).view(-1, self.num_local_heads * self.qk_head_dim)
+ attn_output = self.attn(q, k, v)
+ attn_output = attn_output.view(
+ -1, self.num_local_heads,
+ self.qk_head_dim)[..., :self.v_head_dim].reshape(
+ -1, self.num_local_heads * self.v_head_dim)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+@torch.inference_mode()
+def cp_gather_indexer_k_quant_cache(
+ kv_cache, # [num_blocks, block_size, head_dim + 1]
+ block_table, # [batch_size, num_blocks]
+ cu_seq_lens, # [batch_size + 1, ]
+ batch_size,
+ head_dim,
+):
+ num_blocks, block_size, _ = kv_cache.shape
+ kv_cache = kv_cache.view(num_blocks, -1)
+
+ expected_value = []
+ expected_scale = []
+ for b in range(batch_size):
+ s = cu_seq_lens[b + 1] - cu_seq_lens[b]
+ if s == 0:
+ continue
+ tot = cdiv(s, block_size)
+ blocks = block_table[b, :tot]
+
+ value = []
+ scale = []
+ full_block = torch.arange(tot - 1,
+ device=kv_cache.device,
+ dtype=torch.int32)
+ non_remaining_value = kv_cache[blocks[full_block], :block_size *
+ head_dim].view(-1, head_dim)
+ non_remaining_scale = kv_cache[blocks[full_block],
+ block_size * head_dim:].view(-1, 4)
+
+ remaining = s - (tot - 1) * block_size
+
+ value = torch.cat([
+ non_remaining_value,
+ kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
+ ],
+ dim=0)
+ scale = torch.cat([
+ non_remaining_scale,
+ kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
+ remaining * 4].view(-1, 4)
+ ],
+ dim=0)
+
+ expected_value.append(value)
+ expected_scale.append(scale)
+
+ gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
+ gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
+ gather_value = gather_value.view(torch.int8)
+ gather_scale = gather_scale.view(torch.float32)
+ return gather_value, gather_scale
+
+@torch.inference_mode()
+def kunlun_indexer_k_quant_cache(
+ k, #[num_tokens, head_dim]
+ kv_cache, # [num_blocks, cache_block_size, head_dim + 1]
+ slot_mapping, # [num_tokens]
+ quant_block_size,
+):
+ num_blocks, cache_block_size, cache_stride = kv_cache.shape
+ # num_tokens, head_dim = k.shape
+ head_dim = k.shape[1]
+ num_tokens = slot_mapping.shape[0]
+ assert head_dim % quant_block_size == 0
+ kv_cache = kv_cache.view(num_blocks, -1)
+
+ k_fp8 = torch.empty(
+ k.shape,
+ device=k.device,
+ dtype=torch.int8,
+ )
+ k_scale = torch.empty(
+ [k.shape[0], 1],
+ device=k.device,
+ dtype=torch.float32,
+ )
+
+ torch.ops._C.quant2d(k, k_fp8, k_scale, force_sdnn=True)
+ k_scale /= 127
+ for token_idx in range(num_tokens):
+ slot_idx = slot_mapping[token_idx]
+ if slot_idx < 0:
+ continue
+ block_idx = slot_idx // cache_block_size
+ block_offset = slot_idx % cache_block_size
+ v_offset = block_offset * head_dim
+ kv_cache[block_idx, v_offset:v_offset + head_dim] = k_fp8[token_idx, :].view(torch.uint8).contiguous()
+ s_offset = cache_block_size * head_dim + block_offset * 4
+ kv_cache[block_idx, s_offset:s_offset + 4] = k_scale[token_idx, :].view(torch.uint8).contiguous()
+ kv_cache = kv_cache.view(num_blocks, cache_block_size, cache_stride)
+
+@custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=())
+def sparse_attn_indexer_vllm_kunlun(
+ hidden_states: torch.Tensor,
+ k_cache_prefix: str,
+ kv_cache: torch.Tensor,
+ q_fp8: torch.Tensor,
+ k: torch.Tensor,
+ weights: torch.Tensor,
+ quant_block_size: int,
+ scale_fmt: Optional[str],
+ topk_tokens: int,
+ head_dim: int,
+ max_model_len: int,
+ total_seq_lens: int,
+ topk_indices_buffer: Optional[torch.Tensor],
+) -> None:
+
+ # careful! this will be None in dummy run
+ attn_metadata = get_forward_context().attn_metadata
+ # assert isinstance(attn_metadata, dict)
+ if not isinstance(attn_metadata, dict):
+ sparse_attn_indexer_vllm_kunlun_fake(
+ hidden_states,
+ k_cache_prefix,
+ kv_cache,
+ q_fp8,
+ k,
+ weights,
+ quant_block_size,
+ scale_fmt,
+ topk_tokens,
+ head_dim,
+ max_model_len,
+ total_seq_lens,
+ topk_indices_buffer,
+ )
+ return
+ attn_metadata = attn_metadata[k_cache_prefix]
+ assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
+ slot_mapping = attn_metadata.slot_mapping
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+
+ # kunlun_indexer_k_quant_cache(
+ # k,
+ # kv_cache,
+ # slot_mapping,
+ # quant_block_size,
+ # )
+
+ torch.ops.xspeedgate_ops.indexer_k_quant_and_cache(
+ k,
+ kv_cache,
+ slot_mapping,
+ quant_block_size,
+ scale_fmt,
+ )
+ topk_indices_buffer[:hidden_states.shape[0]] = -1
+ if has_prefill:
+ prefill_metadata = attn_metadata.prefill
+ for chunk in prefill_metadata.chunks:
+ k_fp8, k_scale = cp_gather_indexer_k_quant_cache(
+ kv_cache,
+ chunk.block_table,
+ chunk.cu_seq_lens,
+ chunk.num_reqs,
+ head_dim,
+ )
+
+ logits = int8_mqa_logits(
+ q_fp8[chunk.token_start:chunk.token_end],
+ (k_fp8, k_scale),
+ weights[chunk.token_start:chunk.token_end],
+ chunk.cu_seqlen_ks,
+ chunk.cu_seqlen_ke,
+ )
+ topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
+ dim=-1)[1]
+ topk_indices -= chunk.cu_seqlen_ks[:, None]
+ mask_lo = topk_indices >= 0
+ mask_hi = topk_indices - (chunk.cu_seqlen_ke -
+ chunk.cu_seqlen_ks)[:, None] < 0
+ mask = torch.full_like(topk_indices,
+ False,
+ dtype=torch.bool,
+ device=topk_indices.device)
+ mask = mask_lo & mask_hi
+ topk_indices = topk_indices.masked_fill(~mask, -1)
+ topk_indices_buffer[
+ chunk.token_start:chunk.token_end, :topk_indices.
+ shape[-1]] = topk_indices.to(dtype=torch.int32)
+
+ if has_decode:
+ decode_metadata = attn_metadata.decode
+ # kv_cache size requirement [num_block, block_size, n_head, head_dim],
+ # we only have [num_block, block_size, head_dim],
+ kv_cache = kv_cache.unsqueeze(-2)
+ decode_lens = decode_metadata.decode_lens
+ if decode_metadata.requires_padding:
+ # pad in edge case where we have short chunked prefill length <
+ # decode_threshold since we unstrictly split
+ # prefill and decode by decode_threshold
+ # (currently set to 1 + speculative tokens)
+ padded_q_fp8_decode_tokens = pack_seq_triton(
+ q_fp8[:num_decode_tokens], decode_lens)
+ else:
+ padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
+ decode_lens.shape[0], -1, *q_fp8.shape[1:])
+ # TODO: move and optimize below logic with triton kernels
+ batch_size = padded_q_fp8_decode_tokens.shape[0]
+ next_n = padded_q_fp8_decode_tokens.shape[1]
+ assert batch_size == decode_metadata.seq_lens.shape[0]
+ num_padded_tokens = batch_size * next_n
+ logits = int8_paged_mqa_logits(
+ padded_q_fp8_decode_tokens,
+ kv_cache,
+ weights[:num_padded_tokens],
+ decode_metadata.seq_lens,
+ decode_metadata.seq_lens_cpu,
+ decode_metadata.block_table,
+ decode_metadata.schedule_metadata,
+ max_model_len=max_model_len,
+ )
+ # padded query len
+ current_device = padded_q_fp8_decode_tokens.device
+ padded_num_tokens = batch_size * next_n
+ positions = torch.arange(max_model_len,
+ device=current_device).unsqueeze(0).expand(
+ batch_size * next_n, -1)
+ row_indices = torch.arange(padded_num_tokens,
+ device=current_device) // next_n
+ next_n_offset = torch.arange(
+ padded_num_tokens,
+ device=padded_q_fp8_decode_tokens.device) % next_n
+ index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
+ next_n_offset).unsqueeze(1)
+ # index_end_pos: [B * N, 1]
+ mask = positions <= index_end_pos
+ # mask: [B * N, L]
+ logits = logits.masked_fill(~mask, float('-inf'))
+ topk_indices = logits.topk(topk_tokens,
+ dim=-1)[1].to(torch.int32) # [B * N, K]
+ # ensure we don't set indices for the top k
+ # that is out of range(masked already)
+ # this will happen if context length is shorter than K
+ topk_indices[topk_indices > index_end_pos] = -1
+ if decode_metadata.requires_padding:
+ # if padded, we need to unpack
+ # the topk indices removing padded tokens
+ topk_indices = unpack_seq_triton(
+ topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
+ decode_lens)
+ topk_indices_buffer[:num_decode_tokens, :topk_indices.
+ shape[-1]] = topk_indices.to(dtype=torch.int32)
+
+ # return topk_indices_buffer
+
+
+def sparse_attn_indexer_vllm_kunlun_fake(
+ hidden_states: torch.Tensor,
+ k_cache_prefix: str,
+ kv_cache: torch.Tensor,
+ q_fp8: torch.Tensor,
+ k: torch.Tensor,
+ weights: torch.Tensor,
+ quant_block_size: int,
+ scale_fmt: Optional[str],
+ topk_tokens: int,
+ head_dim: int,
+ max_model_len: int,
+ total_seq_lens: int,
+ topk_indices_buffer: Optional[torch.Tensor],
+) -> None:
+ return
+
+sparse_attn_indexer_vllm_kunlun.register_fake(sparse_attn_indexer_vllm_kunlun_fake)
+
+class Indexer(nn.Module):
+
+ def __init__(self,
+ vllm_config: VllmConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
+ hidden_size: int,
+ q_lora_rank: int,
+ quant_config: Optional[QuantizationConfig],
+ cache_config: Optional[CacheConfig],
+ topk_indices_buffer: Optional[torch.Tensor],
+ prefix: str = ""):
+ super().__init__()
+ self.vllm_config = vllm_config
+ self.config = config
+ # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
+ self.topk_tokens = config.index_topk
+ self.n_head = config.index_n_heads # 64
+ self.head_dim = config.index_head_dim # 128
+ self.rope_dim = config.qk_rope_head_dim # 64
+ self.q_lora_rank = q_lora_rank # 1536
+ # no tensor parallel, just replicated
+ self.wq_b = ReplicatedLinear(self.q_lora_rank,
+ self.head_dim * self.n_head,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.wq_b")
+ self.wk = ReplicatedLinear(hidden_size,
+ self.head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.wk")
+ self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
+ self.weights_proj = ReplicatedLinear(hidden_size,
+ self.n_head,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.weights_proj")
+ self.softmax_scale = self.head_dim**-0.5
+ self.scale_fmt = "ue8m0"
+ self.quant_block_size = 128 # TODO: get from config
+ self.topk_indices_buffer = topk_indices_buffer
+
+ # NOTE: (zyongye) we use fp8 naive cache,
+ # where we store value in fp8 and scale in fp32
+ # per self.quant_block_size element
+ self.k_cache = DeepseekV32IndexerCache(
+ head_dim=self.head_dim +
+ self.head_dim // self.quant_block_size * 4,
+ dtype=torch.uint8,
+ prefix=f"{prefix}.k_cache",
+ cache_config=cache_config)
+ self.max_model_len = vllm_config.model_config.max_model_len
+ if self.max_model_len % cache_config.block_size != 0: #由于I8_paged_mqa_logits输入参数的限制,最大长度必须为block_zise的整数倍
+ self.max_model_len = self.max_model_len + cache_config.block_size - (self.max_model_len % cache_config.block_size)
+ self.prefix = prefix
+ from vllm.v1.attention.backends.mla.indexer import (
+ get_max_prefill_buffer_size)
+ self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
+
+ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
+ rotary_emb) -> torch.Tensor:
+ q, _ = self.wq_b(qr)
+ q = q.view(-1, self.n_head, self.head_dim)
+ q_pe, q_nope = torch.split(
+ q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
+
+ k, _ = self.wk(hidden_states)
+ k = self.k_norm(k)
+ k_pe, k_nope = torch.split(
+ k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
+
+ q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
+ q = torch.cat([q_pe, q_nope], dim=-1)
+ k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
+
+ # we only quant q here since k quant is fused with cache insertion
+ q = q.view(-1, self.head_dim)
+ q_fp8 = torch.empty(
+ q.shape,
+ device=q.device,
+ dtype=torch.int8,
+ )
+ q_scale = torch.empty(
+ [q.shape[0], 1],
+ device=q.device,
+ dtype=torch.float32,
+ )
+ torch.ops._C.quant2d(q, q_fp8, q_scale, force_sdnn=True)
+ q_scale /= 127
+ q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
+ q_scale = q_scale.view(-1, self.n_head)
+ weights, _ = self.weights_proj(hidden_states)
+ weights = weights * self.n_head**-0.5
+ weights = weights * q_scale * self.softmax_scale
+
+ torch.ops.vllm.sparse_attn_indexer_vllm_kunlun(
+ hidden_states,
+ self.k_cache.prefix,
+ self.k_cache.kv_cache[0],
+ q_fp8,
+ k,
+ weights,
+ self.quant_block_size,
+ self.scale_fmt,
+ self.topk_tokens,
+ self.head_dim,
+ self.max_model_len,
+ self.max_total_seq_len,
+ self.topk_indices_buffer,
+ )
+ return self.topk_indices_buffer
+
+
+class DeepseekV2MLAAttention(nn.Module):
+ """
+ Main reference: DeepseekV2 paper, and FlashInfer Implementation
+ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
+
+ For more info see MLACommonImpl in:
+ vllm/v1/attention/backends/mla/utils.py
+ """
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ config: Union[DeepseekV2Config, DeepseekV3Config],
+ hidden_size: int,
+ num_heads: int,
+ qk_nope_head_dim: int,
+ qk_rope_head_dim: int,
+ v_head_dim: int,
+ q_lora_rank: Optional[int],
+ kv_lora_rank: int,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embeddings: int = 8192,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ topk_indices_buffer: Optional[torch.Tensor] = None,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+
+ self.num_heads = num_heads
+ tp_size = get_tensor_model_parallel_world_size()
+ assert num_heads % tp_size == 0
+ self.num_local_heads = num_heads // tp_size
+
+ self.scaling = self.qk_head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+
+ if self.q_lora_rank is not None:
+ self.fused_qkv_a_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fused_qkv_a_proj",
+ disable_tp=True)
+ else:
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_a_proj_with_mqa")
+
+ if self.q_lora_rank is not None:
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank,
+ eps=config.rms_norm_eps)
+ self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
+ self.num_heads *
+ self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_b_proj")
+ else:
+ self.q_proj = ColumnParallelLinear(self.hidden_size,
+ self.num_heads *
+ self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_proj")
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
+ eps=config.rms_norm_eps)
+ self.kv_b_proj = ColumnParallelLinear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_b_proj")
+ self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
+
+ if rope_scaling:
+ rope_scaling["rope_type"] = 'deepseek_yarn'
+ self.rotary_emb = get_rope(qk_rope_head_dim,
+ rotary_dim=qk_rope_head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ is_neox_style=False)
+ if rope_scaling:
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
+ scaling_factor = rope_scaling["factor"]
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
+ self.scaling = self.scaling * mscale * mscale
+
+ self.is_v32 = hasattr(config, "index_topk")
+
+ if self.is_v32:
+ self.indexer = Indexer(vllm_config, config, hidden_size,
+ q_lora_rank, quant_config, cache_config,
+ topk_indices_buffer, f"{prefix}.indexer")
+ else:
+ self.indexer = None
+
+ mla_modules = MLAModules(
+ kv_a_layernorm=self.kv_a_layernorm,
+ kv_b_proj=self.kv_b_proj,
+ rotary_emb=self.rotary_emb,
+ o_proj=self.o_proj,
+ fused_qkv_a_proj=self.fused_qkv_a_proj
+ if self.q_lora_rank is not None else None,
+ kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
+ if self.q_lora_rank is None else None,
+ q_a_layernorm=self.q_a_layernorm
+ if self.q_lora_rank is not None else None,
+ q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
+ q_proj=self.q_proj if self.q_lora_rank is None else None,
+ indexer=self.indexer,
+ is_sparse=self.is_v32,
+ topk_indices_buffer=topk_indices_buffer,
+ )
+
+ self.mla_attn = MultiHeadLatentAttention(
+ self.hidden_size,
+ self.num_local_heads,
+ self.scaling,
+ self.qk_nope_head_dim,
+ self.qk_rope_head_dim,
+ self.v_head_dim,
+ self.q_lora_rank,
+ self.kv_lora_rank,
+ mla_modules,
+ cache_config,
+ quant_config,
+ prefix,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.mla_attn(positions, hidden_states)
+
+
+class DeepseekV2DecoderLayer(nn.Module):
+
+ def __init__(self,
+ vllm_config: VllmConfig,
+ prefix: str,
+ topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ parallel_config = vllm_config.parallel_config
+
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings",
+ 8192)
+ # DecoderLayers are created with `make_layers` which passes the prefix
+ # with the layer's index.
+ layer_idx = int(prefix.split(sep='.')[-1])
+ self.layer_idx = layer_idx
+ if model_config.use_mla:
+ attn_cls = DeepseekV2MLAAttention
+ else:
+ attn_cls = DeepseekV2Attention
+ self.self_attn = attn_cls(
+ vllm_config=vllm_config,
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ qk_nope_head_dim=config.qk_nope_head_dim,
+ qk_rope_head_dim=config.qk_rope_head_dim,
+ v_head_dim=config.v_head_dim,
+ q_lora_rank=config.q_lora_rank
+ if hasattr(config, "q_lora_rank") else None,
+ kv_lora_rank=config.kv_lora_rank,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ topk_indices_buffer=topk_indices_buffer,
+ )
+
+ if (config.n_routed_experts is not None
+ and layer_idx >= config.first_k_dense_replace
+ and layer_idx % config.moe_layer_freq == 0):
+ self.mlp = DeepseekV2MoE(
+ config=config,
+ parallel_config=parallel_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ else:
+ self.mlp = DeepseekV2MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states.clone()
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ if hidden_states.dtype == torch.float16:
+ # Fix FP16 overflow
+ # We scale both hidden_states and residual before
+ # rmsnorm, and rmsnorm result would not affect by scale.
+ hidden_states *= 1. / self.routed_scaling_factor
+ if self.layer_idx == 0:
+ # The residual is shared by all layers, we only scale it on
+ # first layer.
+ residual *= 1. / self.routed_scaling_factor
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ if isinstance(self.mlp,
+ DeepseekV2MLP) and hidden_states.dtype == torch.float16:
+ # Fix FP16 overflow
+ # Scaling the DeepseekV2MLP output, it is the input of
+ # input_layernorm of next decoder layer.
+ # The scaling of DeepseekV2MOE output would be done in the forward
+ # of DeepseekV2MOE
+ hidden_states *= 1. / self.routed_scaling_factor
+
+ return hidden_states, residual
+
+
+@support_torch_compile
+class DeepseekV2Model(nn.Module):
+
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+ self.is_v32 = hasattr(config, "index_topk")
+ if self.is_v32:
+ topk_tokens = config.index_topk
+ topk_indices_buffer = torch.empty(
+ vllm_config.scheduler_config.max_num_batched_tokens,
+ topk_tokens,
+ dtype=torch.int32,
+ device="cuda")
+ else:
+ topk_indices_buffer = None
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens")
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
+ topk_indices_buffer),
+ prefix=f"{prefix}.layers")
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in islice(self.layers, self.start_layer, self.end_layer):
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts,
+ SupportsLoRA):
+ packed_modules_mapping = {
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+
+ # `packed_modules_mapping` needs to be modified before
+ # initializing DeepseekV2Model, as it is passed inplace to
+ # quantization config init and may be used to select the
+ # quant_method for relevant layers during initialization.
+ self.fuse_qkv_a_proj = hasattr(
+ config, "q_lora_rank") and config.q_lora_rank is not None
+ if self.fuse_qkv_a_proj:
+ self.packed_modules_mapping["fused_qkv_a_proj"] = [
+ "q_a_proj",
+ "kv_a_proj_with_mqa",
+ ]
+
+ self.model = DeepseekV2Model(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+ self.expert_weights = []
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = (config.num_hidden_layers -
+ config.first_k_dense_replace)
+ self.num_expert_groups = config.n_group
+
+ self.moe_layers: list[FusedMoE] = []
+ example_moe = None
+ for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
+ assert isinstance(layer, DeepseekV2DecoderLayer)
+ if isinstance(layer.mlp, DeepseekV2MoE):
+ # Pick last one layer since the first ones may be dense layers.
+ example_moe = layer.mlp
+ self.moe_layers.append(layer.mlp.experts)
+
+ if example_moe is None:
+ raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
+
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_shared_experts = example_moe.n_shared_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+ def set_eplb_state(
+ self,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ for layer_idx, layer in enumerate(self.moe_layers):
+ # Register the expert weights.
+ self.expert_weights.append(layer.get_expert_weights())
+ layer.set_eplb_state(
+ moe_layer_idx=layer_idx,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ )
+
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = (num_physical_experts -
+ self.num_logical_experts)
+ for layer in self.model.layers:
+ if isinstance(layer.mlp, DeepseekV2MoE):
+ moe = layer.mlp
+ moe.n_local_physical_experts = num_local_physical_experts
+ moe.n_physical_experts = num_physical_experts
+ moe.n_redundant_experts = self.num_redundant_experts
+ moe.experts.update_expert_map()
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ("fused_qkv_a_proj", "q_a_proj", 0),
+ ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts,
+ num_redundant_experts=self.num_redundant_experts)
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
+
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name_mapped = name.replace(weight_name, param_name)
+
+ # QKV fusion is optional, fall back to normal
+ # weight loading if it's not enabled
+ # if go with fusion option, then update name
+ if ((param_name == "fused_qkv_a_proj")
+ and name_mapped not in params_dict):
+ continue
+ else:
+ name = name_mapped
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+ if name_mapped not in params_dict:
+ continue
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or not
+ # here since otherwise we may skip experts with other
+ # available replicas.
+ weight_loader = typing.cast(Callable[..., bool],
+ param.weight_loader)
+ success = weight_loader(param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True)
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
+ pass
+
+
+# Compatibility with
+# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
+def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
+ DeepseekV3Config],
+ weight_name: str) -> Optional[int]:
+ if (hasattr(config, "num_nextn_predict_layers")
+ and config.num_nextn_predict_layers > 0):
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_nextn_predict_layers):
+ if weight_name.startswith(f"model.layers.{layer_idx+i}."):
+ return layer_idx + i
+ return None
\ No newline at end of file
diff --git a/vllm_kunlun/models/gpt_oss.py b/vllm_kunlun/models/gpt_oss.py
index 532718d..33e97b9 100644
--- a/vllm_kunlun/models/gpt_oss.py
+++ b/vllm_kunlun/models/gpt_oss.py
@@ -16,7 +16,7 @@ from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
-from vllm_kunlun.ops.fused_moe.layer import FusedMoE
+from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
@@ -176,7 +176,7 @@ class MLPBlock(torch.nn.Module):
x = sequence_parallel_chunk(x)
g = self.router(x)
- x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight)
+ x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
diff --git a/vllm_kunlun/models/mimo_v2_flash.py b/vllm_kunlun/models/mimo_v2_flash.py
index 0dfcb86..e381b24 100644
--- a/vllm_kunlun/models/mimo_v2_flash.py
+++ b/vllm_kunlun/models/mimo_v2_flash.py
@@ -21,7 +21,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
-from vllm_kunlun.ops.fused_moe.layer import FusedMoE
+from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -185,8 +185,7 @@ class MiMoV2MoE(nn.Module):
gate_input = hidden_states
router_logits = self.gate(gate_input)
final_hidden_states = self.experts(
- hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight
- )
+ hidden_states=hidden_states, router_logits=router_logits)
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py
index fec8bbb..d6da6ae 100644
--- a/vllm_kunlun/ops/__init__.py
+++ b/vllm_kunlun/ops/__init__.py
@@ -20,4 +20,8 @@ import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.vocab_parallel_embedding
-import vllm_kunlun.ops.linear
\ No newline at end of file
+import vllm_kunlun.ops.linear
+import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass
+import vllm_kunlun.ops.vocab_parallel_embedding
+import vllm_kunlun.ops.quantization.compressed_tensors_moe
+import vllm_kunlun.ops.fused_moe.layer
\ No newline at end of file
diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py
index 6250964..94a875e 100644
--- a/vllm_kunlun/ops/_kunlun_ops.py
+++ b/vllm_kunlun/ops/_kunlun_ops.py
@@ -417,7 +417,6 @@ class KunlunOps:
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
- linear_weights: torch.Tensor,
ep_rank: int,
moe_top_k: int,
renormalize: bool,
diff --git a/vllm_kunlun/ops/activation.py b/vllm_kunlun/ops/activation.py
index 0526acf..6719808 100644
--- a/vllm_kunlun/ops/activation.py
+++ b/vllm_kunlun/ops/activation.py
@@ -108,7 +108,7 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
- xtorch_ops.swiglu(x, out)
+ torch.ops._C.silu_and_mul(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
diff --git a/vllm_kunlun/ops/attention/flashmla.py b/vllm_kunlun/ops/attention/flashmla.py
new file mode 100644
index 0000000..f53dfcf
--- /dev/null
+++ b/vllm_kunlun/ops/attention/flashmla.py
@@ -0,0 +1,260 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
+from typing import Optional, Tuple
+
+import torch
+
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+import xtorch_ops
+
+logger = init_logger(__name__)
+
+if current_platform.is_cuda():
+ try:
+ import vllm._flashmla_C # noqa: F401
+ _flashmla_C_AVAILABLE = True
+ except ImportError:
+ _flashmla_C_AVAILABLE = False
+else:
+ _flashmla_C_AVAILABLE = False
+
+if current_platform.is_cuda():
+ try:
+ import vllm._flashmla_extension_C # noqa: F401
+ _flashmla_extension_C_AVAILABLE = True
+ except ImportError:
+ _flashmla_extension_C_AVAILABLE = False
+else:
+ _flashmla_extension_C_AVAILABLE = False
+
+
+def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
+ """
+ Return: is_supported_flag, unsupported_reason (optional).
+ """
+ return True, None
+
+def get_mla_metadata(
+ cache_seqlens: torch.Tensor,
+ num_heads_per_head_k: int = 1,
+ num_heads_k: int = 1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Arguments:
+ cache_seqlens: (batch_size), dtype torch.int32.
+ num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
+ num_heads_k: num_heads_k.
+
+ Returns:
+ tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
+ num_splits: (batch_size + 1), dtype torch.int32.
+ """
+ # return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
+ cache_seqlens_cpu = cache_seqlens.cpu()
+ return cache_seqlens_cpu, cache_seqlens
+
+def flash_mla_with_kvcache(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ head_dim_v: int,
+ tile_scheduler_metadata: torch.Tensor,
+ num_splits: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ descale_q: Optional[torch.Tensor] = None,
+ descale_k: Optional[torch.Tensor] = None,
+ is_fp8_kvcache: bool = False,
+ indices: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Arguments:
+ q: (batch_size, seq_len_q, num_heads_q, head_dim).
+ k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
+ block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
+ cache_seqlens: (batch_size), torch.int32.
+ head_dim_v: Head dimension of v.
+ tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
+ num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
+ softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
+ causal: bool. Whether to apply causal attention mask.
+
+ Returns:
+ out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
+ softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
+ """
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ softmax_lse = None
+ out = torch.ones(q.size(0), q.size(1), q.size(2), head_dim_v, dtype= q.dtype, device=q.device)
+ kv_lora_rank = head_dim_v
+ qk_rope_head_dim = q.size(3) - head_dim_v
+ head_dim = k_cache.shape[3]
+ page_block_size = k_cache.shape[1]
+ k_cache = k_cache.view(-1, 1, page_block_size, head_dim)
+
+ # todo: optimize memcp
+ # q_c = q[..., : kv_lora_rank].contiguous()
+ # q_r = q[..., kv_lora_rank :].contiguous()
+
+ is_context = False
+ vo_head_dim = -1
+
+ xtorch_ops.paged_attention(out,
+ q,
+ k_cache, None,
+ block_table,
+ tile_scheduler_metadata, # context_lens_cpu
+ num_splits, # context_lens_xpu
+ is_context,
+ causal,
+ vo_head_dim,
+ kv_lora_rank,
+ qk_rope_head_dim,
+ softmax_scale,
+ q_r=q)
+ return out, softmax_lse
+
+def kunlun_flash_mla_with_kvcache(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ cache_seqlens_cpu: torch.Tensor,
+ head_dim_v: int,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ is_fp8_kvcache: bool = False,
+ indices: Optional[torch.Tensor] = None,
+ max_seq_kv: int = 1,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Arguments:
+ q: (batch_size, seq_len_q, num_heads_q, head_dim).
+ k_cache: (num_tokens_kv, head_dim).
+ cache_seqlens: (batch_size), torch.int32.
+ head_dim_v: Head dimension of v.
+ softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
+ causal: bool. Whether to apply causal attention mask.
+ is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format.
+ indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
+ max_seq_kv: seq中最大的kv长度
+
+ Returns:
+ out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
+ max_logits: (batch_size, seq_len_q, num_heads_q), torch.float32.
+ p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
+ """
+ assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
+ assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+ if indices is not None:
+ # NOTE (zyongye): sparse attention is also causal
+ # since it only attend to the tokens before
+ # but here `causal` should not be specified
+ assert not causal, \
+ "causal must be `false` if sparse attention is enabled."
+
+ q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式
+ batch_size, seq_len_q, num_heads_q, head_dim = q.shape
+ kv_lora_rank = head_dim_v
+ rope_head_dim = head_dim - kv_lora_rank
+
+ out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
+ dtype=q.dtype, device=q.device)
+ max_logits = torch.zeros([batch_size, seq_len_q, num_heads_q],
+ dtype=torch.float32, device=q.device)
+ p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
+ dtype=torch.float32, device=q.device)
+
+ xtorch_ops.fwd_kvcache_mla(
+ q_c=q,
+ kv_cache=k_cache,
+ indices=indices,
+ kv_lod_cpu=cache_seqlens_cpu,
+ max_seq_kv=max_seq_kv,
+ softmax_scale=softmax_scale,
+ # q_r=q_r,
+ # pe_cache=pe_cache,
+ out=out,
+ max_logits=max_logits,
+ p_sums=p_sums,
+ kv_lod_xpu=cache_seqlens,
+ )
+
+ return out, max_logits, p_sums
+
+
+def flash_mla_sparse_prefill(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ indices: torch.Tensor,
+ sm_scale: float,
+ q_lod_xpu: torch.Tensor,
+ q_lod_cpu: torch.Tensor,
+ d_v: int = 512,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Sparse attention prefill kernel
+
+ Args:
+ - q: [s_q, h_q, d_qk], bfloat16
+ - kv: [s_kv, d_qk], bfloat16
+ - indices: [s_q, h_kv, topk], int32.
+ Invalid indices should be set to -1 or numbers >= s_kv
+ - sm_scale: float
+ - q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长).
+ - d_v: The dimension of value vectors. Can only be 512
+
+ Returns:
+ - (output, max_logits, lse)
+ About the definition of output,
+ max_logits and lse, please refer to README.md
+ - output: [s_q, h_q, d_v], bfloat16
+ - max_logits: [s_q, h_q], float
+ - lse: [s_q, h_q], float, 2-based log-sum-exp
+ """
+ s_q, h_q, d_qk = q.shape
+
+ out = torch.zeros([s_q, h_q, d_v], dtype=q.dtype, device=q.device)
+ max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
+ lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
+
+ xtorch_ops.sparse_prefill_fwd_opt(
+ q=q,
+ kv=kv,
+ indices=indices,
+ qlod_cpu=q_lod_cpu,
+ qlod_xpu=q_lod_xpu,
+ kvlod_cpu=q_lod_cpu,
+ kvlod_xpu=q_lod_xpu,
+ sm_scale=sm_scale,
+ d_v=d_v,
+ is_causal=True, #aiak这个值为true,这是为啥
+ out=out,
+ max_logits=max_logits,
+ lse=lse,
+ )
+
+ # NOTE: Compared with torch.ops._flashmla_C.sparse_prefill_fwd,
+ # out_scale = 1 / math.log2(math.e)
+ # gpu_max_logits * out_scale = kunlun_lse
+ # gpu_lse * out_scale = kunlun_lse
+ return out, max_logits, lse
+
+
+#
+# TODO: Add fake functions
+#
+# @register_fake("_flashmla_C::get_mla_metadata")
+# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
+# return ....
+#
+# @register_fake("_flashmla_C::fwd_kvcache_mla")
+# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
+# return ....
+#
\ No newline at end of file
diff --git a/vllm_kunlun/ops/attention/mla.py b/vllm_kunlun/ops/attention/mla.py
new file mode 100644
index 0000000..e50ac2f
--- /dev/null
+++ b/vllm_kunlun/ops/attention/mla.py
@@ -0,0 +1,180 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+from vllm_kunlun.ops.attention.layer import Attention
+# from vllm.attention import Attention
+from vllm.config import CacheConfig
+from vllm.model_executor.custom_op import CustomOp
+from vllm.model_executor.layers.quantization import QuantizationConfig
+
+
+@dataclass
+class MLAModules:
+ """Modules used in MLA.
+ """
+ kv_a_layernorm: torch.nn.Module
+ kv_b_proj: torch.nn.Module
+ rotary_emb: torch.nn.Module
+ o_proj: torch.nn.Module
+ fused_qkv_a_proj: Optional[torch.nn.Module]
+ kv_a_proj_with_mqa: Optional[torch.nn.Module]
+ q_a_layernorm: Optional[torch.nn.Module]
+ q_b_proj: Optional[torch.nn.Module]
+ q_proj: Optional[torch.nn.Module]
+ indexer: Optional[torch.nn.Module]
+ is_sparse: bool
+ topk_indices_buffer: Optional[torch.Tensor]
+
+
+@CustomOp.register("vllm_kunlun_multi_head_latent_attention")
+class MultiHeadLatentAttention(CustomOp):
+ """MLA layer registered as CustomOp.
+ Note that currently MLA ignores the enable/disable mechanism of CustomOp
+ because there is only one in-tree implementation in forward_native.
+ TODO: implement this with a new PluggableLayer mechanism.
+
+ This class takes positions and hidden_states as input.
+ The input tensors can either contain prefill tokens or decode tokens.
+ The class does the following:
+
+ 1. MLA Preprocess.
+ 2. Perform multi-head attention to prefill tokens and
+ multi-query attention to decode tokens separately.
+ 3. Return the output tensor.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ scale: float,
+ qk_nope_head_dim: int,
+ qk_rope_head_dim: int,
+ v_head_dim: int,
+ q_lora_rank: Optional[int],
+ kv_lora_rank: int,
+ mla_modules: MLAModules,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+ self.num_heads = num_heads
+ self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
+ self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
+ self.q_a_layernorm = mla_modules.q_a_layernorm
+ self.q_b_proj = mla_modules.q_b_proj
+ self.q_proj = mla_modules.q_proj
+ self.kv_a_layernorm = mla_modules.kv_a_layernorm
+ self.kv_b_proj = mla_modules.kv_b_proj
+ self.rotary_emb = mla_modules.rotary_emb
+ self.o_proj = mla_modules.o_proj
+ self.indexer = mla_modules.indexer
+ self.is_sparse = mla_modules.is_sparse
+
+ if self.indexer is not None:
+ assert hasattr(self.indexer, "topk_tokens")
+ self.topk_tokens = self.indexer.topk_tokens
+ self.topk_indices_buffer = mla_modules.topk_indices_buffer
+
+ # In the MLA backend, kv_cache includes both k_c and
+ # pe (i.e. decoupled position embeddings). In particular,
+ # the concat_and_cache_mla op requires
+ # k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
+ # i.e.
+ # kv_lora_rank + qk_rope_head_dim == head_size
+ self.mla_attn = Attention(
+ num_heads=self.num_heads,
+ head_size=self.kv_lora_rank + self.qk_rope_head_dim,
+ scale=scale,
+ num_kv_heads=1,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ use_mla=True,
+ use_sparse=mla_modules.is_sparse,
+ # MLA Args
+ q_lora_rank=self.q_lora_rank,
+ kv_lora_rank=self.kv_lora_rank,
+ qk_nope_head_dim=self.qk_nope_head_dim,
+ qk_rope_head_dim=self.qk_rope_head_dim,
+ qk_head_dim=self.qk_head_dim,
+ v_head_dim=self.v_head_dim,
+ kv_b_proj=self.kv_b_proj,
+ indexer=self.indexer,
+ )
+
+ self.prefix = prefix
+
+ def forward_native(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ q_c = None
+ kv_lora = None
+
+ if self.q_lora_rank is not None:
+ assert self.fused_qkv_a_proj is not None, \
+ "fused_qkv_a_proj is required when q_lora_rank is not None"
+ assert self.q_a_layernorm is not None, \
+ "q_a_layernorm is required when q_lora_rank is not None"
+ assert self.q_b_proj is not None, \
+ "q_b_proj is required when q_lora_rank is not None"
+ qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
+ q_c, kv_lora = qkv_lora.split(
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
+ dim=-1,
+ )
+ q_c = self.q_a_layernorm(q_c)
+ q = self.q_b_proj(q_c)[0]
+ else:
+ assert self.kv_a_proj_with_mqa is not None, \
+ "kv_a_proj_with_mqa is required when q_lora_rank is None"
+ assert self.q_proj is not None, \
+ "q_proj is required when q_lora_rank is None"
+ kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
+ q = self.q_proj(hidden_states)[0]
+
+ kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
+ dim=-1)
+ kv_c_normed = self.kv_a_layernorm(kv_c)
+
+ q = q.view(-1, self.num_heads, self.qk_head_dim)
+ # Add head dim of 1 to k_pe
+ k_pe = k_pe.unsqueeze(1)
+
+ q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
+ positions, q[..., self.qk_nope_head_dim:], k_pe)
+
+ if self.indexer and self.is_sparse:
+ _topk_indices = self.indexer(hidden_states, q_c, positions,
+ self.rotary_emb)
+
+ hidden_states_shape_0 = 0
+ if isinstance(hidden_states, tuple):
+ x_q, x_scale = hidden_states
+ hidden_states_shape_0 = x_q.shape[0]
+ else:
+ hidden_states_shape_0 = hidden_states.shape[0]
+ attn_out = self.mla_attn(
+ q,
+ kv_c_normed,
+ k_pe,
+ output_shape=(hidden_states_shape_0,
+ self.num_heads * self.v_head_dim))
+ return self.o_proj(attn_out)[0]
+
+ def forward_cuda(self, *args, **kwargs):
+ return self.forward_native(*args, **kwargs)
\ No newline at end of file
diff --git a/vllm_kunlun/ops/deep_gemm.py b/vllm_kunlun/ops/deep_gemm.py
new file mode 100644
index 0000000..af0640f
--- /dev/null
+++ b/vllm_kunlun/ops/deep_gemm.py
@@ -0,0 +1,114 @@
+import torch
+import xtorch_ops
+
+def int8_mqa_logits(
+ q: torch.Tensor,
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+) -> torch.Tensor:
+ """Compute FP8 MQA logits for a single sequence without KV paging.
+
+ Args:
+ q: Query tensor of shape [M, H, D]. Casted to
+ `torch.float8_e4m3fn` by caller.
+ kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
+ dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
+ [N, 1]) with dtype `torch.float32`.
+ weights: weights of shape [M, H], dtype `torch.float32`.
+ cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
+ shape [M], dtype int32.
+ cu_seqlen_ke: End indices (exclusive) for valid K per query position,
+ shape [M], dtype int32.
+
+ Returns:
+ Logits tensor of shape [M, N], dtype `torch.float32`.
+ """
+ logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device)
+ context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
+ context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
+
+ xtorch_ops.I8_mqa_logits(
+ q=q,
+ fused_kv_cache=kv,
+ weights=weights,
+ context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu),
+ context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu),
+ logits=logits,
+ clean_logits=True,
+ use_xfa_boost=False,
+ )
+ seq_len_kv = kv[0].shape[0]
+ # mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现
+ mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :]
+ >= cu_seqlen_ks[:, None])
+ mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :]
+ < cu_seqlen_ke[:, None])
+ mask = mask_lo & mask_hi
+ logits = logits.masked_fill(~mask, float("-inf"))
+
+ return logits
+
+def int8_paged_mqa_logits(
+ q_fp8: torch.Tensor,
+ kv_cache_fp8: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ context_lens_cpu: torch.Tensor,
+ block_tables: torch.Tensor,
+ schedule_metadata: torch.Tensor,
+ max_model_len: int,
+) -> torch.Tensor:
+ """Compute FP8 MQA logits using paged KV-cache.
+
+ Args:
+ q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
+ `torch.float8_e4m3fn` by caller.
+ kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
+ [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
+ 4 bytes per (block,pos) store the `float` dequant scale.
+ weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
+ context_lens: Tensor of shape [B], dtype int32; effective context length
+ for each batch element.
+ block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
+ block indices to physical blocks in the paged cache.
+ schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
+ used to distribute work across SMs.
+ max_model_len: Maximum sequence length used to size the logits output.
+
+ Returns:
+ Logits tensor of shape [B * next_n, max_model_len], dtype
+ `torch.float32`.
+ """
+ batch_size, next_n, _, D = q_fp8.shape
+ num_blocks, block_size, _, _ = kv_cache_fp8.shape
+
+ kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1)
+ k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8)
+ k_val = k_val.view(-1,block_size, 1, D)
+ k_scale_list = []
+ for block_tables_idx in range(block_tables.shape[0]):
+ k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size *
+ D:].view(-1, 4)
+ k_scale_list.append(k_scale_item)
+ k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len)
+ kv_cache = [k_val, k_scale]
+
+ weights = weights.view(batch_size,next_n,-1)
+
+ logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
+
+ xtorch_ops.I8_paged_mqa_logits(
+ q=q_fp8,
+ fused_kv_cache=kv_cache,
+ weights=weights,
+ context_lens=[context_lens_cpu, context_lens],
+ block_table=block_tables,
+ max_context_len=max_model_len,
+ clean_logits=True,
+ out=logits,
+ use_xfa_boost=False
+ )
+ logits = logits.view(-1, max_model_len)
+ return logits
\ No newline at end of file
diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py
index 772fe1a..953fbc0 100644
--- a/vllm_kunlun/ops/fused_moe/layer.py
+++ b/vllm_kunlun/ops/fused_moe/layer.py
@@ -1,37 +1,14 @@
"""layer.py"""
+
+from contextlib import nullcontext
+from typing import Callable, Optional, Union, get_args
+
import torch
-import os
-from typing import Callable, Optional
+from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
-import vllm.envs as envs
-from vllm.config import get_current_vllm_config
-from vllm.forward_context import ForwardContext, get_forward_context
-from vllm.distributed import get_ep_group
-from vllm.distributed.eplb.eplb_state import EplbState
-
-from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
-from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
-from vllm.model_executor.layers.fused_moe.layer import (
- UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
-from vllm.model_executor.layers.quantization.base_config import (
- QuantizationConfig, QuantizeMethodBase)
-from vllm.model_executor.layers.fused_moe.config import (
- FusedMoEConfig, FusedMoEParallelConfig)
-
-from vllm.model_executor.custom_op import CustomOp
-from vllm.platforms import current_platform
-
-from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
-
-
-class FusedMoEMethodBase(VllmFusedMoEMethodBase):
- """FusedMoEMethodBase"""
- moe: FusedMoEConfig
-
-@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
-class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
- """UnquantizedFusedMoEMethod"""
- def apply(
+def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
@@ -45,6 +22,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
+ routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
@@ -52,40 +30,12 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
- linear_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""apply"""
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
-
- return self.forward_kunlun(x=x,
- layer=layer,
- router_logits=router_logits,
- top_k=top_k,
- renormalize=renormalize,
- use_grouped_topk=use_grouped_topk,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function,
- linear_weights=linear_weights,
- e_score_correction_bias=e_score_correction_bias)
-
- def forward_kunlun(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- use_grouped_topk: bool,
- top_k: int,
- router_logits: torch.Tensor,
- linear_weights: torch.Tensor,
- renormalize: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None,
- scoring_func: str = "softmax",
- e_score_correction_bias: Optional[torch.Tensor] = None
- ) -> torch.Tensor:
+
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.moe.use_ep:
@@ -93,21 +43,18 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
layer.w13_weight,
layer.w2_weight,
router_logits,
- linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
- topk_group=topk_group
- )
+ topk_group=topk_group)
else:
return ops.fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
- linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
@@ -118,12 +65,13 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
w1_bias = layer.w13_bias,
- w2_bias = layer.w2_bias,
- )
+ w2_bias = layer.w2_bias)
-class FusedMoE(VllmFusedMoE):
- """FusedMoE"""
- def __init__(self,
+UnquantizedFusedMoEMethod.apply = apply
+
+class VllmFusedMoE(FusedMoE):
+ def __init__(
+ self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
@@ -141,198 +89,47 @@ class FusedMoE(VllmFusedMoE):
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
+ routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
- is_sequence_parallel=False,
has_bias: bool = False,
+ is_sequence_parallel=False,
+ zero_expert_num: Optional[int] = 0,
+ zero_expert_type: Optional[str] = None,
):
super().__init__(
- num_experts=num_experts, # Global number of experts
- top_k=top_k,
- hidden_size=hidden_size,
- intermediate_size=intermediate_size,
- params_dtype=params_dtype,
- reduce_results=reduce_results,
- renormalize=renormalize,
- use_grouped_topk=use_grouped_topk,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- quant_config=quant_config,
- tp_size=tp_size,
- ep_size=ep_size,
- dp_size=dp_size,
- prefix=prefix,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- e_score_correction_bias=e_score_correction_bias,
- apply_router_weight_on_input=apply_router_weight_on_input,
- activation=activation,
- enable_eplb=enable_eplb,
- num_redundant_experts=num_redundant_experts,
- )
-
- vllm_config = get_current_vllm_config()
- if vllm_config.model_config is not None:
- model_dtype = vllm_config.model_config.dtype
- else:
- # TODO (bnell): This is a hack to get test_mixtral_moe to work
- # since model_config is not set in the pytest test.
- model_dtype = params_dtype
-
- moe = FusedMoEConfig(
- num_experts=self.global_num_experts,
- experts_per_token=top_k,
- hidden_dim=hidden_size,
- num_local_experts=self.local_num_experts,
- moe_parallel_config=self.moe_parallel_config,
- in_dtype=model_dtype,
- max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
+ num_experts=num_experts, # Global number of experts
+ top_k=top_k,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ params_dtype=params_dtype,
+ reduce_results=reduce_results,
+ renormalize=renormalize,
+ use_grouped_topk=use_grouped_topk,
+ num_expert_group=num_expert_group,
+ topk_group=topk_group,
+ quant_config=quant_config,
+ tp_size=tp_size,
+ ep_size=ep_size,
+ dp_size=dp_size,
+ prefix=prefix,
+ custom_routing_function=custom_routing_function,
+ scoring_func=scoring_func,
+ routed_scaling_factor=routed_scaling_factor,
+ e_score_correction_bias=e_score_correction_bias,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ activation=activation,
+ enable_eplb=enable_eplb,
+ num_redundant_experts=num_redundant_experts,
has_bias=has_bias,
- # quant_config=quant_config,
- )
- self.moe_config = moe
- self.quant_config = quant_config
+ is_sequence_parallel=is_sequence_parallel,
+ zero_expert_num=zero_expert_num,
+ zero_expert_type=zero_expert_type)
self.has_bias=has_bias
self.register_parameter("w13_bias", None)
self.register_parameter("w2_bias", None)
-
- # Note: get_quant_method will look at the layer's local_num_experts
- # for heuristic purposes, so it must be initialized first.
- quant_method: Optional[QuantizeMethodBase] = None
- quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
- else quant_config.get_quant_method(self, prefix))
- assert quant_method is not None
- # assert isinstance(quant_method, FusedMoEMethodBase)
- self.quant_method = quant_method
-
- if self.enable_eplb:
- from vllm_kunlun.ops.quantization.fp8 import (
- Fp8MoEMethod)
- if not isinstance(quant_method, Fp8MoEMethod):
- # TODO: Add support for additional quantization methods.
- # The implementation for other quantization methods does not
- # contain essential differences, but the current quant API
- # design causes duplicated work when extending to new
- # quantization methods, so I'm leaving it for now.
- # If you plan to add support for more quantization methods,
- # please refer to the implementation in `Fp8MoEMethod`.
- raise NotImplementedError("EPLB is only supported for FP8 "
- "quantization for now.")
-
- moe_quant_params = {
- "num_experts": self.local_num_experts,
- "hidden_size": hidden_size,
- "intermediate_size_per_partition":
- self.intermediate_size_per_partition,
- "params_dtype": params_dtype,
- "weight_loader": self.weight_loader,
- }
- # need full intermediate size pre-sharding for WNA16 act order
- if (self.quant_method.__class__.__name__
- in ("GPTQMarlinMoEMethod",
- "CompressedTensorsWNA16MarlinMoEMethod",
- "CompressedTensorsWNA16MoEMethod")):
- moe_quant_params["intermediate_size_full"] = intermediate_size
-
- self.quant_method.create_weights(layer=self, **moe_quant_params)
-
- def forward(self, hidden_states: torch.Tensor,
- router_logits: torch.Tensor = None,
- linear_weights: torch.Tensor = None):
- """forward"""
- # TODO: Once the OOM issue for the TPU backend is resolved, we will
- # switch to using the moe_forward custom op.
- if current_platform.is_tpu():
- return self.forward_impl(hidden_states, router_logits)
- else:
- forward_context: ForwardContext = get_forward_context()
- self = forward_context.no_compile_layers[self.layer_name]
- assert self.quant_method is not None
- return self.forward_impl(hidden_states, router_logits, linear_weights)
- # return torch.ops.vllm.moe_forward(hidden_states, router_logits,
- # self.layer_name)
-
- def forward_impl(self, hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- linear_weights: torch.Tensor = None):
- """forward_impl"""
- assert self.quant_method is not None
- if (self.moe_parallel_config.use_pplx_kernels
- or self.moe_parallel_config.use_deepep_ll_kernels):
- return self.forward_impl_chunked(hidden_states, router_logits)
-
- do_naive_dispatch_combine: bool = (
- self.dp_size > 1
- and not self.moe_parallel_config.use_deepep_ht_kernels)
- if do_naive_dispatch_combine:
- hidden_states, router_logits = get_ep_group().dispatch(
- hidden_states, router_logits)
-
- # Matrix multiply.
- final_hidden_states = self.quant_method.apply(
- layer=self,
- x=hidden_states,
- router_logits=router_logits,
- top_k=self.top_k,
- renormalize=self.renormalize,
- use_grouped_topk=self.use_grouped_topk,
- global_num_experts=self.global_num_experts,
- expert_map=self.expert_map,
- topk_group=self.topk_group,
- num_expert_group=self.num_expert_group,
- custom_routing_function=self.custom_routing_function,
- scoring_func=self.scoring_func,
- e_score_correction_bias=self.e_score_correction_bias,
- activation=self.activation,
- apply_router_weight_on_input=self.apply_router_weight_on_input,
- enable_eplb=self.enable_eplb,
- expert_load_view=self.expert_load_view,
- logical_to_physical_map=self.logical_to_physical_map,
- logical_replica_count=self.logical_replica_count,
- linear_weights=linear_weights
- )
-
- if do_naive_dispatch_combine:
- final_hidden_states = get_ep_group().combine(final_hidden_states)
-
- if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
- # Default set to False. (May have to add shared expert outputs.
- final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
- final_hidden_states)
-
- return final_hidden_states
- @classmethod
- def make_expert_params_mapping(
- cls,
- ckpt_gate_proj_name: str,
- ckpt_down_proj_name: str,
- ckpt_up_proj_name: str,
- num_experts: int,
- num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
-
- num_physical_experts = num_experts + num_redundant_experts
-
- # In the returned mapping:
- # - `expert_id` is the physical expert id
- # - `weight_name` contains the weight name of the logical expert
- # So that we should map the expert id to logical in `weight_name`
- physical_to_logical_map = \
- EplbState.build_initial_global_physical_to_logical_map(
- num_experts, num_redundant_experts)
-
- return [
- # (param_name, weight_name, expert_id, shard_id)
- ("experts.w13_" if weight_name
- in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
- f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
- expert_id, shard_id) for expert_id in range(num_physical_experts)
- for shard_id, weight_name in [
- ("w1", ckpt_gate_proj_name),
- ("w2", ckpt_down_proj_name),
- ("w3", ckpt_up_proj_name),
- ]
- ]
+FusedMoE=VllmFusedMoE
\ No newline at end of file
diff --git a/vllm_kunlun/ops/quantization/compressed_tensors_moe.py b/vllm_kunlun/ops/quantization/compressed_tensors_moe.py
index 7b06bc5..7a73e8b 100644
--- a/vllm_kunlun/ops/quantization/compressed_tensors_moe.py
+++ b/vllm_kunlun/ops/quantization/compressed_tensors_moe.py
@@ -1,244 +1,169 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import enum
+from enum import Enum
+from typing import Callable, Optional, Union
+
import torch
-from typing import Any, Literal, Optional, cast, Callable, Optional
+from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
-from compressed_tensors.config import (CompressionFormat,
- SparsityCompressionConfig,
- SparsityStructure)
-from compressed_tensors.quantization import (ActivationOrdering,
- QuantizationStrategy)
-from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
- FusedMoeWeightScaleSupported)
-from vllm.model_executor.layers.quantization.utils import replace_parameter
-# TODO: import position will be changed after 0.9.0
-# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
+def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
+ """modify scale -> abs max"""
+ layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
+ layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
+ layer.w13_weight_scale = torch.nn.Parameter(
+ layer.w13_weight_scale.data * 127, requires_grad=False
+ )
+ layer.w2_weight_scale = torch.nn.Parameter(
+ layer.w2_weight_scale.data * 127, requires_grad=False
+ )
-from vllm.model_executor.utils import set_weight_attrs
-import re
-import xtorch_ops
+def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ klx_process_weights_after_loading(layer)
+def apply(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ renormalize: bool,
+ use_grouped_topk: bool = False,
+ topk_group: Optional[int] = None,
+ num_expert_group: Optional[int] = None,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ custom_routing_function: Optional[Callable] = None,
+ scoring_func: str = "softmax",
+ routed_scaling_factor: float = 1.0,
+ e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
+ activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ hidden_states = x
+ global_num_experts, up_gate_size, _ = layer.w13_weight.shape
+ M, N = hidden_states.shape
+ hidden_dim = layer.w2_weight.shape[1]
+ normed_score = torch.empty(M,
+ top_k,
+ dtype=torch.float32,
+ device=hidden_states.device)
+ topk_ids = torch.empty(M,
+ top_k,
+ dtype=torch.int32,
+ device=hidden_states.device)
+ num_blocks = 12
+ block_statistic = torch.zeros(
+ num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
+ )
-from safetensors.torch import load_file as safe_load_file
-
-class CompressedTensorsMoEMethod(FusedMoEMethodBase):
-
- def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
- tsm = getattr(quant_config, "target_scheme_map", None) or {}
- linear_cfg = None
- for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
- if k in tsm and isinstance(tsm[k], dict):
- linear_cfg = tsm[k]; break
- if not linear_cfg:
- # print("target_scheme_map missing; fallback to INT8(W8A8) method")
- return CompressedTensorsW8A8Int8MoEMethod(quant_config)
- wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
- if not wq or not aq:
- # print("incomplete scheme; fallback to INT8(W8A8)")
- return CompressedTensorsW8A8Int8MoEMethod(quant_config)
- # 其它分流按需;默认回落:
- return CompressedTensorsW8A8Int8MoEMethod(quant_config)
-
-# copied from vllm 0.9.0
-class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
-
- def __init__(
- self,
- quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
- ):
- self.quant_config = quant_config
-
- # 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
- # print("Creating default INT8 quantization config for MoE")
-
- # 创建默认的权重量化配置字典
- self.weight_quant = type('WeightQuant', (), {
- 'type': 'int',
- 'num_bits': 8,
- 'strategy': 'channel',
- 'group_size': 128,
- 'symmetric': True,
- 'dynamic': False,
- 'actorder': 'none',
- 'observer': None,
- 'observer_kwargs': {},
- 'block_structure': None
- })()
-
- # 创建默认的输入激活量化配置字典
- self.input_quant = type('InputQuant', (), {
- 'type': 'int',
- 'num_bits': 8,
- 'strategy': 'token',
- 'group_size': 128,
- 'symmetric': True,
- 'dynamic': True,
- 'actorder': 'none',
- 'observer': None,
- 'observer_kwargs': {},
- 'block_structure': None
- })()
-
- # 修改比较方式,直接比较字符串
- per_channel = (
- self.weight_quant.strategy == "channel"
- and self.input_quant.strategy == "token")
- if not per_channel:
- raise ValueError(
- "For INT8 Fused MoE layers, we require channelwise, "
- "dynamic per token quantization. Found "
- f"{self.weight_quant}, {self.input_quant}")
-
- self.static_input_scales = not self.input_quant.dynamic
- if self.static_input_scales:
- raise ValueError(
- "For INT8 Fused MoE layers, we require channelwise, "
- "dynamic per token quantization. Found static input scales.")
-
- def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
- # 权重先用浮点占位,便于从 ckpt 加载原始权重
- w13_weight = torch.nn.Parameter(torch.empty(
- num_experts,
- 2 * intermediate_size_per_partition,
- hidden_size,
- dtype=params_dtype), # 通常是 torch.bfloat16
- requires_grad=False)
- layer.register_parameter("w13_weight", w13_weight)
- set_weight_attrs(w13_weight, extra_weight_attrs)
-
- w2_weight = torch.nn.Parameter(torch.empty(
- num_experts,
- hidden_size,
- intermediate_size_per_partition,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w2_weight", w2_weight)
- set_weight_attrs(w2_weight, extra_weight_attrs)
-
- # 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
- w13_weight_scale = torch.nn.Parameter(
- torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
- requires_grad=False)
- w2_weight_scale = torch.nn.Parameter(
- torch.empty(num_experts, hidden_size, dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
-
- # 输入 scale 动态计算即可
- layer.w13_input_scale = None
- layer.w2_input_scale = None
-
- def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
- w13_weight = torch.nn.Parameter(torch.empty(
- num_experts,
- 2 * intermediate_size_per_partition,
- hidden_size,
- dtype=torch.int8), # 直接使用 int8
- requires_grad=False)
- layer.register_parameter("w13_weight", w13_weight)
- set_weight_attrs(w13_weight, extra_weight_attrs)
-
- w2_weight = torch.nn.Parameter(torch.empty(
- num_experts,
- hidden_size,
- intermediate_size_per_partition,
- dtype=torch.int8), # 直接使用 int8
- requires_grad=False)
- layer.register_parameter("w2_weight", w2_weight)
- set_weight_attrs(w2_weight, extra_weight_attrs)
-
- # 缩放因子
- w13_weight_scale = torch.nn.Parameter(
- torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
- requires_grad=False)
- w2_weight_scale = torch.nn.Parameter(
- torch.empty(num_experts, hidden_size, dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
-
- # 输入 scale 动态计算
- layer.w13_input_scale = None
- layer.w2_input_scale = None
-
- @torch.no_grad()
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- return
- #原始权重转 float32 做统计更稳健
- w13_f = layer.w13_weight.float()
- w2_f = layer.w2_weight.float()
-
- # 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
- qmax = 127.0
- w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
- w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
-
- w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
- w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
-
- # 量化:用 3D scale 广播,存回 2D scale
- w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
- w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
-
- w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
- w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
-
- # 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
- w13_scale_2d = w13_scale_2d * 127.0
- w2_scale_2d = w2_scale_2d * 127.0
-
- # 回写参数:权重 int8;scale 用 float32 + 2D
- replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
- replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
- replace_parameter(layer, 'w13_weight_scale',
- torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
- replace_parameter(layer, 'w2_weight_scale',
- torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
-
- # 简要检查
- print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
-
- def apply(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- renormalize: bool,
- use_grouped_topk: bool = False,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- global_num_experts: int = -1,
- expert_map: Optional[torch.Tensor] = None,
- custom_routing_function: Optional[Callable] = None,
- scoring_func: str = "softmax",
- e_score_correction_bias: Optional[torch.Tensor] = None,
- apply_router_weight_on_input: bool = False,
- activation: str = "silu",
- enable_eplb: bool = False, # 添加这个参数
- expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
- logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
- logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
- linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
- ) -> torch.Tensor:
-
- output = torch.empty_like(x)
- torch.ops._C.moe_ffn_per_token_block(
- x=x,
- inter_weight=layer.w13_weight,
- inter_scale=layer.w13_weight_scale,
- outer_weight=layer.w2_weight,
- outer_scale=layer.w2_weight_scale,
- top_k=top_k,
- global_num_experts=global_num_experts,
- linear_weights=linear_weights,
- expert_map=expert_map,
- activation=activation,
- output=output,
- use_expert_parallel=expert_map is not None,
- ep_size=expert_map.size(0) if expert_map is not None else 1,
- ep_rank=0,
+ router_logits = router_logits.float()
+ if scoring_func == "softmax":
+ torch.ops._C.moe_softmax_topk_norm(
+ x=router_logits,
+ normed_score=normed_score,
+ topk_index=topk_ids,
+ block_statistic=None,
+ stable=True)
+ elif scoring_func == "sigmoid":
+ torch.ops._C.moe_sigmoid_group_topk_norm(
+ x=router_logits,
+ norm_score=normed_score,
+ topk_index=topk_ids,
+ block_static=block_statistic,
+ bias=e_score_correction_bias,
+ n_group=num_expert_group,
+ topk_group=topk_group,
+ scale=routed_scaling_factor,
)
- return output
-print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
- --> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
\ No newline at end of file
+ moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
+ expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
+ sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
+ sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
+
+ torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
+
+ torch.ops._C.moe_pre_sorted(
+ x=hidden_states,
+ topk_index=topk_ids,
+ block_statistic=block_statistic,
+ moe_expand=moe_expand,
+ moe_index=sorted_tokens_idx,
+ expert_m=expert_m,
+ sorted_tokens_num_lod=sorted_tokens_num_lod)
+
+ y = torch.empty(M,top_k,
+ layer.w13_weight.shape[1],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ moe_expand = moe_expand.view(M * top_k, hidden_dim)
+
+ x_shape = moe_expand.shape
+ x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
+ x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
+ torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
+
+ torch.ops._C.moe_fc(
+ x=x_q,
+ x_perchannel_max=x_scale,
+ weight=layer.w13_weight,
+ w_perchannel_max=layer.w13_weight_scale,
+ sorted_tokens_num_lod=sorted_tokens_num_lod,
+ sorted_tokens_idx=sorted_tokens_idx,
+ moe_topk=top_k,
+ y=y,
+ topk_ids=topk_ids,
+ # sort_mode=False,
+ act=None)
+
+ d = y.shape[-1] // 2
+ output_shape = (y.shape[:-1] + (d, ))
+ out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
+ torch.ops._C.silu_and_mul(out1, y)
+
+ out = torch.empty(M,top_k,
+ layer.w2_weight.shape[1],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device)
+
+ out1 = out1.reshape(-1, out1.shape[-1])
+ x_shape = out1.shape
+ x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
+ x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
+ torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
+
+ torch.ops._C.moe_fc(
+ x=x_q,
+ x_perchannel_max=x_scale,
+ weight=layer.w2_weight,
+ w_perchannel_max=layer.w2_weight_scale,
+ sorted_tokens_num_lod=sorted_tokens_num_lod,
+ sorted_tokens_idx=sorted_tokens_idx,
+ moe_topk=top_k,
+ y=out,
+ topk_ids=topk_ids,
+ # sort_mode=False,
+ act=None)
+
+ dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
+ output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
+ sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
+
+ torch.ops._C.moe_post(
+ x=out,
+ moe_index=sorted_tokens_idx,
+ normed_scale=normed_score,
+ dequant_scale=dequant_scale,
+ y=output
+ )
+ return output
+
+CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
+CompressedTensorsW8A8Int8MoEMethod.apply = apply
\ No newline at end of file
diff --git a/vllm_kunlun/ops/quantization/kernels/__init__.py b/vllm_kunlun/ops/quantization/kernels/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/vllm_kunlun/ops/quantization/kernels/scaled_mm/__init__.py b/vllm_kunlun/ops/quantization/kernels/scaled_mm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py b/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
new file mode 100644
index 0000000..25a3add
--- /dev/null
+++ b/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
@@ -0,0 +1,122 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Optional
+
+import torch
+
+from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig
+from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
+from vllm.model_executor.layers.quantization.utils import replace_parameter
+from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
+ convert_to_channelwise)
+
+def can_implement_kunlun(
+ cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]:
+ return True, None
+
+def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
+ """modify scale -> abs max"""
+ layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
+ layer.weight_scale = torch.nn.Parameter(
+ layer.weight_scale.data * 127, requires_grad=False)
+
+def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None:
+ # WEIGHT
+ # Cutlass kernels need transposed weight.
+ weight = getattr(layer, self.w_q_name)
+ replace_parameter(
+ layer, self.w_q_name,
+ torch.nn.Parameter(weight.t().data, requires_grad=False))
+
+ # WEIGHT SCALE
+ # Cutlass kernels support only per-tensor and per-channel.
+ # If we have a fused module (QKV, MLP) with per tensor scales (thus N
+ # scales being passed to the kernel), convert to the per-channel case.
+ is_fused_module = len(layer.logical_widths) > 1
+ weight_scale = getattr(layer, self.w_s_name)
+ if is_fused_module and not self.config.is_channelwise:
+ weight_scale = convert_to_channelwise(weight_scale,
+ layer.logical_widths)
+ replace_parameter(
+ layer, self.w_s_name,
+ torch.nn.Parameter(weight_scale.data, requires_grad=False))
+
+ # INPUT SCALE
+ if self.config.is_static_input_scheme:
+ input_scale = getattr(layer, self.i_s_name)
+
+ if self.config.input_symmetric:
+ replace_parameter(
+ layer, self.i_s_name,
+ torch.nn.Parameter(input_scale.max(), requires_grad=False))
+ setattr(layer, self.i_zp_name, None)
+ else:
+ input_zero_point = getattr(layer, self.i_zp_name)
+
+ # reconstruct the ranges
+ int8_traits = torch.iinfo(torch.int8)
+ azps = input_zero_point.to(dtype=torch.int32)
+ range_max = (input_scale * (int8_traits.max - azps)).max()
+ range_min = (input_scale * (int8_traits.min - azps)).min()
+
+ scale = (range_max - range_min) / (int8_traits.max -
+ int8_traits.min)
+ replace_parameter(
+ layer, self.i_s_name,
+ torch.nn.Parameter(scale, requires_grad=False))
+
+ # AZP loaded as int8 but used as int32
+ azp = (int8_traits.min -
+ range_min / scale).to(dtype=torch.int32)
+ replace_parameter(layer, self.i_zp_name,
+ torch.nn.Parameter(azp, requires_grad=False))
+
+ else:
+ setattr(layer, self.i_s_name, None)
+ setattr(layer, self.i_zp_name, None)
+
+ # azp_adj is the AZP adjustment term, used to account for weights.
+ # It does not depend on scales or azp, so it is the same for
+ # static and dynamic quantization.
+ # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
+ # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
+ if not self.config.input_symmetric:
+ weight = getattr(layer, self.w_q_name)
+ azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
+ if self.config.is_static_input_scheme:
+ # cutlass_w8a8 requires azp to be folded into azp_adj
+ # in the per-tensor case
+ azp_adj = getattr(layer, self.i_zp_name) * azp_adj
+ setattr(layer, self.azp_adj_name,
+ torch.nn.Parameter(azp_adj, requires_grad=False))
+ else:
+ setattr(layer, self.azp_adj_name, None)
+
+ klx_process_weights_after_loading(layer)
+
+def apply_weights_kunlun(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ x_q, x_scale, out = None, None, None
+ w_t_shape = layer.weight.T.shape
+ if isinstance(x, tuple):
+ x_q, x_scale = x
+ out = torch.empty((x_q.shape[0], w_t_shape[0]),
+ dtype=torch.bfloat16,
+ device=x_q.device)
+ else:
+ x_shape = x.shape
+ x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device)
+ x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device)
+ out = torch.empty((x_shape[0], w_t_shape[0]),
+ dtype=x.dtype,
+ device=x.device)
+ torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True)
+ torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out)
+ return out
+
+CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun
+CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun
+CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun
\ No newline at end of file
diff --git a/vllm_kunlun/ops/rotary_embedding.py b/vllm_kunlun/ops/rotary_embedding.py
index bbe32a6..a4c6289 100644
--- a/vllm_kunlun/ops/rotary_embedding.py
+++ b/vllm_kunlun/ops/rotary_embedding.py
@@ -19,7 +19,9 @@ import torch
import xspeedgate_ops
import os
from vllm.model_executor.layers.rotary_embedding import (
- RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
+ RotaryEmbedding, YaRNScalingRotaryEmbedding,
+ DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding,
+ DeepseekScalingRotaryEmbedding)
from typing import Optional, Tuple
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
@@ -143,12 +145,15 @@ def vllm_kunlun_mrope_forward_cuda(
return query, key
+DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward
+DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
+DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward
+DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
-
def Split_Norm_Rope(
qkv: torch.Tensor,
cos_sin_cache: torch.Tensor,
diff --git a/vllm_kunlun/platforms/kunlun.py b/vllm_kunlun/platforms/kunlun.py
index d52f294..b79cd62 100644
--- a/vllm_kunlun/platforms/kunlun.py
+++ b/vllm_kunlun/platforms/kunlun.py
@@ -177,6 +177,8 @@ class KunlunPlatform(Platform):
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
# we default to FlashMLA backend, so we need to force the blocksize
# here
+ use_sparse = hasattr(vllm_config.model_config.hf_config,
+ "index_topk")
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
from vllm.attention.ops.flashmla import is_flashmla_supported
@@ -185,6 +187,11 @@ class KunlunPlatform(Platform):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
+ if use_sparse and cache_config.block_size != 64:
+ cache_config.block_size = 64
+ logger.info(
+ "Forcing kv cache block size to 64 for FlashMLASparse "
+ "backend.")
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
@@ -224,6 +231,14 @@ class KunlunPlatform(Platform):
Returns:
str: Class name of the attention backend.
"""
+ if use_mla:
+ if use_sparse:
+ logger.info_once("Using Sparse MLA backend on V1 engine.")
+ # return ("vllm.v1.attention.backends.mla.flashmla_sparse."
+ # "FlashMLASparseBackend")
+ return ("vllm_kunlun.v1.attention.backends.mla.flashmla_sparse."
+ "FlashMLASparseBackend")
+ return "vllm_kunlun.v1.attention.backends.mla.flashmla.FlashMLABackend"
if use_v1:
return "vllm_kunlun.v1.attention.backends.kunlun_attn.KunlunAttentionBackend"
elif not use_mla:
diff --git a/vllm_kunlun/v1/attention/backends/mla/__init__.py b/vllm_kunlun/v1/attention/backends/mla/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/vllm_kunlun/v1/attention/backends/mla/common.py b/vllm_kunlun/v1/attention/backends/mla/common.py
new file mode 100644
index 0000000..d39b405
--- /dev/null
+++ b/vllm_kunlun/v1/attention/backends/mla/common.py
@@ -0,0 +1,1867 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+# MLA Common Components
+
+This file implements common components for MLA implementations.
+
+First we define:
+
+Sq as Q sequence length
+Skv as KV sequence length
+
+MLA has two possible ways of computing, a data-movement friendly approach and a
+compute friendly approach, we generally want to use the compute friendly
+approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
+and the data-movement friendly approach for "decode" (i.e. the ratio
+Sq / Skv is "large").
+
+NOTE what we deem small and large is currently determined by if its labelled
+prefill or decode by the scheduler, but this is something we should probably
+tune.
+
+Main reference: DeepseekV2 paper, and FlashInfer Implementation
+(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
+
+Deepseek's MLA attention works the following way:
+* Use a single latent vector to represent the per-token entry of the KV cache.
+* For decode (i.e. the memory friendly approach) the attention "simulates" a
+multi-head attention, while the compute is similar to multi-query attention.
+
+Below is example of both paths assuming batchsize = 1
+
+## More Extent Definitions:
+
+C Context length, `Skv - Sq`
+H hidden size
+N number of attention heads
+Lq latent dimension for Q 1536 in DSV3
+Lkv latent dimension for K/V 512 in DSV3
+P nope dimension, no rope. 128 in DSV3
+R rope dimension, goes through rope. 64 in DSV3
+V V head dim. 128 in DSV3
+
+## Vector/Matrix Definitions
+
+h_t hidden states (input to attention) shape [Sq, H]
+q_c latent/compressed Q shape [Sq, Lq]
+q_nope uncompressed Q (no-rope) shape [Sq, N, P]
+q_pe uncompressed Q (rope) shape [Sq, N, R]
+kv_c latent/compressed KV shape [Skv, Lkv]
+k_pe decoupled k position embeddings shape [Skv, R]
+new_kv_c new kv_c from current iter shape [Sq, Lkv]
+new_k_pe new k_pe from current iter shape [Sq, R]
+cache_kv_c cached k_c from previous iters shape [C, Lkv]
+cache_k_pe cached k_pe from previous iters shape [C, R]
+W_DQ project h_t to q_c shape [H, Lq]
+W_UQ project q_c to q_nope shape [Lq, N * P]
+W_QR project q_c to q_pe shape [Lq, N * R]
+W_DKV project h_t to kv_c shape [H, Lkv]
+W_UK project kv_c to k_nope shape [Lkv, N, P]
+W_KR project h_t to k_pe shape [H, R]
+W_UV project kv_c to v shape [Lkv, N, V]
+W_O project v to h_t shape [N * V, H]
+
+
+## Compute Friendly Approach (i.e. "_forward_prefill"):
+
+q_c = h_t @ W_DQ
+q_nope = (q_c @ W_UQ).view(Sq, N, P)
+q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
+new_kv_c = h_t @ W_DKV
+new_k_pe = RoPE(h_t @ W_KR)
+kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
+k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
+k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
+v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
+
+// MHA with QK headdim = P + R
+// V headdim = V
+// spda_o shape [Sq, N, V]
+spda_o = scaled_dot_product_attention(
+ torch.cat([q_nope, q_pe], dim=-1),
+ torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
+ v
+)
+return spda_o @ W_O
+
+NOTE: in the actual code,
+ `kv_b_proj` is [W_UK; W_UV] concatenated per head
+ `q_b_proj` is [W_UQ; W_QR] concatenated per head
+ `out_proj` is W_O
+
+
+## Data-Movement Friendly Approach (i.e. "_forward_decode"):
+
+Runtime
+q_c = h_t @ W_DQ
+q_nope = (q_c @ W_UQ).view(-1, N, P)
+ql_nope = einsum("snh,lnh->snl", q, W_UK)
+q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
+new_kv_c = h_t @ W_DKV
+new_k_pe = RoPE(h_t @ W_KR)
+kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
+k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
+
+// MQA with QK headdim = Lkv + R
+// V headdim = Lkv
+// spda_o shape [Sq, N, Lkv]
+// NOTE: this is less compute-friendly since Lkv > P
+// but is more data-movement friendly since its MQA vs MHA
+spda_o = scaled_dot_product_attention(
+ torch.cat([ql_nope, q_pe], dim=-1),
+ torch.cat([kv_c, k_pe], dim=-1),
+ kv_c
+)
+
+o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
+return o.view(-1, N * V) @ self.num_heads @ W_O
+
+
+## Chunked Prefill
+
+For chunked prefill we want to use the compute friendly algorithm. We are
+assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
+the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
+
+However, the compute-friendly approach can potentially run out of memory if Skv
+is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
+
+To mitigate this, we chunk the computation of attention with respect to the
+current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
+fixed workspace size.
+
+The chunked prefill approach is as follows:
+
+MCC Max chunk of context to process per iter, computed dynamically,
+ used to bound the memory usage
+
+q_c = h_t @ W_DQ
+q_nope = (q_c @ W_UQ).view(Sq, N, P)
+q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
+new_kv_c = h_t @ W_DKV
+new_k_pe = RoPE(h_t @ W_KR)
+new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
+new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
+
+// MHA between queries and new KV
+// with QK headdim = P + R
+// V headdim = V
+// curr_o shape [Sq, N, V]
+// curr_lse shape [N, Sq], this is just order FA returns
+curr_o, curr_lse = scaled_dot_product_attention(
+ torch.cat([q_nope, q_pe], dim=-1),
+ torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
+ new_v,
+ casual=True,
+ return_softmax_lse=True
+)
+
+// Compute attention with the already existing context
+for chunk_idx in range(cdiv(C, MCC)):
+ chunk_start = chunk_idx * MCC
+ chunk_end = min(chunk_start + MCC, C)
+ Sc = chunk_end - chunk_start
+ cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
+ cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
+ cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
+ cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
+
+ chunk_o, chunk_lse = scaled_dot_product_attention(
+ torch.cat([q_nope, q_pe], dim=-1),
+ torch.cat([cache_k_nope_chunk,
+ cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
+ dim=-1),
+ cache_v_chunk,
+ casual=False,
+ return_softmax_lse=True
+ )
+
+ curr_o, curr_lse = merge_attn_states(
+ suffix_output=curr_o,
+ suffix_lse=curr_lse,
+ prefix_output=chunk_o,
+ prefix_lse=chunk_lse,
+ )
+
+return curr_o @ W_O
+"""
+
+import functools
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from typing import Generic, Optional, TypeVar, Union
+
+import torch
+from tqdm import tqdm
+
+import vllm.envs as envs
+from vllm import _custom_ops as ops
+from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
+ AttentionMetadata,
+ MLAAttentionImpl)
+from vllm.attention.backends.utils import get_mla_dims
+from vllm.attention.ops.common import cp_lse_ag_out_rs
+from vllm.attention.ops.merge_attn_states import merge_attn_states
+from vllm.attention.utils.fa_utils import get_flash_attn_version
+from vllm.config import VllmConfig, get_current_vllm_config
+from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
+from vllm.distributed import get_tp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ LinearBase,
+ UnquantizedLinearMethod)
+from vllm.platforms import current_platform
+from vllm.utils import cdiv, round_down
+from vllm.utils.flashinfer import has_nvidia_artifactory
+from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
+ CommonAttentionMetadata,
+ get_per_layer_parameters,
+ infer_global_hyperparameters,
+ split_decodes_and_prefills)
+from vllm.v1.kv_cache_interface import AttentionSpec
+import xtorch_ops
+
+try:
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
+ is_vllm_fa = True
+except ImportError:
+ # For rocm use upstream flash attention
+ if current_platform.is_rocm():
+ from flash_attn import flash_attn_varlen_func
+ is_vllm_fa = False
+
+try:
+ from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
+ from flashinfer.prefill import ( # noqa: F401
+ cudnn_batch_prefill_with_kv_cache)
+ flashinfer_available = True
+except ImportError:
+ flashinfer_available = False
+
+
+def is_rocm_aiter_fp8bmm_enabled() -> bool:
+ return current_platform.is_rocm() \
+ and envs.VLLM_ROCM_USE_AITER_FP8BMM \
+ and envs.VLLM_ROCM_USE_AITER
+
+
+if is_rocm_aiter_fp8bmm_enabled():
+ from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
+ batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
+ as aiter_triton_fp8_bmm)
+
+ def dynamic_per_batched_tensor_quant(
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn):
+ DTYPE_MAX = torch.finfo(dtype).max
+ min_val, max_val = x.aminmax()
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
+ scale = DTYPE_MAX / amax
+ x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
+
+
+logger = init_logger(__name__)
+
+CUDNN_WORKSPACE_SIZE = 12800
+
+
+class MLACommonBackend(AttentionBackend):
+
+ accept_output_buffer: bool = True
+
+ @staticmethod
+ def get_name() -> str:
+ return "TRITON_MLA"
+
+ @staticmethod
+ def get_metadata_cls() -> type["AttentionMetadata"]:
+ return MLACommonMetadata
+
+ @staticmethod
+ def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
+ return MLACommonMetadataBuilder
+
+ @staticmethod
+ def get_kv_cache_shape(
+ num_blocks: int,
+ block_size: int,
+ num_kv_heads: int, # assumed to be 1 for MLA
+ head_size: int,
+ cache_dtype_str: str = "auto",
+ ) -> tuple[int, ...]:
+ return (num_blocks, block_size, head_size)
+
+ @classmethod
+ def get_supported_dtypes(cls) -> list[torch.dtype]:
+ return [torch.float16, torch.bfloat16]
+
+ @classmethod
+ def get_supported_head_sizes(cls) -> list[int]:
+ return [576]
+
+ @classmethod
+ def validate_head_size(cls, head_size: int) -> None:
+ supported_head_sizes = cls.get_supported_head_sizes()
+ if head_size not in supported_head_sizes:
+ attn_type = cls.__name__.removesuffix("Backend")
+ raise ValueError(
+ f"Head size {head_size} is not supported by {attn_type}. "
+ f"Supported head sizes are: {supported_head_sizes}. "
+ "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
+ "FlexAttention backend which supports all head sizes.")
+
+
+@dataclass
+class MLACommonPrefillMetadata:
+ """ Prefill Specific Metadata """
+
+ @dataclass
+ class ChunkedContextMetadata:
+ # New for MLA (compared to FlashAttention)
+ # For handling chunked prefill
+ cu_seq_lens: torch.Tensor
+ starts: torch.Tensor
+ seq_tot: list[int]
+ max_seq_lens: list[int]
+ seq_lens: torch.Tensor
+ workspace: torch.Tensor
+
+ # for mla DCP
+ cp_chunk_seq_lens: Optional[list[list[int]]] = None
+ origin_context_lens: Optional[list[int]] = None
+ cp_cu_seq_lens: Optional[torch.Tensor] = None
+ chunk_size: Optional[int] = None
+ cu_seq_lens_lst: Optional[list[list[int]]] = None
+
+ block_table: torch.Tensor
+ query_start_loc: torch.Tensor
+ query_start_loc_cpu: torch.Tensor
+ max_query_len: int
+ chunked_context: Optional[ChunkedContextMetadata] = None
+
+
+@dataclass
+class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
+ prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None
+ prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field(
+ default_factory=list)
+
+
+@dataclass
+class CudnnPrefillMetadata(MLACommonPrefillMetadata):
+
+ class ChunkedContextMetadata(
+ MLACommonPrefillMetadata.ChunkedContextMetadata):
+ seq_lens: torch.Tensor
+
+ query_seq_lens: Optional[torch.Tensor] = None
+ cudnn_workspace: Optional[torch.Tensor] = None
+
+
+@dataclass
+class MLACommonDecodeMetadata:
+ block_table: torch.Tensor
+ seq_lens: torch.Tensor
+
+
+D = TypeVar("D", bound=MLACommonDecodeMetadata)
+
+
+@dataclass
+class MLACommonMetadata(Generic[D]):
+ """Metadata for MLACommon.
+
+ NOTE: Please read the comment at the top of the file before trying to
+ understand this class
+ """
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
+ # |---------- N-1 iteration --------|
+ # |---------------- N iteration ---------------------|
+ # |- tokenA -|......................|-- newTokens ---|
+ # |---------- context_len ----------|
+ # |-------------------- seq_len ---------------------|
+ # |-- query_len ---|
+
+ num_reqs: int
+ max_query_len: int
+ max_seq_len: int
+
+ num_actual_tokens: int # Number of tokens excluding padding.
+ query_start_loc: torch.Tensor
+ query_start_loc_cpu: torch.Tensor
+ slot_mapping: torch.Tensor
+
+ # New for MLA (compared to FlashAttention)
+ # For handling prefill decode split
+ num_decodes: int
+ num_decode_tokens: int
+ num_prefills: int
+
+ # The dimension of the attention heads
+ head_dim: Optional[int] = None
+
+ decode: Optional[D] = None
+ prefill: Optional[Union[MLACommonPrefillMetadata,
+ FlashInferPrefillMetadata,
+ CudnnPrefillMetadata]] = None
+
+ def __post_init__(self):
+ if self.head_dim is not None:
+ MLACommonBackend.validate_head_size(self.head_dim)
+
+
+M = TypeVar("M", bound=MLACommonMetadata)
+A = TypeVar("A")
+
+
+def use_flashinfer_prefill() -> bool:
+ # For blackwell default to flashinfer prefill if it's available since
+ # it is faster than FA2.
+ return (not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available
+ and not envs.VLLM_USE_CUDNN_PREFILL
+ and current_platform.is_device_capability(100))
+
+
+def use_cudnn_prefill() -> bool:
+ return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL
+ and current_platform.is_device_capability(100)
+ and has_nvidia_artifactory())
+
+
+# Currently 394MB, this can be tuned based on GEMM sizes used.
+# Chosen to be the same as sglang:
+# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
+FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
+
+
+class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
+ """
+ NOTE: Please read the comment at the top of the file before trying to
+ understand this class
+ """
+ reorder_batch_threshold: int = 1
+
+ @staticmethod
+ def determine_chunked_prefill_workspace_size(
+ vllm_config: VllmConfig) -> int:
+ scheduler_config = vllm_config.scheduler_config
+ cache_config = vllm_config.cache_config
+ model_config = vllm_config.model_config
+
+ chunked_prefill_workspace_size = min(
+ # Try for 8 full length request or at least 4 pages per-request
+ max(8 * model_config.max_model_len,
+ 4 * scheduler_config.max_num_seqs * cache_config.block_size),
+ # For long-context models try not to over-allocate limiting
+ # kv-cache space, limiting it to 64k tokens,
+ # which would result in the workspace being:
+ # 2*(576)*(64*1024) = 144mb
+ # (assuming 576 MLA head dim, and fp16)
+ # which would result in up-projected context being
+ # 2*(192*128)*(64*1024) = 3gb
+ # (assuming 192 QK head dim, 128 heads, and fp16)
+ 64 * 1024)
+
+ # Enforce that we enough for at least 1 page per request
+ chunked_prefill_workspace_size = max(
+ chunked_prefill_workspace_size,
+ scheduler_config.max_num_seqs * cache_config.block_size)
+
+ return chunked_prefill_workspace_size
+
+ def __init__(self,
+ kv_cache_spec: AttentionSpec,
+ layer_names: list[str],
+ vllm_config: VllmConfig,
+ device: torch.device,
+ metadata_cls: Optional[type[M]] = None):
+ self.metadata_cls = metadata_cls \
+ if metadata_cls is not None else MLACommonMetadata
+ self.kv_cache_spec = kv_cache_spec
+ scheduler_config = vllm_config.scheduler_config
+ self.model_config = vllm_config.model_config
+ parallel_config = vllm_config.parallel_config
+ self.compilation_config = vllm_config.compilation_config
+ self.device = device
+
+ self.num_heads = self.model_config.get_num_attention_heads(
+ parallel_config)
+ self.mla_dims = get_mla_dims(self.model_config)
+ self.aot_schedule = current_platform.is_cuda()
+ try:
+ self.dcp_world_size = get_dcp_group().world_size
+ self.dcp_rank = get_dcp_group().rank_in_group
+ except AssertionError:
+ # DCP might not be initialized in testing
+ self.dcp_world_size = 1
+ self.dcp_rank = 0
+
+ # Don't try to access the runner on AMD
+ if self.aot_schedule:
+ self.page_size = self.kv_cache_spec.block_size
+
+ self.chunked_prefill_workspace_size = \
+ self.determine_chunked_prefill_workspace_size(vllm_config)
+
+ if self.dcp_world_size > 1:
+ # Note(hc): The local kvcache is incomplete when DCP is triggered,
+ # an additional kvcache allgather across the DCP group is therefore
+ # required, so the workspace has to be enlarged by 1/DCP relative
+ # to the original TP allocation.
+ assert self.chunked_prefill_workspace_size % \
+ self.dcp_world_size == 0
+ self.chunked_prefill_workspace = torch.empty(
+ (self.chunked_prefill_workspace_size +
+ self.chunked_prefill_workspace_size // self.dcp_world_size,
+ self.model_config.get_head_size()),
+ dtype=self.model_config.dtype,
+ device=device,
+ )
+ else:
+ self.chunked_prefill_workspace = torch.empty(
+ (self.chunked_prefill_workspace_size,
+ self.model_config.get_head_size()),
+ dtype=self.model_config.dtype,
+ device=device,
+ )
+
+ self._use_cudnn_prefill = use_cudnn_prefill()
+ self._use_fi_prefill = use_flashinfer_prefill()
+ self.prefill_metadata_cls = (
+ FlashInferPrefillMetadata
+ if self._use_fi_prefill else CudnnPrefillMetadata
+ if self._use_cudnn_prefill else MLACommonPrefillMetadata)
+
+ if self._use_fi_prefill:
+ self._workspace_buffer = torch.empty(
+ FLASHINFER_WORKSPACE_BUFFER_SIZE,
+ dtype=torch.uint8,
+ device=device)
+
+ self._fi_prefill_main: Optional[
+ BatchPrefillWithRaggedKVCacheWrapper] = None
+ self._fi_prefill_chunks: list[
+ BatchPrefillWithRaggedKVCacheWrapper] = []
+
+ self._global_hyperparameters = infer_global_hyperparameters(
+ get_per_layer_parameters(vllm_config, layer_names,
+ MLACommonImpl))
+
+ if self._use_cudnn_prefill:
+ self.cudnn_workspace = torch.empty(
+ CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
+ dtype=torch.int8,
+ device=device,
+ )
+
+ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
+ qo_indptr = prefill.query_start_loc
+
+ has_context = False
+ if prefill.chunked_context is not None:
+ chunked_context = prefill.chunked_context
+ has_context = True
+
+ if self._fi_prefill_main is None:
+ self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
+ self._workspace_buffer, "NHD", backend="cutlass")
+
+ if has_context:
+ num_chunks = chunked_context.cu_seq_lens.shape[0]
+ # Allocate more prefill chunk wrappers if needed
+ if len(self._fi_prefill_chunks) < num_chunks:
+ for _ in range(len(self._fi_prefill_chunks), num_chunks):
+ self._fi_prefill_chunks.append(
+ BatchPrefillWithRaggedKVCacheWrapper(
+ self._workspace_buffer, "NHD", backend="cutlass"))
+ assert num_chunks <= len(self._fi_prefill_chunks)
+
+ # In MLA, the non-latent num_qo_heads == num_kv_heads
+ num_qo_heads = self.num_heads
+ num_kv_heads = num_qo_heads
+
+ # Sanity: Verify that num_kv_heads == 1 since it is latent space
+ assert self.kv_cache_spec.num_kv_heads == 1
+
+ # Get non-latent head_dim_qk and head_dim_vo
+ head_dim_qk = (self.mla_dims.qk_nope_head_dim +
+ self.mla_dims.qk_rope_head_dim)
+ head_dim_vo = self.mla_dims.v_head_dim
+
+ # For main run, qo_indptr == kv_indptr
+ kv_indptr = qo_indptr.clone()
+
+ # Prepare main prefill
+ self._fi_prefill_main.plan(
+ qo_indptr=qo_indptr,
+ kv_indptr=kv_indptr,
+ num_qo_heads=num_qo_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim_qk=head_dim_qk,
+ head_dim_vo=head_dim_vo,
+ causal=True, # This is main run
+ sm_scale=self._global_hyperparameters.sm_scale,
+ window_left=self._global_hyperparameters.window_left,
+ logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
+ q_data_type=self.model_config.dtype,
+ )
+
+ # Prepare context prefills
+ if has_context:
+ for i in range(num_chunks):
+ kv_indptr_chunk = chunked_context.cu_seq_lens[i]
+
+ self._fi_prefill_chunks[i].plan(
+ qo_indptr=qo_indptr,
+ kv_indptr=kv_indptr_chunk,
+ num_qo_heads=num_qo_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim_qk=head_dim_qk,
+ head_dim_vo=head_dim_vo,
+ causal=False, # This is context run
+ sm_scale=self._global_hyperparameters.sm_scale,
+ window_left=self._global_hyperparameters.window_left,
+ logits_soft_cap=self._global_hyperparameters.
+ logits_soft_cap,
+ q_data_type=self.model_config.dtype,
+ )
+
+ prefill.prefill_main = self._fi_prefill_main
+ prefill.prefill_chunks = self._fi_prefill_chunks
+
+ def _build_decode(self, block_table_tensor: torch.Tensor,
+ seq_lens_cpu: torch.Tensor,
+ seq_lens_device: torch.Tensor,
+ query_start_loc_cpu: torch.Tensor,
+ query_start_loc_device: torch.Tensor,
+ num_decode_tokens: int) -> MLACommonDecodeMetadata:
+ return MLACommonDecodeMetadata(
+ block_table=block_table_tensor,
+ seq_lens=seq_lens_device,
+ )
+
+ def build_for_cudagraph_capture(
+ self, common_attn_metadata: CommonAttentionMetadata) -> M:
+ """
+ This method builds the metadata for full cudagraph capture.
+ Currently, only decode is supported for full cudagraphs with MLA.
+ """
+ m = common_attn_metadata
+ assert m.num_reqs <= (m.num_actual_tokens *
+ self.reorder_batch_threshold), \
+ "MLA only supports decode-only full CUDAGraph capture. " \
+ "Make sure all cudagraph capture sizes <= max_num_seq."
+
+ assert m.max_query_len <= self.reorder_batch_threshold # decode only
+
+ return self.build(0, m)
+
+ def build(self,
+ common_prefix_len: int,
+ common_attn_metadata: CommonAttentionMetadata,
+ fast_build: bool = False) -> M:
+ num_reqs = common_attn_metadata.num_reqs
+ num_tokens = common_attn_metadata.num_actual_tokens
+ max_query_len = common_attn_metadata.max_query_len
+ max_seq_len = common_attn_metadata.max_seq_len
+
+ # Note(simon): be careful about the CPU <> GPU memory movement in this
+ # function. We should avoid GPU -> CPU sync as much as possible because
+ # it blocks on all previous kernels.
+ device = self.device
+ block_table_tensor = common_attn_metadata.block_table_tensor
+ slot_mapping = common_attn_metadata.slot_mapping
+
+ query_start_loc = common_attn_metadata.query_start_loc
+ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
+ seq_lens = common_attn_metadata.seq_lens
+ seq_lens_cpu = common_attn_metadata.seq_lens_cpu
+
+ query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
+
+ num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
+ query_seq_lens_cpu)
+
+ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
+ split_decodes_and_prefills(common_attn_metadata,
+ decode_threshold=self.reorder_batch_threshold)
+
+ # Note(hc): update seq_lens of decode reqs under DCP.
+ if self.dcp_world_size > 1:
+ seq_lens[:num_decodes] = seq_lens[:num_decodes] \
+ // self.dcp_world_size + (self.dcp_rank <= \
+ (seq_lens[:num_decodes] - 1) % self.dcp_world_size)
+
+ assert num_decodes + num_prefills == num_reqs
+ assert num_decode_tokens + num_prefill_tokens == num_tokens
+
+ prefill_metadata = None
+ if num_prefills > 0:
+ reqs_start = num_decodes # prefill_start
+
+ context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
+ # Note(hc): The context lengths in the perspective of dcp rank0.
+ cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
+ self.dcp_world_size).int()
+ origin_context_lens = context_lens_cpu.tolist()
+ max_context_len_cpu = context_lens_cpu.max().item()
+ num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
+ prefill_query_start_loc = query_start_loc[
+ reqs_start:] - query_start_loc[reqs_start]
+
+ chunked_context_metadata = None
+ if max_context_len_cpu > 0:
+ # NOTE: it is recommend you read the `Chunked Prefill` section
+ # in the comment at the top of the file before trying to
+ # understand the following code
+
+ # currently we allocate an equal amount of workspace for each
+ # prefill in the batch, we could probably use a more advanced
+ # algorithm here and allocate more workspace to prefills with
+ # longer context lengths
+ max_context_chunk = (self.chunked_prefill_workspace_size //
+ num_prefills_with_context_cpu)
+
+ if self.aot_schedule:
+ # align max_context_chunk to page_size by rounding down,
+ # currently the `gather_and_maybe_dequant_cache` kernel
+ # cannot handle `context_chunk_starts` that are not aligned
+ # to page_size
+ max_context_chunk = round_down(max_context_chunk,
+ self.page_size)
+
+ assert max_context_chunk > 0
+ num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
+
+ # if `max_context_chunk = 256`, `num_chunks = 3`, and
+ # `num_prefills_with_context = 4`, create a tensor that looks
+ # like
+ # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
+ # Note(simon): this is done in CPU because of downstream's
+ # of `to_list`.
+ chunk_starts = \
+ torch.arange(num_chunks, dtype=torch.int32) \
+ .unsqueeze(1).expand(-1, num_prefills) \
+ * max_context_chunk
+ chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
+ chunk_starts + max_context_chunk)
+ chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
+
+ cu_seq_lens_cpu = torch.zeros(num_chunks,
+ num_prefills + 1,
+ dtype=torch.int32,
+ pin_memory=True)
+ torch.cumsum(chunk_seq_lens,
+ dim=1,
+ out=cu_seq_lens_cpu[:, 1:],
+ dtype=torch.int32)
+
+ if self.dcp_world_size > 1:
+ # Note(hc): The above max_context_chunk already enforces
+ # block_size alignment, DCP just need the block_size can
+ # be divisible by dcp_world_size, because DCP use
+ # cp_gather_cache which not require `cp_chunk_starts`
+ # aligned to page_size.
+ assert max_context_chunk % self.dcp_world_size == 0
+ cp_max_context_chunk = max_context_chunk // \
+ self.dcp_world_size
+ cp_chunk_starts = \
+ torch.arange(num_chunks, dtype=torch.int32) \
+ .unsqueeze(1).expand(-1, num_prefills) \
+ * cp_max_context_chunk
+ cp_chunk_ends = torch.min(
+ cp_context_lens_cpu.unsqueeze(0),
+ cp_chunk_starts + cp_max_context_chunk)
+ cp_chunk_seq_lens = (cp_chunk_ends -
+ cp_chunk_starts).clamp(min=0)
+
+ cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
+ num_prefills + 1,
+ dtype=torch.int32,
+ pin_memory=True)
+ torch.cumsum(cp_chunk_seq_lens,
+ dim=1,
+ out=cp_cu_seq_lens_cpu[:, 1:],
+ dtype=torch.int32)
+
+ chunked_context_metadata_cls = \
+ CudnnPrefillMetadata.ChunkedContextMetadata \
+ if self._use_cudnn_prefill else \
+ MLACommonPrefillMetadata.ChunkedContextMetadata
+ if self.dcp_world_size > 1:
+ chunked_context_metadata = \
+ chunked_context_metadata_cls(
+ cu_seq_lens=cu_seq_lens_cpu \
+ .to(device, non_blocking=True),
+ starts=cp_chunk_starts.to(device, non_blocking=True),
+ seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
+ seq_lens=chunk_seq_lens,
+ workspace=self.chunked_prefill_workspace,
+ cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
+ origin_context_lens=origin_context_lens,
+ cp_cu_seq_lens=cp_cu_seq_lens_cpu \
+ .to(device, non_blocking=True),
+ chunk_size=max_context_chunk,
+ cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
+ )
+ else:
+ chunked_context_metadata = \
+ chunked_context_metadata_cls(
+ cu_seq_lens=cu_seq_lens_cpu \
+ .to(device, non_blocking=True),
+ starts=chunk_starts.to(device, non_blocking=True),
+ seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
+ seq_lens=chunk_seq_lens,
+ workspace=self.chunked_prefill_workspace,
+ )
+
+ if self._use_cudnn_prefill:
+ chunked_context_metadata.seq_lens = chunk_seq_lens
+
+ assert max(chunked_context_metadata.max_seq_lens) <= \
+ self.chunked_prefill_workspace_size
+
+ prefill_metadata = self.prefill_metadata_cls(
+ block_table=block_table_tensor[reqs_start:, ...],
+ query_start_loc=prefill_query_start_loc,
+ query_start_loc_cpu=prefill_query_start_loc.cpu(),
+ max_query_len=max_query_len,
+ chunked_context=chunked_context_metadata,
+ )
+
+ if self._use_cudnn_prefill:
+ assert isinstance(prefill_metadata, CudnnPrefillMetadata)
+ prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \
+ - prefill_query_start_loc[:-1]
+ prefill_metadata.cudnn_workspace = self.cudnn_workspace
+
+ decode_metadata = None
+ if num_decodes > 0:
+ decode_metadata = self._build_decode(
+ block_table_tensor=block_table_tensor[:num_decodes, ...],
+ seq_lens_cpu=seq_lens_cpu[:num_decodes],
+ seq_lens_device=seq_lens[:num_decodes],
+ query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
+ query_start_loc_device=query_start_loc[:num_decodes + 1],
+ num_decode_tokens=num_decode_tokens,
+ )
+
+ attn_metadata = self.metadata_cls(
+ num_reqs=common_attn_metadata.num_reqs,
+ max_query_len=common_attn_metadata.max_query_len,
+ max_seq_len=max_seq_len,
+ num_actual_tokens=num_tokens,
+ query_start_loc=query_start_loc,
+ query_start_loc_cpu=query_start_loc_cpu,
+ slot_mapping=slot_mapping,
+ head_dim=self.model_config.get_head_size(),
+ # MLACommonMetadata Chunk prefill specific
+ num_decodes=num_decodes,
+ num_decode_tokens=num_decode_tokens,
+ num_prefills=num_prefills,
+ prefill=prefill_metadata,
+ decode=decode_metadata,
+ )
+
+ if self._use_fi_prefill and num_prefills > 0:
+ assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
+ self._build_fi_prefill_wrappers(attn_metadata.prefill)
+
+ return attn_metadata
+
+
+def reorg_kvcache(
+ allgatered_kv_c_normed: torch.Tensor,
+ allgatered_k_pe: torch.Tensor,
+ cp_chunk_seq_lens_lst: list[int],
+ origin_context_lens: list[int],
+ cp_world_size: int,
+ sum_seq_len: int,
+ max_seq_len: int,
+ chunk_size: int,
+ chunk_idx: int,
+ toks: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ reorg kvcache after cp local gather to tp layout for attn kernel.
+
+ Args:
+ cp_chunk_seq_lens_lst: chunk context lengths under CP.
+ origin_context_lens: origin full context lengths under CP.
+ cp_world_size: CP size.
+ sum_seq_len: the sum of cp_chunk_seq_lens_lst.
+ max_seq_len: the max value of cp_chunk_seq_lens_lst.
+ chunk_size: equals to max_context_chunk from
+ chunked_context_metadata building.
+ chunk_idx: chunk idx of chunked_prefill.
+ toks: the number of tokens for local gather cache.
+ """
+ kv_c_segments = []
+ k_pe_segments = []
+ src_token_idx = 0
+ max_seq_len_check = 0
+ for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
+ origin_context_lens):
+ chunk_context_len = chunk_size
+ if cp_chunk_seq_len != 0:
+ chunk_context_len = min(
+ chunk_context_len, origin_context_len - chunk_size * chunk_idx)
+ cp_target_rank = (chunk_context_len - 1) % cp_world_size
+ cur_seq_len = 0
+ for rank in range(cp_world_size):
+ if rank > cp_target_rank and cp_chunk_seq_len:
+ real_cp_chunk_seq_len = cp_chunk_seq_len - 1
+ else:
+ real_cp_chunk_seq_len = cp_chunk_seq_len
+ if real_cp_chunk_seq_len:
+ kv_c_segment = allgatered_kv_c_normed[rank * toks +
+ src_token_idx:rank *
+ toks + src_token_idx +
+ real_cp_chunk_seq_len]
+ k_pe_segment = allgatered_k_pe[rank * toks +
+ src_token_idx:rank * toks +
+ src_token_idx +
+ real_cp_chunk_seq_len]
+ kv_c_segments.append(kv_c_segment)
+ k_pe_segments.append(k_pe_segment)
+ cur_seq_len += real_cp_chunk_seq_len
+ max_seq_len_check = max(max_seq_len_check, cur_seq_len)
+ src_token_idx += cp_chunk_seq_len
+ reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
+ reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
+ assert reorganized_kv_c_normed.shape[0] == sum_seq_len
+ assert reorganized_k_pe.shape[0] == sum_seq_len
+ assert max_seq_len_check == max_seq_len
+ return reorganized_kv_c_normed, reorganized_k_pe
+
+
+# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
+# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
+class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
+ """
+ NOTE: Please read the comment at the top of the file before trying to
+ understand this class
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: int,
+ alibi_slopes: Optional[list[float]],
+ sliding_window: Optional[int],
+ kv_cache_dtype: str,
+ logits_soft_cap: Optional[float],
+ attn_type: str,
+ kv_sharing_target_layer_name: Optional[str],
+ # MLA Specific Arguments
+ q_lora_rank: Optional[int],
+ kv_lora_rank: int,
+ qk_nope_head_dim: int,
+ qk_rope_head_dim: int,
+ qk_head_dim: int,
+ v_head_dim: int,
+ kv_b_proj: ColumnParallelLinear,
+ indexer=None,
+ q_pad_num_heads: Optional[int] = None,
+ ) -> None:
+ if kv_sharing_target_layer_name is not None:
+ raise NotImplementedError("KV sharing is not supported for MLA")
+
+ self.num_heads = num_heads
+ self.head_size = head_size
+ self.scale = float(scale)
+ self.num_kv_heads = num_kv_heads
+ self.kv_cache_dtype = kv_cache_dtype
+
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.qk_head_dim = qk_head_dim
+ self.v_head_dim = v_head_dim
+ self.kv_b_proj = kv_b_proj
+ self.indexer = indexer
+ self.q_pad_num_heads = q_pad_num_heads
+
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
+
+ def get_layer_weight(layer):
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
+ for attr in WEIGHT_NAMES:
+ if hasattr(layer, attr):
+ return getattr(layer, attr)
+ raise AttributeError(
+ f"Layer '{layer}' has no recognized weight attribute:"
+ f" {WEIGHT_NAMES}.")
+
+ def get_and_maybe_dequant_weights(layer: LinearBase):
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
+ # NOTE: This should only be used offline, since it's O(N^3)
+ eye = torch.eye(layer.input_size_per_partition,
+ dtype=act_dtype,
+ device=get_layer_weight(layer).device)
+ dequant_weights = layer.quant_method.apply(layer,
+ eye,
+ bias=None)
+ del eye
+ # standardize to (output, input)
+ return dequant_weights.T
+ return layer.weight
+
+ # we currently do not have quantized bmm's which are needed for
+ # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
+ assert kv_b_proj_weight.shape == (
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
+ f"{kv_b_proj_weight.shape=}, "
+ f"{self.kv_lora_rank=}, "
+ f"{self.num_heads=}, "
+ f"{self.qk_nope_head_dim=}, "
+ f"{self.v_head_dim=}")
+ kv_b_proj_weight = kv_b_proj_weight.view(
+ self.kv_lora_rank,
+ self.num_heads,
+ self.qk_nope_head_dim + self.v_head_dim,
+ )
+
+ W_UK, W_UV = kv_b_proj_weight.split(
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ if is_rocm_aiter_fp8bmm_enabled():
+ W_K = W_UK.transpose(0, 1) # 16 512 128
+ W_V = W_UV.permute(1, 2, 0) # 16 128 512
+ self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
+ W_K, dtype=current_platform.fp8_dtype())
+ self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
+ W_V, dtype=current_platform.fp8_dtype())
+
+ # The kernel operates on non-padded inputs. Hence, pre-compiling
+ # triton kernel to avoid runtime compilation for unseen batch sizes
+ # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
+ # On DS-R1, this step adds roughly 50s to the model loading time.
+ max_batch_size = 1024 # [ToDo] Find the optimal upper limit
+ pre_compilation_list = list(range(1, max_batch_size + 1))
+ if is_global_first_rank():
+ pre_compilation_list = tqdm(
+ pre_compilation_list,
+ desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
+ total=max_batch_size,
+ )
+
+ for m in pre_compilation_list:
+ x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
+ dtype=torch.bfloat16,
+ device=self.W_K.device)
+ aiter_triton_fp8_bmm(x,
+ self.W_K,
+ self.W_K_scale,
+ group_size=128,
+ transpose_bm=True)
+
+ x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
+ dtype=torch.bfloat16,
+ device=self.W_V.device)
+ aiter_triton_fp8_bmm(x,
+ self.W_V,
+ self.W_V_scale,
+ group_size=128,
+ transpose_bm=True)
+ else:
+ # Convert from (L, N, V) to (N, L, V)
+ self.W_UV = W_UV.transpose(0, 1)
+ # Convert from (L, N, P) to (N, P, L)
+ self.W_UK_T = W_UK.permute(1, 2, 0)
+
+ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
+ # Convert from (B, N, L) to (N, B, L)
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
+ if is_rocm_aiter_fp8bmm_enabled():
+ # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
+ x = aiter_triton_fp8_bmm(x,
+ self.W_V,
+ self.W_V_scale,
+ group_size=128,
+ transpose_bm=True)
+ # Convert from (B, N, V) to (B, N * V)
+ x = x.reshape(-1, self.num_heads * self.v_head_dim)
+ # Copy result
+ out.copy_(x)
+ else:
+ # Convert from (B, N * V) to (N, B, V)
+ out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
+
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
+ torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
+
+ # Convert from (N, B, V) to (B, N * V)
+ out_new = out.transpose(0, 1).reshape(
+ -1, self.num_heads * self.v_head_dim)
+
+ # Adjust output buffer shape back to the original (B, N * V)
+ N, B, V = out.shape
+ out.resize_((B, N * V))
+ out.copy_(out_new) # Copy result
+
+
+class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
+ """
+ NOTE: Please read the comment at the top of the file before trying to
+ understand this class
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ if use_flashinfer_prefill():
+ logger.debug_once("Using FlashInfer prefill for MLA")
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
+ self._pad_v = False
+ elif use_cudnn_prefill():
+ logger.debug_once("Using CUDNN prefill for MLA")
+ self._run_prefill_context_chunk = \
+ self._run_prefill_context_chunk_cudnn
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
+ self._pad_v = False
+ else: # Use FlashAttention
+ logger.debug_once("Using FlashAttention prefill for MLA")
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
+
+ # Handle the differences between the flash_attn_varlen from
+ # flash_attn and the one from vllm_flash_attn. The former is used on
+ # RoCM and the latter has an additional parameter to control
+ # FA2 vs FA3
+ self.flash_attn_varlen_func = flash_attn_varlen_func
+ self.vllm_flash_attn_version = get_flash_attn_version()
+ if self.vllm_flash_attn_version is not None:
+ self.flash_attn_varlen_func = \
+ functools.partial(flash_attn_varlen_func,
+ fa_version=self.vllm_flash_attn_version)
+
+ # For MLA the v head dim is smaller than qk head dim so we pad out
+ # v with 0s to match the qk head dim for attention backends that do
+ # not support different headdims
+ # We don't need to pad V if we are on a hopper system with FA3
+ self._pad_v = self.vllm_flash_attn_version is None or not (
+ self.vllm_flash_attn_version == 3
+ and current_platform.get_device_capability()[0] == 9)
+
+ self.dcp_world_size: Optional[int] = None
+
+ self.chunked_prefill_workspace_size = \
+ MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
+ get_current_vllm_config())
+
+ def _flash_attn_varlen_diff_headdims(self,
+ q,
+ k,
+ v,
+ context_seq_lod_xpu=None,
+ context_seq_lod_cpu=None,
+ return_softmax_lse=False,
+ causal=True,
+ softmax_scale=None,
+ **kwargs):
+ maybe_padded_v = v
+ if self._pad_v:
+ maybe_padded_v = torch.nn.functional.pad(
+ v, [0, q.shape[-1] - v.shape[-1]], value=0)
+
+ if is_vllm_fa:
+ kwargs["return_softmax_lse"] = return_softmax_lse
+ else:
+ # ROCm leverages the upstream flash_attn, which takes a parameter
+ # called "return_attn_probs" instead of return_softmax_lse
+ kwargs["return_attn_probs"] = return_softmax_lse
+
+ # attn_out = self.flash_attn_varlen_func(
+ # q=q,
+ # k=k,
+ # v=maybe_padded_v,
+ # softmax_scale=softmax_scale,
+ # **kwargs,
+ # )
+ attn_out = torch.empty_like(q)
+ ds_alpha = 1.8738542070926265
+ tp_q_head_num=128
+ softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
+ softmax_lse.fill_(float('-inf'))
+ xtorch_ops.attention(
+ q=q,
+ k_cache=k,
+ v_cache=maybe_padded_v,
+ out=attn_out,
+ is_causal=causal,
+ is_prefill=True,
+ prefill_len=0,
+ k_perchannel_scale=None,
+ v_perchannel_scale=None,
+ smooth=None,
+ context_seq_lod_cpu=context_seq_lod_cpu,
+ context_seq_lod_xpu=context_seq_lod_xpu,
+ slot_mapping_cpu=None,
+ slot_mapping_xpu=None,
+ v_trans=False,
+ v_trans_threshold=0,
+ alpha=ds_alpha,
+ softmax_lse=softmax_lse
+ )
+
+ # Unpack the output if there is multiple results
+ lse = None
+ if isinstance(attn_out, tuple):
+ attn_out, lse = attn_out[0], attn_out[1]
+
+ # Remain consistent with old `flash_attn_varlen_func` where there
+ # is only one output tensor if `return_softmax_lse` is False.
+ if return_softmax_lse:
+ return attn_out, lse
+ return attn_out
+
+ def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q,
+ k, v, return_softmax_lse):
+ return self._flash_attn_varlen_diff_headdims(
+ q=q,
+ k=k,
+ v=v,
+ context_seq_lod_xpu=prefill.query_start_loc,
+ context_seq_lod_cpu=prefill.query_start_loc_cpu,
+ softmax_scale=self.scale,
+ causal=True,
+ return_softmax_lse=return_softmax_lse,
+ )
+
+ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
+ k, v, return_softmax_lse):
+ assert isinstance(prefill, FlashInferPrefillMetadata)
+ assert prefill.prefill_main is not None
+ ret = prefill.prefill_main.run(
+ q=q,
+ k=k,
+ v=v,
+ return_lse=return_softmax_lse,
+ )
+
+ if isinstance(ret, tuple):
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
+ return ret[0], ret[1].transpose(0, 1).contiguous()
+ return ret
+
+ def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
+ q, k, v, return_softmax_lse):
+ assert isinstance(prefill, CudnnPrefillMetadata)
+ assert prefill.query_seq_lens is not None
+ output, lse = cudnn_batch_prefill_with_kv_cache(
+ q=q,
+ k_cache=k,
+ v_cache=v,
+ scale=self.scale,
+ workspace_buffer=prefill.cudnn_workspace,
+ max_token_per_sequence=prefill.max_query_len,
+ max_sequence_kv=prefill.max_query_len,
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
+ actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
+ causal=True,
+ return_lse=True, # do not support False for now
+ is_cuda_graph_compatible=
+ True, #Indicates actual_seq_lens are on GPU or CPU.
+ )
+ if return_softmax_lse:
+ return output, lse
+ return output
+
+ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
+ chunk_idx: int, q, k, v):
+ assert prefill.chunked_context is not None
+ return self._flash_attn_varlen_diff_headdims(
+ q=q,
+ k=k,
+ v=v,
+ context_seq_lod_xpu=prefill.query_start_loc,
+ context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx],
+ softmax_scale=self.scale,
+ causal=False,
+ return_softmax_lse=True,
+ )
+
+ def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
+ chunk_idx: int, q, k, v):
+ assert isinstance(prefill, FlashInferPrefillMetadata)
+ attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
+ q=q,
+ k=k,
+ v=v,
+ return_lse=True,
+ )
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
+ return attn_out, lse.transpose(0, 1).contiguous()
+
+ def _run_prefill_context_chunk_cudnn(self,
+ prefill: MLACommonPrefillMetadata,
+ chunk_idx: int, q, k, v):
+ assert isinstance(prefill, CudnnPrefillMetadata)
+ assert prefill.chunked_context is not None
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
+ assert prefill.query_seq_lens is not None
+ return cudnn_batch_prefill_with_kv_cache(
+ q=q,
+ k_cache=k,
+ v_cache=v,
+ scale=self.scale,
+ workspace_buffer=prefill.cudnn_workspace,
+ max_token_per_sequence=prefill.max_query_len,
+ max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
+ actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].
+ view(-1, 1, 1, 1),
+ causal=False,
+ return_lse=True,
+ is_cuda_graph_compatible=
+ True, #Indicates actual_seq_lens are on GPU or CPU.
+ )
+
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
+
+ def get_layer_weight(layer):
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
+ for attr in WEIGHT_NAMES:
+ if hasattr(layer, attr):
+ return getattr(layer, attr)
+ raise AttributeError(
+ f"Layer '{layer}' has no recognized weight attribute:"
+ f" {WEIGHT_NAMES}.")
+
+ def get_and_maybe_dequant_weights(layer: LinearBase):
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
+ # NOTE: This should only be used offline, since it's O(N^3)
+ eye = torch.eye(layer.input_size_per_partition,
+ dtype=act_dtype,
+ device=get_layer_weight(layer).device)
+ dequant_weights = layer.quant_method.apply(layer,
+ eye,
+ bias=None)
+ del eye
+ # standardize to (output, input)
+ return dequant_weights.T
+ return layer.weight
+
+ # we currently do not have quantized bmm's which are needed for
+ # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
+ assert kv_b_proj_weight.shape == (
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
+ f"{kv_b_proj_weight.shape=}, "
+ f"{self.kv_lora_rank=}, "
+ f"{self.num_heads=}, "
+ f"{self.qk_nope_head_dim=}, "
+ f"{self.v_head_dim=}")
+ kv_b_proj_weight = kv_b_proj_weight.view(
+ self.kv_lora_rank,
+ self.num_heads,
+ self.qk_nope_head_dim + self.v_head_dim,
+ )
+
+ W_UK, W_UV = kv_b_proj_weight.split(
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ if is_rocm_aiter_fp8bmm_enabled():
+ W_K = W_UK.transpose(0, 1) # 16 512 128
+ W_V = W_UV.permute(1, 2, 0) # 16 128 512
+ self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
+ W_K, dtype=current_platform.fp8_dtype())
+ self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
+ W_V, dtype=current_platform.fp8_dtype())
+
+ # The kernel operates on non-padded inputs. Hence, pre-compiling
+ # triton kernel to avoid runtime compilation for unseen batch sizes
+ # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
+ # On DS-R1, this step adds roughly 50s to the model loading time.
+ max_batch_size = 1024 # [ToDo] Find the optimal upper limit
+ pre_compilation_list = list(range(1, max_batch_size + 1))
+ if is_global_first_rank():
+ pre_compilation_list = tqdm(
+ pre_compilation_list,
+ desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
+ total=max_batch_size,
+ )
+
+ for m in pre_compilation_list:
+ x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
+ dtype=torch.bfloat16,
+ device=self.W_K.device)
+ aiter_triton_fp8_bmm(x,
+ self.W_K,
+ self.W_K_scale,
+ group_size=128,
+ transpose_bm=True)
+
+ x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
+ dtype=torch.bfloat16,
+ device=self.W_V.device)
+ aiter_triton_fp8_bmm(x,
+ self.W_V,
+ self.W_V_scale,
+ group_size=128,
+ transpose_bm=True)
+ else:
+ # Convert from (L, N, V) to (N, L, V)
+ self.W_UV = W_UV.transpose(0, 1)
+ # Convert from (L, N, P) to (N, P, L)
+ self.W_UK_T = W_UK.permute(1, 2, 0)
+
+ def gather_and_maybe_dequant_cache_py_optimized(
+ self,
+ src_cache: torch.Tensor,
+ dst: torch.Tensor,
+ block_table: torch.Tensor,
+ cu_seq_lens: torch.Tensor,
+ batch_size: int,
+ kv_cache_dtype: str,
+ scale: torch.Tensor,
+ seq_starts: Optional[torch.Tensor] = None
+ ) -> None:
+ device = src_cache.device
+ num_blocks, block_size, head_dim = src_cache.shape
+ tot_tokens = dst.shape[0]
+
+ src_cache_2d = src_cache.reshape(-1, head_dim)
+
+ token_ids = torch.arange(tot_tokens, device=device, dtype=torch.int32)
+
+ batch_ids = torch.zeros(tot_tokens, dtype=torch.int32, device=device)
+ for i in range(batch_size):
+ start = cu_seq_lens[i]
+ end = cu_seq_lens[i + 1]
+ mask = (token_ids >= start) & (token_ids < end)
+ batch_ids[mask] = i
+
+ batch_starts = cu_seq_lens[batch_ids]
+ batch_offsets = token_ids - batch_starts
+
+ if seq_starts is not None:
+ seq_start_offsets = seq_starts[batch_ids]
+ batch_offsets = batch_offsets + seq_start_offsets
+
+ block_table_ids = batch_offsets // block_size
+ slot_ids = batch_offsets % block_size
+
+ block_table_flat = block_table.view(-1)
+ block_table_stride = block_table.shape[1]
+ block_table_indices = batch_ids * block_table_stride + block_table_ids
+ block_ids = block_table_flat[block_table_indices]
+
+ valid_mask = block_ids >= 0
+ linear_indices = torch.where(
+ valid_mask,
+ block_ids * block_size + slot_ids,
+ torch.zeros_like(block_ids, device=device)
+ )
+
+ if kv_cache_dtype == 'auto':
+ gathered = src_cache_2d[linear_indices.long()]
+ dst[valid_mask] = gathered[valid_mask].to(dst.dtype)
+ else:
+ gathered = src_cache_2d[linear_indices.long()].float()
+ if kv_cache_dtype in ['fp8', 'fp8_e4m3', 'fp8_e5m2']:
+ gathered = gathered * scale.item()
+ dst[valid_mask] = gathered[valid_mask].to(dst.dtype)
+
+ def _compute_prefill_context(
+ self,
+ q: torch.Tensor,
+ kv_c_and_k_pe_cache: torch.Tensor,
+ attn_metadata: MLACommonMetadata,
+ k_scale: torch.Tensor,
+ ):
+ assert attn_metadata.prefill is not None
+ prefill_metadata = attn_metadata.prefill
+ assert prefill_metadata.chunked_context is not None
+
+ output = None
+ iters = len(prefill_metadata.chunked_context.seq_tot)
+ workspace = prefill_metadata.chunked_context.workspace
+
+ for i in range(iters):
+ toks = prefill_metadata.chunked_context.seq_tot[i]
+ self.gather_and_maybe_dequant_cache_py_optimized(
+ src_cache=kv_c_and_k_pe_cache,
+ dst=workspace,
+ block_table=prefill_metadata.block_table,
+ cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
+ batch_size=attn_metadata.num_prefills,
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=k_scale,
+ seq_starts=prefill_metadata.chunked_context.starts[i],
+ )
+
+ kv_c_normed = workspace[:toks]\
+ [..., :self.kv_lora_rank]
+ k_pe = workspace[:toks]\
+ [..., self.kv_lora_rank:].unsqueeze(1)
+
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+ k_nope, v = kv_nope\
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
+ dim=-1)
+
+ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
+ prefill=prefill_metadata,
+ chunk_idx=i,
+ q=q,
+ k=k,
+ v=v,
+ )
+
+ if output is None:
+ output = attn_output
+ output_lse = attn_softmax_lse
+ else:
+ output_tmp = torch.empty_like(output)
+ output_lse_tmp = torch.empty_like(output_lse)
+ merge_attn_states(
+ output=output_tmp,
+ output_lse=output_lse_tmp,
+ prefix_output=output,
+ prefix_lse=output_lse,
+ suffix_output=attn_output,
+ suffix_lse=attn_softmax_lse,
+ )
+ output = output_tmp
+ output_lse = output_lse_tmp
+
+ return output, output_lse
+
+ def _context_parallel_compute_prefill_context(
+ self,
+ q: torch.Tensor,
+ kv_c_and_k_pe_cache: torch.Tensor,
+ attn_metadata: MLACommonMetadata,
+ k_scale: torch.Tensor,
+ dcp_world_size: int,
+ ):
+ assert k_scale is None, "DCP not support scaled kvcache now."
+ assert attn_metadata.prefill is not None
+ prefill_metadata = attn_metadata.prefill
+ assert prefill_metadata.chunked_context is not None
+ assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
+ assert prefill_metadata.chunked_context.origin_context_lens is not None
+ assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
+ assert prefill_metadata.chunked_context.chunk_size is not None
+ assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
+
+ output = None
+ iters = len(prefill_metadata.chunked_context.seq_tot)
+ workspace = prefill_metadata.chunked_context.workspace
+
+ for i in range(iters):
+ toks = prefill_metadata.chunked_context.seq_tot[i]
+ ops.cp_gather_cache(
+ src_cache=kv_c_and_k_pe_cache,
+ dst=workspace,
+ block_table=prefill_metadata.block_table,
+ cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
+ batch_size=attn_metadata.num_prefills,
+ seq_starts=prefill_metadata.chunked_context.starts[i],
+ )
+ # workspace
+ # |------- N tokens --------|--------- N*dcp_size tokens ----------|
+ # |<- use for loca_gather ->|<--------- use for allgather -------->|
+ allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
+ assert allgather_offset * (dcp_world_size +
+ 1) == workspace.shape[0]
+ assert toks <= allgather_offset
+ local_gathered_kvcache = workspace[:toks]
+ cur_allgather_workspace = workspace[
+ allgather_offset:allgather_offset * (1 + dcp_world_size)]
+ assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
+ cur_allgather_kvcache = cur_allgather_workspace[:toks *
+ dcp_world_size]
+ cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
+ local_gathered_kvcache, dim=0))
+ assert cur_allgather_kvcache.shape[
+ -1] == self.kv_lora_rank + self.qk_rope_head_dim
+ allgatered_kv_c_normed, allgatered_k_pe = \
+ cur_allgather_kvcache.unsqueeze(
+ 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+
+ kv_c_normed, k_pe = reorg_kvcache(
+ allgatered_kv_c_normed,
+ allgatered_k_pe,
+ cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
+ cp_chunk_seq_lens[i],
+ origin_context_lens=prefill_metadata.chunked_context.
+ origin_context_lens,
+ cp_world_size=dcp_world_size,
+ sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
+ [-1],
+ max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
+ chunk_size=prefill_metadata.chunked_context.chunk_size,
+ chunk_idx=i,
+ toks=toks)
+
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+ k_nope, v = kv_nope\
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
+ dim=-1)
+
+ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
+ prefill=prefill_metadata,
+ chunk_idx=i,
+ q=q,
+ k=k,
+ v=v,
+ )
+
+ if output is None:
+ output = attn_output
+ output_lse = attn_softmax_lse
+ else:
+ output_tmp = torch.empty_like(output)
+ output_lse_tmp = torch.empty_like(output_lse)
+ merge_attn_states(
+ output=output_tmp,
+ output_lse=output_lse_tmp,
+ prefix_output=output,
+ prefix_lse=output_lse,
+ suffix_output=attn_output,
+ suffix_lse=attn_softmax_lse,
+ )
+ output = output_tmp
+ output_lse = output_lse_tmp
+
+ return output, output_lse
+
+ def _forward_prefill(
+ self,
+ q: torch.Tensor,
+ kv_c_normed: torch.Tensor,
+ k_pe: torch.Tensor,
+ kv_c_and_k_pe_cache: torch.Tensor,
+ attn_metadata: MLACommonMetadata,
+ k_scale: torch.Tensor,
+ ) -> torch.Tensor:
+ # TODO (zyongye): Prefill function here
+ assert attn_metadata.prefill is not None
+ assert self.dcp_world_size is not None
+
+ has_context = attn_metadata.prefill.chunked_context is not None
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
+ k_nope, v = kv_nope\
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
+
+ output = self._run_prefill_new_tokens(
+ prefill=attn_metadata.prefill,
+ q=q,
+ k=k,
+ v=v,
+ return_softmax_lse=has_context,
+ )
+
+ if has_context:
+ suffix_output, suffix_lse = output
+ if self.dcp_world_size > 1:
+ context_output, context_lse = \
+ self._context_parallel_compute_prefill_context(
+ q, kv_c_and_k_pe_cache, attn_metadata,
+ k_scale=None, dcp_world_size=self.dcp_world_size)
+ else:
+ context_output, context_lse = \
+ self._compute_prefill_context(
+ q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
+
+ output = torch.empty_like(suffix_output)
+ merge_attn_states(
+ output=output,
+ prefix_output=context_output,
+ prefix_lse=context_lse,
+ suffix_output=suffix_output,
+ suffix_lse=suffix_lse,
+ )
+
+ # unpad if necessary
+ if self._pad_v:
+ output = output[..., :v.shape[-1]]
+
+ return output.flatten(start_dim=-2)
+
+ @abstractmethod
+ def _forward_decode(
+ self,
+ q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
+ kv_c_and_k_pe_cache: torch.Tensor,
+ attn_metadata: M,
+ layer: AttentionLayer,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ raise NotImplementedError
+
+ def forward(
+ self,
+ layer: AttentionLayer,
+ q: torch.Tensor,
+ k_c_normed: torch.Tensor, # key in unified attn
+ k_pe: torch.Tensor, # value in unified attn
+ kv_cache: torch.Tensor,
+ attn_metadata: M,
+ output: Optional[torch.Tensor] = None,
+ output_scale: Optional[torch.Tensor] = None,
+ output_block_scale: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ assert output is not None, "Output tensor must be provided."
+
+ if output_scale is not None or output_block_scale is not None:
+ raise NotImplementedError(
+ "fused output quantization is not yet supported"
+ " for MLACommonImpl")
+
+ if attn_metadata is None:
+ # During the profile run try to simulate to worse case output size
+ # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
+ # since this can be large
+ _ = torch.empty(
+ (self.chunked_prefill_workspace_size, self.num_heads,
+ self.qk_nope_head_dim + self.v_head_dim),
+ device=k_c_normed.device,
+ dtype=k_c_normed.dtype,
+ )
+
+ # The zero fill is required when used with DP + EP
+ # to ensure all ranks within a DP group compute the
+ # same expert outputs.
+ return output.fill_(0)
+
+ if self.dcp_world_size is None:
+ self.dcp_world_size = get_dcp_group().world_size
+
+ fp8_attention = self.kv_cache_dtype.startswith("fp8")
+
+ num_actual_toks = attn_metadata.num_actual_tokens
+
+ # Inputs and outputs may be padded for CUDA graphs
+ output_padded = output
+ output = output[:num_actual_toks, ...]
+ q = q[:num_actual_toks, ...]
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
+ k_pe = k_pe[:num_actual_toks, ...]
+
+ assert attn_metadata.num_decodes is not None and \
+ attn_metadata.num_prefills is not None and \
+ attn_metadata.num_decode_tokens is not None
+
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+
+ decode_q = q[:num_decode_tokens]
+
+ prefill_q = q[num_decode_tokens:]
+ prefill_k_pe = k_pe[num_decode_tokens:]
+ prefill_k_c_normed = k_c_normed[num_decode_tokens:]
+
+ # write the latent and rope to kv cache
+ if kv_cache.numel() > 0:
+ xtorch_ops.concat_and_cache_mla(
+ k_c_normed,
+ k_pe.squeeze(1),
+ attn_metadata.slot_mapping.flatten(),
+ kv_cache,
+ )
+
+ if fp8_attention:
+ kv_cache = kv_cache.view(current_platform.fp8_dtype())
+
+ if has_prefill:
+ output[num_decode_tokens:] = self._forward_prefill(
+ prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
+ attn_metadata, layer._k_scale)
+
+ if has_decode:
+ assert attn_metadata.decode is not None
+ decode_q_nope, decode_q_pe = decode_q.split(
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ # Convert from (B, N, P) to (N, B, P)
+ decode_q_nope = decode_q_nope.transpose(0, 1)
+
+ # Pads the head_dim if necessary (for the underlying kernel)
+ if self.q_pad_num_heads is not None:
+ B, N, L = decode_q_pe.shape
+ decode_pe_padded = decode_q_pe.new_empty(
+ (B, self.q_pad_num_heads, L))
+ decode_pe_padded.resize_((B, N, L))
+ decode_pe_padded.copy_(decode_q_pe)
+ decode_q_pe = decode_pe_padded
+
+ if is_rocm_aiter_fp8bmm_enabled():
+ # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
+ decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
+ self.W_K,
+ self.W_K_scale,
+ group_size=128,
+ transpose_bm=True)
+ else:
+ # Pads the head_dim if necessary (for the underlying kernel)
+ N, B, P = decode_q_nope.shape
+ _, _, L = self.W_UK_T.shape
+ if self.q_pad_num_heads is not None:
+ decode_ql_nope = decode_q_nope.new_empty(
+ (self.q_pad_num_heads, B, L))
+ decode_ql_nope.resize_((N, B, L))
+
+ else:
+ decode_ql_nope = decode_q_nope.new_empty((N, B, L))
+
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
+ # Convert from (N, B, L) to (B, N, L)
+ decode_ql_nope = decode_ql_nope.transpose(0, 1)
+
+ if fp8_attention:
+ ql_nope_shape = decode_ql_nope.shape
+ decode_ql_nope, _ = ops.scaled_fp8_quant(
+ decode_ql_nope.reshape([
+ ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]
+ ]), layer._q_scale)
+ decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
+ q_pe_shape = decode_q_pe.shape
+ decode_q_pe, _ = ops.scaled_fp8_quant(
+ decode_q_pe.reshape(
+ [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
+ layer._q_scale)
+ decode_q_pe = decode_q_pe.reshape(q_pe_shape)
+
+ decode_q = (decode_ql_nope, decode_q_pe)
+ if self.dcp_world_size > 1:
+ assert not fp8_attention, "DCP not support fp8 kvcache now."
+ # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
+ decode_q = torch.cat(decode_q, dim=-1)
+ # decode_q do allgather in head dim.
+ decode_q = get_dcp_group().all_gather(decode_q, dim=1)
+
+ # call decode attn
+ attn_out, lse = self._forward_decode(decode_q, kv_cache,
+ attn_metadata, layer)
+
+ # recorect dcp attn_out with lse.
+ if self.dcp_world_size > 1:
+ attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
+
+ # v_up projection
+ self._v_up_proj(attn_out, out=output[:num_decode_tokens])
+ return output_padded
diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla.py b/vllm_kunlun/v1/attention/backends/mla/flashmla.py
new file mode 100644
index 0000000..46268eb
--- /dev/null
+++ b/vllm_kunlun/v1/attention/backends/mla/flashmla.py
@@ -0,0 +1,202 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass
+from typing import ClassVar, Optional, Union
+
+import torch
+
+from vllm.attention.backends.abstract import AttentionLayer, AttentionType
+from vllm_kunlun.ops.attention.flashmla import (flash_mla_with_kvcache,
+ get_mla_metadata,
+ is_flashmla_supported)
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm_kunlun.v1.attention.backends.mla.common import (MLACommonBackend,
+ MLACommonDecodeMetadata,
+ MLACommonImpl,
+ MLACommonMetadata,
+ MLACommonMetadataBuilder)
+from vllm.v1.attention.backends.utils import AttentionCGSupport
+from vllm.v1.kv_cache_interface import AttentionSpec
+
+logger = init_logger(__name__)
+
+class FlashMLABackend(MLACommonBackend):
+
+ @staticmethod
+ def get_name() -> str:
+ return "FLASHMLA"
+
+ @staticmethod
+ def get_metadata_cls() -> type["FlashMLAMetadata"]:
+ return FlashMLAMetadata
+
+ @staticmethod
+ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
+ return FlashMLAMetadataBuilder
+
+ @staticmethod
+ def get_impl_cls() -> type["FlashMLAImpl"]:
+ return FlashMLAImpl
+
+
+@dataclass
+class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
+ tile_scheduler_metadata: torch.Tensor
+ num_splits: torch.Tensor
+
+
+@dataclass
+class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
+ pass
+
+
+class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
+ cudagraph_support: ClassVar[AttentionCGSupport] = \
+ AttentionCGSupport.UNIFORM_BATCH
+
+ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
+ vllm_config: VllmConfig, device: torch.device):
+ super().__init__(kv_cache_spec, layer_names, vllm_config, device,
+ FlashMLAMetadata)
+
+ self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
+ vllm_config.parallel_config)
+
+ self.cg_buf_tile_scheduler_metadata = None
+ self.cg_buf_num_splits = None
+
+ device_properties = torch.cuda.get_device_properties(self.device)
+ num_sms = device_properties.multi_processor_count
+
+ if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
+ self.cg_buf_tile_scheduler_metadata = torch.zeros(
+ # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
+ # TileSchedulerMetaDataSize = 8
+ (num_sms, 8),
+ device=self.device,
+ dtype=torch.int32,
+ )
+ self.cg_buf_num_splits = torch.empty(
+ (vllm_config.scheduler_config.max_num_seqs + 1),
+ device=self.device,
+ dtype=torch.int32)
+
+ def _build_decode(self, block_table_tensor: torch.Tensor,
+ seq_lens_cpu: torch.Tensor,
+ seq_lens_device: torch.Tensor,
+ query_start_loc_cpu: torch.Tensor,
+ query_start_loc_device: torch.Tensor,
+ num_decode_tokens: int) -> FlashMLADecodeMetadata:
+ tile_scheduler_metadata, num_splits = \
+ get_mla_metadata(
+ seq_lens_device,
+ self.num_q_heads,
+ 1, # MQA for the decode path
+ )
+
+ # TODO: we can disambiguate between decode and mixed-prefill decode here
+ # so we can only use the persistent buffer if a cudagraph is actually
+ # being used.
+ if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
+ assert self.cg_buf_tile_scheduler_metadata is not None
+ assert self.cg_buf_num_splits is not None
+
+ sm_parts = tile_scheduler_metadata.size(0)
+ # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
+ assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
+ tile_scheduler_metadata_view = \
+ self.cg_buf_tile_scheduler_metadata[:sm_parts]
+ tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
+ tile_scheduler_metadata = tile_scheduler_metadata_view
+
+ # Num splits is per-batch, varying size (batch_size,)
+ n = num_splits.size(0)
+ # make sure static buffer is large enough
+ assert n <= self.cg_buf_num_splits.size(0)
+ num_splits_view = self.cg_buf_num_splits[:n]
+ num_splits_view.copy_(num_splits)
+ # Num splits needs to monotonically increasing
+ # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
+ # it needs to monotonically increasing by 1)
+ self.cg_buf_num_splits[n:].fill_(num_splits[-1])
+ num_splits = num_splits_view
+
+ return FlashMLADecodeMetadata(
+ block_table=block_table_tensor,
+ seq_lens=seq_lens_device,
+ tile_scheduler_metadata=tile_scheduler_metadata,
+ num_splits=num_splits,
+ )
+
+
+class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
+
+ can_return_lse_for_decode: bool = True
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: int,
+ alibi_slopes: Optional[list[float]],
+ sliding_window: Optional[int],
+ kv_cache_dtype: str,
+ logits_soft_cap: Optional[float],
+ attn_type: str,
+ kv_sharing_target_layer_name: Optional[str],
+ # MLA Specific Arguments
+ **mla_args) -> None:
+ super().__init__(num_heads, head_size, scale, num_kv_heads,
+ alibi_slopes, sliding_window, kv_cache_dtype,
+ logits_soft_cap, attn_type,
+ kv_sharing_target_layer_name, **mla_args)
+
+ is_supported, reason = is_flashmla_supported()
+ assert is_supported, reason
+
+ unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
+ if any(unsupported_features):
+ raise NotImplementedError(
+ "FlashMLAImpl does not support one of the following: "
+ "alibi_slopes, sliding_window, logits_soft_cap")
+
+ if attn_type != AttentionType.DECODER:
+ raise NotImplementedError("Encoder self-attention and "
+ "encoder/decoder cross-attention "
+ "are not implemented for "
+ "FlashMLAImpl")
+
+ def _forward_decode(
+ self,
+ q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
+ kv_c_and_k_pe_cache: torch.Tensor,
+ attn_metadata: FlashMLAMetadata,
+ layer: AttentionLayer,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # TODO: (zyongye) decode function for mla here
+ assert kv_c_and_k_pe_cache.numel() > 0
+ assert attn_metadata.decode is not None
+
+ if type(q) is tuple:
+ q = torch.cat(q, dim=-1)
+
+ assert isinstance(q, torch.Tensor)
+ o, lse = flash_mla_with_kvcache(
+ q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
+ k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
+ block_table=attn_metadata.decode.block_table,
+ cache_seqlens=attn_metadata.decode.seq_lens,
+ head_dim_v=self.kv_lora_rank,
+ tile_scheduler_metadata=attn_metadata.decode.
+ tile_scheduler_metadata,
+ num_splits=attn_metadata.decode.num_splits,
+ softmax_scale=self.scale,
+ causal=True,
+ descale_q=layer._q_scale.reshape(1),
+ descale_k=layer._k_scale.reshape(1),
+ )
+
+ return o, lse
diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
new file mode 100644
index 0000000..3a3a4b5
--- /dev/null
+++ b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
@@ -0,0 +1,752 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, ClassVar, Optional
+
+import numpy as np
+import torch
+
+from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
+ AttentionMetadata)
+from vllm.attention.backends.utils import get_mla_dims
+from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill,
+ flash_mla_with_kvcache,
+ get_mla_metadata,
+ kunlun_flash_mla_with_kvcache)
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.triton_utils import tl, triton
+from vllm.utils import cdiv
+from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
+from vllm.v1.attention.backends.utils import (AttentionCGSupport,
+ AttentionMetadataBuilder,
+ CommonAttentionMetadata,
+ reshape_attn_output_for_spec_decode,
+ reshape_query_for_spec_decode,
+ split_decodes_and_prefills)
+from vllm.v1.kv_cache_interface import AttentionSpec
+from vllm.distributed import get_tp_group
+
+if TYPE_CHECKING:
+ from vllm.model_executor.models.deepseek_v2 import Indexer
+
+logger = init_logger(__name__)
+"""
+NOTE: FlashMLA Sparse uses an fp8 cache with the following format
+
+In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
+structured as:
+- **First 512 bytes:** The "quantized NoPE" part, containing 512
+ `float8_e4m3` values.
+- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
+ The first `float32` is the scale for the first 128 `float8_e4m3` values,
+ the second for the next 128, and so on.
+- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
+ part is not quantized for accuracy.
+"""
+
+
+def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
+ # Convert base-2 LSE to natural-log LSE
+ # Keep FP32 for numerical stability during the merge.
+ return (lse_base2.to(torch.float32) * math.log(2.0))
+
+
+class FlashMLASparseBackend(AttentionBackend):
+
+ accept_output_buffer: bool = True
+
+ @staticmethod
+ def get_name() -> str:
+ return "FLASHMLA_SPARSE_VLLM_V1"
+
+ @staticmethod
+ def get_metadata_cls() -> type[AttentionMetadata]:
+ return FlashMLASparseMetadata
+
+ @staticmethod
+ def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
+ return FlashMLASparseMetadataBuilder
+
+ @staticmethod
+ def get_impl_cls() -> type["FlashMLASparseImpl"]:
+ return FlashMLASparseImpl
+
+ @staticmethod
+ def get_kv_cache_shape(
+ num_blocks: int,
+ block_size: int,
+ num_kv_heads: int, # assumed to be 1 for MLA
+ head_size: int,
+ cache_dtype_str: str = "auto",
+ ) -> tuple[int, ...]:
+ if cache_dtype_str == "fp8_ds_mla":
+ # custom storage fromat is 656 bytes
+ # see FlashMLA readme.md for details
+ return (num_blocks, block_size, 656)
+ else:
+ return (num_blocks, block_size, head_size)
+
+ @classmethod
+ def get_supported_dtypes(cls) -> list[torch.dtype]:
+ return [torch.bfloat16]
+
+ @classmethod
+ def get_supported_head_sizes(cls) -> list[int]:
+ return [576]
+
+
+@dataclass
+class MLASparsePrefillMetadata:
+ # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
+ # the kernel is not from flashmla
+ block_table: torch.Tensor = None
+ has_context: bool = False
+ context_lens: Optional[torch.Tensor] = None
+
+ # Sequence lengths (context + query) for prefill requests
+ # Shape: [num_prefill_reqs]
+ seq_lens: torch.Tensor = None
+
+ # Request ID for each token: -1 for decode tokens, request index
+ # (0, 1, 2, ...) for prefill tokens.
+ # Shape: [num_actual_tokens]
+ request_ids: torch.Tensor = None
+ query_start_loc: torch.Tensor = None
+ query_start_loc_cpu: torch.Tensor = None
+
+@dataclass
+class FlashMLASparseDecodeAndContextMetadata:
+ scheduler_metadata: torch.Tensor = None
+ num_splits: torch.Tensor = None
+ cache_lens: torch.Tensor = None
+ prefill_context_lengths: Optional[torch.Tensor] = None
+ prefill_new_k_start_locs: Optional[torch.Tensor] = None
+ dummy_block_table: torch.Tensor = None
+
+ seq_lens: torch.Tensor = None
+ seq_lens_cpu: torch.Tensor = None
+ max_seq_len: int = -1 # needed for reshape in spec decode
+
+ def filter_prefill_indices(
+ self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ assert self.prefill_context_lengths is not None
+ prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
+ context_indices = torch.where(indices < prefill_context_lengths,
+ indices, -1)
+ new_token_indices = torch.where(indices >= prefill_context_lengths,
+ indices - prefill_context_lengths, -1)
+ return context_indices, new_token_indices
+
+
+@dataclass
+class FlashMLASparseMetadata:
+ num_reqs: int
+ max_query_len: int
+ max_seq_len: int
+
+ num_actual_tokens: int # Number of tokens excluding padding.
+ query_start_loc: torch.Tensor
+ slot_mapping: torch.Tensor
+
+ block_table: torch.Tensor
+ req_id_per_token: torch.Tensor
+ block_size: int = 64
+ topk_tokens: int = 2048
+
+ num_prefills: int = 0
+ num_decodes: int = 0
+ num_prefill_tokens: int = 0
+ num_decode_tokens: int = 0
+
+ decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None
+ prefill_metadata: Optional[MLASparsePrefillMetadata] = None
+
+ @dataclass
+ class FP8KernelMetadata:
+ scheduler_metadata: Optional[torch.Tensor]
+ num_splits: torch.Tensor
+ dummy_block_table: torch.Tensor
+ cache_lens: torch.Tensor
+
+ fp8_extra_metadata: Optional[FP8KernelMetadata] = None
+
+
+@triton.jit
+def _convert_req_index_to_global_index_kernel(
+ req_id_ptr, # int32 [num_tokens]
+ block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
+ token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
+ out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
+ # shapes (compile-time where possible)
+ max_num_blocks_per_req: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ BLOCK_N: tl.constexpr, # tile width along columns
+ # strides (in elements)
+ bt_stride0,
+ bt_stride1,
+ ti_stride0,
+ ti_stride1,
+ out_stride0,
+ out_stride1,
+):
+ # program_id(0) -> token_id (row)
+ # program_id(1) -> tile index along columns
+ token_id = tl.program_id(0)
+ tile_id = tl.program_id(1)
+
+ # Each program covers BLOCK_N consecutive columns
+ indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ # Load request id for this token (no mask: grid is exact)
+ req = tl.load(req_id_ptr + token_id)
+
+ # Load token indices for this tile
+ ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
+ tok = tl.load(ti_ptr) # int32
+
+ # Only token == -1 should propagate as -1
+ is_invalid_tok = tok < 0
+
+ # Compute block id and in-block offset
+ block_id = tok // BLOCK_SIZE
+ inblock_off = tok % BLOCK_SIZE
+
+ # Guard block_table access
+ valid_block = block_id < max_num_blocks_per_req
+ bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
+ base = tl.load(bt_ptr, mask=valid_block, other=0)
+
+ # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
+ out_val = tl.where(is_invalid_tok | (~valid_block), -1,
+ base * BLOCK_SIZE + inblock_off)
+
+ # Store results
+ out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
+ tl.store(out_ptr_ij, out_val)
+
+
+def triton_convert_req_index_to_global_index(
+ req_id: torch.Tensor, # int32 [num_tokens]
+ block_table: torch.
+ Tensor, # int32 [num_requests, max_num_blocks_per_req]
+ token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
+ BLOCK_SIZE: int = 64,
+ NUM_TOPK_TOKENS: int = 2048,
+ BLOCK_N: int = 128, # tile width along columns
+):
+ """
+ out[token_id, indice_id] =
+ block_table[req_id[token_id],
+ token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
+ + token_indices[token_id, indice_id] % BLOCK_SIZE
+
+ Only when token_indices[token_id, indice_id] == -1 do we output -1.
+ For safety, we also output -1 if the derived block_id would be
+ out-of-bounds.
+ """
+ assert req_id.dtype == torch.int32
+ assert block_table.dtype == torch.int32
+ assert token_indices.dtype == torch.int32
+ assert token_indices.shape[1] == NUM_TOPK_TOKENS
+ assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
+ f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
+ f"BLOCK_N ({BLOCK_N})"
+
+ num_tokens = req_id.shape[0]
+ num_requests, max_num_blocks_per_req = block_table.shape
+ tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
+
+ # Ensure contiguous tensors on the same device
+ req_id_c = req_id.contiguous()
+ block_table_c = block_table.contiguous()
+ token_indices_c = token_indices.contiguous()
+ out = torch.empty_like(token_indices_c)
+
+ # Strides in elements
+ bt_stride0, bt_stride1 = block_table_c.stride()
+ ti_stride0, ti_stride1 = token_indices_c.stride()
+ out_stride0, out_stride1 = out.stride()
+
+ # Exact 2D grid: tokens × column tiles
+ grid = (num_tokens, tiles_per_row)
+
+ _convert_req_index_to_global_index_kernel[grid](
+ req_id_c,
+ block_table_c,
+ token_indices_c,
+ out,
+ # shapes / constexprs
+ max_num_blocks_per_req,
+ BLOCK_SIZE,
+ BLOCK_N,
+ # strides
+ bt_stride0,
+ bt_stride1,
+ ti_stride0,
+ ti_stride1,
+ out_stride0,
+ out_stride1,
+ )
+ return out
+
+def kunlun_convert_req_index_to_global_index(
+ req_id: torch.Tensor, # int32 [num_tokens]
+ block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
+ token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
+ BLOCK_SIZE: int = 64,
+ NUM_TOPK_TOKENS: int = 2048,
+):
+ assert req_id.dtype == torch.int32
+ assert block_table.dtype == torch.int32
+ assert token_indices.dtype == torch.int32
+ assert token_indices.shape[1] == NUM_TOPK_TOKENS
+
+ num_tokens = req_id.shape[0]
+ num_requests, max_num_blocks_per_req = block_table.shape
+
+ out = torch.zeros_like(token_indices)
+
+ # Compute block_id and inblock_off for all tokens at once
+ block_id = token_indices // BLOCK_SIZE
+ inblock_off = token_indices % BLOCK_SIZE
+
+ # Create mask for invalid tokens (tok < 0)
+ invalid_tok_mask = token_indices < 0
+
+ # Create mask for out-of-bounds block_id
+ oob_block_mask = block_id >= max_num_blocks_per_req
+
+ # Combine masks - output -1 for either condition
+ invalid_mask = invalid_tok_mask | oob_block_mask
+
+ # Get request IDs expanded to match token_indices shape
+ req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS)
+
+ # Gather base addresses from block_table
+ # Clamp block_id to avoid index errors (we'll mask these out anyway)
+ block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1)
+
+ # Use advanced indexing to get base addresses
+ base_addrs = block_table[req_ids_expanded, block_id_clamped]
+
+ # Compute the global indices
+ global_indices = base_addrs * BLOCK_SIZE + inblock_off
+
+ # Apply mask: set invalid positions to -1
+ out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices)
+
+ return out
+
+def kunlun_concat_and_cache_mla(
+ kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
+ k_pe: torch.Tensor, #[num_tokens, pe_dim]
+ kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
+ slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
+ kv_cache_dtype: str,
+ scale: torch.Tensor
+):
+ num_tokens = slot_mapping.shape[0]
+ kv_lora_rank = kv_c.shape[1]
+ pe_dim = k_pe.shape[1]
+ block_size = kv_cache.shape[1]
+
+ def kunlun_fp8_ds_mla():
+ for token_idx in range(num_tokens):
+ slot = slot_mapping[token_idx].item()
+ if slot < 0: continue
+ block_idx = slot // block_size
+ block_offset = slot % block_size
+ kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous()
+ kv_c_i_int8 = torch.zeros(
+ kv_c_i.shape,
+ device=kv_c.device,
+ dtype=torch.int8,
+ )
+ kv_c_i_scale = torch.zeros(
+ [kv_c_i.shape[0], 1],
+ device=kv_c.device,
+ dtype=torch.float32,
+ )
+ torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True)
+ kv_c_i_scale /= 127
+ kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous()
+ kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous()
+ kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous()
+
+ def kunlun_mla():
+ for token_idx in range(num_tokens):
+ slot = slot_mapping[token_idx].item()
+ if slot < 0: continue
+ block_idx = slot // block_size
+ block_offset = slot % block_size
+ kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous()
+ kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous()
+
+ if (kv_cache_dtype == "fp8_ds_mla"):
+ assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"
+ assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"
+ assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla"
+ assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla"
+ assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla"
+ kunlun_fp8_ds_mla()
+ else:
+ assert kv_cache.shape[2] == kv_lora_rank + pe_dim
+ kunlun_mla()
+
+
+@dataclass
+class FlashMLASparseMetadataBuilder(
+ AttentionMetadataBuilder[FlashMLASparseMetadata]):
+ cudagraph_support: ClassVar[AttentionCGSupport] = \
+ AttentionCGSupport.UNIFORM_BATCH
+
+ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
+ vllm_config: VllmConfig, device: torch.device):
+ self.vllm_config = vllm_config
+ self.layer_names = layer_names
+ cache_config = vllm_config.cache_config
+ self.kv_cache_spec = kv_cache_spec
+ self.model_config = vllm_config.model_config
+ parallel_config = vllm_config.parallel_config
+ self.device = device
+
+ # Treat requests with query length <= 1 as decodes to match the
+ # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
+ # 从最新版本vllm中引入的
+ self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
+
+ props = torch.cuda.get_device_properties(device)
+ sm_count = props.multi_processor_count
+
+ self.num_heads = self.model_config.get_num_attention_heads(
+ parallel_config)
+ self.mla_dims = get_mla_dims(self.model_config)
+ self.topk_tokens = vllm_config.model_config.hf_config.index_topk
+ self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
+
+ self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
+ device=device,
+ dtype=torch.int32)
+ # self.max_model_len_tensor = torch.tensor(
+ # [self.model_config.max_model_len],
+ # device=device,
+ # dtype=torch.int32)
+
+ # this is ignored by `flash_mla_with_kvcache` if indices not None
+ self.dummy_block_table = torch.empty((1, 1),
+ dtype=torch.int32,
+ device=self.device)
+
+ # Equation taken from FlashMLA/csrc/pybind.cpp
+ h_q, h_k = self.num_heads, 1
+ s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
+ max_num_sm_parts = int(
+ max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
+ if current_platform.is_device_capability(100):
+ max_num_sm_parts *= 2
+ self.tile_scheduler_metadata_buffer = torch.zeros(
+ # TileSchedulerMetaDataSize = 8
+ # see: FlashMLA/csrc/params.h
+ (max_num_sm_parts, 8),
+ dtype=torch.int32,
+ device=device)
+ self.num_splits_buffer = torch.zeros(
+ # We pack all the tokens into one batch for sparse attention.
+ # Otherwise, we can exceed the sm of `get_mla_metadata`.
+ (
+ 2, ),
+ dtype=torch.int32,
+ device=device)
+ self.req_id_per_token_buffer = torch.zeros(
+ (vllm_config.scheduler_config.max_num_batched_tokens, ),
+ dtype=torch.int32,
+ device=device)
+ def build(self,
+ common_prefix_len: int,
+ common_attn_metadata: CommonAttentionMetadata,
+ fast_build: bool = False) -> FlashMLASparseMetadata:
+
+ num_tokens = common_attn_metadata.num_actual_tokens
+ starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
+ dtype=np.int32)
+ seg_lengths = np.diff(starts)
+ req_id_per_token = np.repeat(
+ np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
+ # Zero-fill for cudagraphs
+ self.req_id_per_token_buffer.fill_(0)
+ self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
+ .copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
+ req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
+
+ fp8_extra_metadata = None
+
+ if self.use_fp8_kv_cache:
+ cache_seqlens_cpu, cache_seqlens = get_mla_metadata(
+ cache_seqlens=self.topk_tokens_tensor,
+ )
+ fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
+ scheduler_metadata=None,
+ num_splits=None,
+ # cache_lens and block_table are basically unused in sparse case
+ # but the decode kernel will treat -1 and indices >= cache_lens
+ # as invalid so we make sure cache_lens is large enough to not
+ # accidentally mark indices invalid, we will use -1 exclusively
+ # to mark invalid indices
+ cache_lens=cache_seqlens_cpu,
+ dummy_block_table=self.dummy_block_table)
+
+ (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
+ split_decodes_and_prefills(
+ common_attn_metadata,
+ decode_threshold=self.reorder_batch_threshold or 1,
+ require_uniform=True,
+ )
+ )
+
+ # For pure decode batches, prefill_request_id will be None
+ # For mixed batches, it will have -1 for decode and request_id for prefill
+ prefill_metadata = None
+ if num_prefills > 0:
+ prefill_metadata = MLASparsePrefillMetadata(
+ query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值
+ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes],
+ )
+
+ decode_metadata = None
+ if num_decodes > 0:
+ max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max())
+
+ decode_metadata = FlashMLASparseDecodeAndContextMetadata(
+ max_seq_len=max_seq_len,
+ seq_lens=common_attn_metadata.seq_lens[:num_decodes],
+ seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes],
+ )
+
+
+ metadata = FlashMLASparseMetadata(
+ num_reqs=common_attn_metadata.num_reqs,
+ max_query_len=common_attn_metadata.max_query_len,
+ max_seq_len=common_attn_metadata.max_seq_len,
+ num_actual_tokens=common_attn_metadata.num_actual_tokens,
+ query_start_loc=common_attn_metadata.query_start_loc,
+ slot_mapping=common_attn_metadata.slot_mapping,
+ block_table=common_attn_metadata.block_table_tensor,
+ req_id_per_token=req_id_per_token,
+ block_size=self.kv_cache_spec.block_size,
+ topk_tokens=self.topk_tokens,
+ fp8_extra_metadata=fp8_extra_metadata,
+ num_prefills=num_prefills,
+ num_decodes=num_decodes,
+ num_prefill_tokens=num_prefill_tokens,
+ num_decode_tokens=num_decode_tokens,
+ decode_metadata=decode_metadata,
+ prefill_metadata=prefill_metadata
+ )
+ return metadata
+
+
+class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ scale: float,
+ num_kv_heads: int,
+ alibi_slopes: Optional[list[float]],
+ sliding_window: Optional[int],
+ kv_cache_dtype: str,
+ logits_soft_cap: Optional[float],
+ attn_type: str,
+ kv_sharing_target_layer_name: Optional[str],
+ # MLA Specific Arguments
+ topk_indice_buffer: Optional[torch.Tensor] = None,
+ indexer: Optional["Indexer"] = None,
+ **mla_args) -> None:
+ super().__init__(num_heads, head_size, scale, num_kv_heads,
+ alibi_slopes, sliding_window, kv_cache_dtype,
+ logits_soft_cap, attn_type,
+ kv_sharing_target_layer_name, **mla_args)
+ self.softmax_scale = scale
+ assert indexer is not None
+ self.topk_indices_buffer = indexer.topk_indices_buffer
+ self.padding = 128 if current_platform.is_device_capability(
+ 100) else 64
+
+ def _forward_bf16_kv(
+ self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
+
+ num_tokens = q.shape[0]
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view(
+ -1, kv_c_and_k_pe_cache.shape[-1])
+
+ # num_decode_tokens = attn_metadata.num_decode_tokens
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_decodes = attn_metadata.num_decodes
+
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+
+ def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
+ # Reshape q: (num_decode_tokens, num_heads, head_dim)
+ # -> (num_decodes, seq_len, num_heads, head_dim)
+ q = reshape_query_for_spec_decode(q, num_decodes)
+ seq_len = q.shape[1]
+ # Reshape topk_indices: (num_decode_tokens, topk)
+ # -> (num_decodes, seq_len, topk)
+ topk_indices = topk_indices.view(num_decodes, seq_len, -1)
+ decode_metadata = attn_metadata.decode_metadata
+ _attn_out, _, _ = kunlun_flash_mla_with_kvcache(
+ q=q,
+ k_cache=kv_c_and_k_pe_cache,
+ head_dim_v=512,
+ cache_seqlens=decode_metadata.seq_lens,
+ cache_seqlens_cpu=decode_metadata.seq_lens_cpu,
+ is_fp8_kvcache=False,
+ indices=topk_indices,
+ softmax_scale=self.softmax_scale,
+ max_seq_kv=decode_metadata.max_seq_len
+ )
+ # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
+ # -> (num_decode_tokens, num_heads, head_dim_v)
+ return reshape_attn_output_for_spec_decode(_attn_out)
+
+ def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
+ prefill_metadata = attn_metadata.prefill_metadata
+ topk_indices = topk_indices.view(num_prefill_tokens, 1, -1)
+ # NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的
+ _attn_out = flash_mla_sparse_prefill(
+ q=q,
+ kv=kv_c_and_k_pe_cache,
+ indices=topk_indices,
+ sm_scale=self.softmax_scale,
+ q_lod_xpu=prefill_metadata.query_start_loc,
+ q_lod_cpu=prefill_metadata.query_start_loc_cpu
+ )[0]
+ return _attn_out
+
+ topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index(
+ req_id=attn_metadata.req_id_per_token,
+ block_table=attn_metadata.block_table,
+ token_indices=topk_indices,
+ block_size=attn_metadata.block_size,
+ num_topk_tokens=attn_metadata.topk_tokens,
+ )
+
+ attn_out = torch.empty(
+ (num_tokens, self.num_heads, self.kv_lora_rank),
+ dtype=q.dtype,
+ device=q.device,
+ )
+ if has_prefill:
+ prefill_q = q[num_decode_tokens:]
+ prefill_topk_indices_global = topk_indices_global[num_decode_tokens:]
+ attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global)
+
+ # 处理decode部分 - 需要正确的block table映射print
+ if has_decode:
+ decode_q = q[:num_decode_tokens]
+ decode_topk_indices_global = topk_indices_global[:num_decode_tokens]
+ attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global)
+
+ return attn_out
+
+
+ def _forward_fp8_kv(self, q: torch.Tensor,
+ kv_c_and_k_pe_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
+ # TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function.
+ assert attn_metadata.fp8_extra_metadata is not None
+ extra_metadata = attn_metadata.fp8_extra_metadata
+
+ _attn_out, _ = flash_mla_with_kvcache(
+ q=q.unsqueeze(0), # unsqueeze to add batch_dim
+ k_cache=kv_c_and_k_pe_cache,
+ block_table=extra_metadata.dummy_block_table,
+ head_dim_v=512,
+ cache_seqlens=extra_metadata.cache_lens,
+ tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None
+ num_splits=extra_metadata.num_splits, # None
+ is_fp8_kvcache=True,
+ indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
+ softmax_scale=self.softmax_scale,
+ max_seq_kv=attn_metadata.max_seq_len
+ )
+
+ return _attn_out
+
+ def forward(
+ self,
+ layer: AttentionLayer,
+ q: torch.Tensor,
+ k_c_normed: torch.Tensor, # key in unified attn
+ k_pe: torch.Tensor, # value in unified attn
+ kv_cache: torch.Tensor,
+ attn_metadata: FlashMLASparseMetadata,
+ output: Optional[torch.Tensor] = None,
+ output_scale: Optional[torch.Tensor] = None,
+ output_block_scale: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
+ # MQA 576/512 approach for both prefill and decode
+
+ assert output is not None, "Output tensor must be provided."
+
+ if output_scale is not None or output_block_scale is not None:
+ raise NotImplementedError(
+ "fused output quantization is not yet supported"
+ " for MLACommonImpl")
+
+ if attn_metadata is None:
+ # The zero fill is required when used with DP + EP
+ # to ensure all ranks within a DP group compute the
+ # same expert outputs.
+ return output.fill_(0)
+
+ num_actual_toks = attn_metadata.num_actual_tokens
+
+ # Inputs and outputs may be padded for CUDA graphs
+
+ q = q[:num_actual_toks, ...]
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
+ k_pe = k_pe[:num_actual_toks, ...]
+
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
+ dim=-1)
+ # Convert from (B, N, P) to (N, B, P)
+ q_nope = q_nope.transpose(0, 1)
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
+ # Convert from (N, B, L) to (B, N, L)
+ ql_nope = ql_nope.transpose(0, 1)
+
+ topk_indices = self.topk_indices_buffer[:num_actual_toks]
+
+ q = torch.cat([ql_nope, q_pe], dim=-1)
+
+ # write the latent and rope to kv cache
+ if kv_cache.numel() > 0:
+ torch.ops._C.concat_and_cache_mla(
+ kv_c=k_c_normed,
+ k_pe=k_pe.squeeze(1),
+ kv_cache=kv_cache,
+ slot_mapping=attn_metadata.slot_mapping.flatten(),
+ )
+
+ if self.kv_cache_dtype != "fp8_ds_mla":
+ attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
+ attn_metadata)
+ else:
+ # attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
+ # attn_metadata)
+ raise NotImplementedError
+
+ self._v_up_proj(attn_out, out=output[:num_actual_toks])
+ return output
diff --git a/vllm_kunlun/v1/attention/backends/mla/indexer.py b/vllm_kunlun/v1/attention/backends/mla/indexer.py
new file mode 100644
index 0000000..67471be
--- /dev/null
+++ b/vllm_kunlun/v1/attention/backends/mla/indexer.py
@@ -0,0 +1,133 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+from dataclasses import dataclass
+from typing import ClassVar, Optional
+from vllm.logger import init_logger
+from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
+ split_decodes_and_prefills)
+from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
+ DeepseekV32IndexerMetadataBuilder,
+ DeepseekV32IndexerPrefillMetadata)
+
+logger = init_logger(__name__)
+
+@dataclass
+class DeepSeekV32IndexerDecodeMetadata:
+ block_table: torch.Tensor
+ seq_lens: torch.Tensor
+ seq_lens_cpu: torch.Tensor
+ decode_lens: torch.Tensor
+ requires_padding: bool
+ schedule_metadata: torch.Tensor
+
+
+@dataclass
+class DeepseekV32IndexerMetadata:
+
+ # FIXME (zyongye)
+ # hacky way to access the data now, need to be in chunked meta
+ seq_lens: torch.Tensor
+ seq_lens_cpu: torch.Tensor
+
+ num_reqs: int
+ max_query_len: int
+ max_seq_len: int
+
+ num_actual_tokens: int # Number of tokens excluding padding.
+ query_start_loc: torch.Tensor
+ slot_mapping: torch.Tensor
+ # The dimension of the attention heads
+ head_dim: int
+
+ # New for MLA (compared to FlashAttention)
+ # For handling prefill decode split
+ num_decodes: int
+ num_decode_tokens: int
+ num_prefills: int
+ num_prefill_tokens: int
+
+ decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
+ prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
+
+def kunlun_build(self,
+ common_prefix_len: int,
+ common_attn_metadata: CommonAttentionMetadata,
+ fast_build: bool = False) -> DeepseekV32IndexerMetadata:
+
+ num_reqs = common_attn_metadata.num_reqs
+ num_tokens = common_attn_metadata.num_actual_tokens
+
+ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
+ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
+ split_decodes_and_prefills(
+ common_attn_metadata,
+ decode_threshold=self.reorder_batch_threshold)
+
+ assert num_decodes + num_prefills == num_reqs
+ assert num_decode_tokens + num_prefill_tokens == num_tokens
+
+ prefill_metadata = None
+ if num_prefills > 0:
+ chunk_seq_ids = split_prefill_chunks(
+ common_attn_metadata.seq_lens_cpu,
+ self.max_prefill_buffer_size,
+ num_decodes,
+ )
+ chunks = [
+ self.build_one_prefill_chunk(
+ reqs_start, reqs_end, query_start_loc_cpu,
+ common_attn_metadata.seq_lens_cpu,
+ common_attn_metadata.block_table_tensor)
+ for reqs_start, reqs_end in chunk_seq_ids
+ ]
+ prefill_metadata = DeepseekV32IndexerPrefillMetadata(
+ chunks=chunks, )
+
+ decode_metadata = None
+ if num_decodes > 0:
+ torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
+ out=self.decode_lens_buffer[:num_decodes])
+ decode_lens = self.decode_lens_buffer[:num_decodes]
+ decode_lens_cpu = torch.diff(
+ common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
+
+ # Use CPU to avoid GPU sync; breaking async scheduling
+ requires_padding = (decode_lens_cpu.max()
+ > decode_lens_cpu.min()).item()
+
+ seq_lens = common_attn_metadata.seq_lens[:num_decodes]
+
+ decode_metadata = DeepSeekV32IndexerDecodeMetadata(
+ block_table=common_attn_metadata.
+ block_table_tensor[:num_decodes, ...],
+ seq_lens=common_attn_metadata.seq_lens[:num_decodes],
+ seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
+ decode_lens=decode_lens,
+ requires_padding=requires_padding,
+ schedule_metadata=self.scheduler_metadata_buffer,
+ )
+
+ attn_metadata = DeepseekV32IndexerMetadata(
+ seq_lens=common_attn_metadata.seq_lens,
+ seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
+ num_reqs=common_attn_metadata.num_reqs,
+ max_query_len=common_attn_metadata.max_query_len,
+ max_seq_len=common_attn_metadata.max_seq_len,
+ num_actual_tokens=common_attn_metadata.num_actual_tokens,
+ query_start_loc=common_attn_metadata.query_start_loc,
+ slot_mapping=common_attn_metadata.slot_mapping,
+ head_dim=128,
+ num_decodes=num_decodes,
+ num_decode_tokens=num_decode_tokens,
+ num_prefills=num_prefills,
+ num_prefill_tokens=num_prefill_tokens,
+ prefill=prefill_metadata,
+ decode=decode_metadata,
+ )
+
+ # if get_tensor_model_parallel_rank() == 0:
+ # logger.info(f"attn_metadata: {attn_metadata}")
+ return attn_metadata
+
+DeepseekV32IndexerMetadataBuilder.build = kunlun_build
\ No newline at end of file
diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py
index 08a33f9..b080524 100644
--- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py
+++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
-import os
+
import torch
import torch.nn as nn
from packaging import version
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
def __init__(self, logprobs_mode):
super().__init__()
+ self.logprobs_mode = logprobs_mode
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_kunlun
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
- logits = apply_top_k_top_p(logits, k, p)
+ logits = self.apply_top_k_top_p(logits, k, p)
+ logits_to_return = None
+ if self.logprobs_mode == "processed_logits":
+ logits_to_return = logits
+ elif self.logprobs_mode == "processed_logprobs":
+ logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
- return random_sample(probs, generators), None
+ return random_sample(probs, generators), logits_to_return
def forward_kunlun(
self,
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
- if k is None and p is None:
- # We prefer `random_sample` over `flashinfer_sample` when sorting is
- # not needed. This is because `random_sample` does not require
- # CPU-GPU synchronization while `flashinfer_sample` does.
- probs = logits.softmax(dim=-1, dtype=torch.float32)
- return random_sample(probs, generators), None
- if generators:
- logger.warning_once("FlashInfer 0.2.3+ does not support "
- "per-request generators. Falling back to "
- "PyTorch-native implementation.")
+ if (k is None and p is None) or generators:
+ if generators:
+ logger.debug_once(
+ "FlashInfer 0.2.3+ does not support "
+ "per-request generators. Falling back to "
+ "PyTorch-native implementation."
+ )
return self.forward_native(logits, generators, k, p)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
@@ -196,6 +199,7 @@ def flashinfer_sample(
probs, top_k=k, deterministic=True)
else:
# Both top-k and top-p.
+ k = k.to(torch.int32)
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)
diff --git a/vllm_kunlun/v1/worker/utils.py b/vllm_kunlun/v1/worker/utils.py
new file mode 100644
index 0000000..e31f887
--- /dev/null
+++ b/vllm_kunlun/v1/worker/utils.py
@@ -0,0 +1,344 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Optional
+
+import torch
+
+from vllm.attention.backends.abstract import AttentionBackend
+from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
+from vllm.model_executor.models.interfaces import MultiModalEmbeddings
+from vllm.model_executor.models.utils import extract_layer_index
+from vllm.multimodal.cache import processor_only_cache_from_config
+from vllm.multimodal.registry import MultiModalRegistry
+from vllm.platforms import current_platform
+from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
+from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
+from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
+
+if TYPE_CHECKING:
+ from vllm.attention.layer import Attention
+
+
+class MultiModalBudget:
+ """Helper class to calculate budget information for multi-modal models."""
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ scheduler_config: SchedulerConfig,
+ mm_registry: MultiModalRegistry,
+ ) -> None:
+ super().__init__()
+
+ self.model_config = model_config
+ self.scheduler_config = scheduler_config
+ self.mm_registry = mm_registry
+ self.cache = cache = processor_only_cache_from_config(
+ model_config, mm_registry)
+
+ self.max_model_len = model_config.max_model_len
+ self.max_num_reqs = scheduler_config.max_num_seqs
+
+ self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
+ cache=cache)
+
+ max_tokens_by_modality = mm_registry \
+ .get_max_tokens_per_item_by_nonzero_modality(model_config,
+ cache=cache)
+
+ encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
+ scheduler_config,
+ max_tokens_by_modality,
+ )
+
+ self.encoder_compute_budget = encoder_compute_budget
+ self.encoder_cache_size = encoder_cache_size
+
+ max_items_per_prompt_by_modality = dict[str, int]()
+ max_items_per_batch_by_modality = dict[str, int]()
+
+ for modality, max_tokens in max_tokens_by_modality.items():
+ (
+ max_items_per_prompt,
+ max_items_per_batch,
+ ) = self.get_max_items(modality, max_tokens)
+
+ max_items_per_prompt_by_modality[modality] = max_items_per_prompt
+ max_items_per_batch_by_modality[modality] = max_items_per_batch
+
+ self.max_tokens_by_modality = max_tokens_by_modality
+ self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
+ self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
+
+ def get_modality_with_max_tokens(self) -> str:
+ max_tokens_by_modality = self.max_tokens_by_modality
+ modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
+
+ return modality
+
+ def get_encoder_budget(self) -> int:
+ return min(self.encoder_compute_budget, self.encoder_cache_size)
+
+ def get_max_items(
+ self,
+ modality: str,
+ max_tokens_per_item: int,
+ ) -> tuple[int, int]:
+ if max_tokens_per_item == 0:
+ return 0, 0
+
+ # Check how many items of this modality can be supported by
+ # the encoder budget.
+ encoder_budget = self.get_encoder_budget()
+
+ # TODO: handle encoder-decoder models once we support them.
+ if encoder_budget == 0:
+ return 0, 0
+
+ max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
+
+ # Check how many items of this modality can be supported by
+ # the decoder budget.
+ mm_limit = self.mm_limits[modality]
+
+ max_items_per_prompt = max(
+ 1,
+ min(mm_limit, self.max_model_len // max_tokens_per_item),
+ )
+
+ scheduler_config = self.scheduler_config
+ max_num_reqs = self.max_num_reqs
+
+ if not scheduler_config.enable_chunked_prefill:
+ max_num_reqs = min(
+ max_num_reqs,
+ scheduler_config.max_num_batched_tokens // max_tokens_per_item,
+ )
+
+ max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
+
+ max_items_per_batch = max(
+ 1,
+ min(max_encoder_items_per_batch, max_decoder_items_per_batch),
+ )
+
+ return max_items_per_prompt, max_items_per_batch
+
+
+@dataclass
+class AttentionGroup:
+ backend: type[AttentionBackend]
+ # When ubatching is enabled we will have a metadata builder for each ubatch
+ # so that if they use internal persistant buffers for cudagraphs, and they
+ # won't have to worry about conflicting with the other ubatches.
+ metadata_builders: list[AttentionMetadataBuilder]
+ layer_names: list[str]
+ kv_cache_spec: KVCacheSpec
+
+ @staticmethod
+ def create_with_metadata_builders(
+ backend: type[AttentionBackend],
+ layer_names: list[str],
+ kv_cache_spec: KVCacheSpec,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ num_metadata_builders: int = 1,
+ ) -> 'AttentionGroup':
+ metadata_builders = [
+ backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
+ device)
+ for _ in range(num_metadata_builders)
+ ]
+ return AttentionGroup(backend, metadata_builders, layer_names,
+ kv_cache_spec)
+
+ def get_metadata_builder(self,
+ ubatch_id: int = 0) -> AttentionMetadataBuilder:
+ assert len(self.metadata_builders) > ubatch_id
+ return self.metadata_builders[ubatch_id]
+
+
+def sanity_check_mm_encoder_outputs(
+ mm_embeddings: MultiModalEmbeddings,
+ expected_num_items: int,
+) -> None:
+ """
+ Perform sanity checks for the result of
+ [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
+ """
+ assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
+ "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
+ f"or a single 3D tensor, but got {type(mm_embeddings)} "
+ "instead. This is most likely due to incorrect implementation "
+ "of the model's `get_multimodal_embeddings` method.")
+
+ assert len(mm_embeddings) == expected_num_items, (
+ "Expected number of multimodal embeddings to match number of "
+ f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
+ "instead. This is most likely due to incorrect implementation "
+ "of the model's `get_multimodal_embeddings` method.")
+
+ assert all(e.ndim == 2 for e in mm_embeddings), (
+ "Expected multimodal embeddings to be a sequence of 2D tensors, "
+ f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
+ "instead. This is most likely due to incorrect implementation "
+ "of the model's `get_multimodal_embeddings` method.")
+
+
+def scatter_mm_placeholders(
+ embeds: torch.Tensor,
+ is_embed: Optional[torch.Tensor],
+) -> torch.Tensor:
+ """
+ Scatter the multimodal embeddings into a contiguous tensor that represents
+ the placeholder tokens.
+
+ [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
+
+ Args:
+ embeds: The multimodal embeddings.
+ Shape: `(num_embeds, embed_dim)`
+ is_embed: A boolean mask indicating which positions in the placeholder
+ tokens need to be filled with multimodal embeddings.
+ Shape: `(num_placeholders, num_embeds)`
+ """
+ if is_embed is None:
+ return embeds
+
+ placeholders = embeds.new_full(
+ (is_embed.shape[0], embeds.shape[-1]),
+ fill_value=torch.nan,
+ )
+ placeholders[is_embed] = embeds
+ return placeholders
+
+
+def gather_mm_placeholders(
+ placeholders: torch.Tensor,
+ is_embed: Optional[torch.Tensor],
+) -> torch.Tensor:
+ """
+ Reconstructs the embeddings from the placeholder tokens.
+
+ This is the operation of [`scatter_mm_placeholders`]
+ [vllm.v1.worker.utils.scatter_mm_placeholders].
+ """
+ if is_embed is None:
+ return placeholders
+
+ return placeholders[is_embed]
+
+
+def add_kv_sharing_layers_to_kv_cache_groups(
+ shared_kv_cache_layers: dict[str, str],
+ kv_cache_groups: list[KVCacheGroupSpec],
+ runner_only_attn_layers: Optional[set[str]] = None,
+) -> None:
+ """
+ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
+ for layers that do not allocate its own KV cache, based on the mapping in
+ `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
+ group, which is needed to ensure that attention metadata is assigned later.
+
+ Args:
+ shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
+ If an Attention layer `layer_name` is in the keys of this dict, it
+ means this layer will perform attention using the keys and values
+ from the KV cache of `shared_kv_cache_layers[layer_name]`.
+ kv_cache_groups: The KV cache groups of the model.
+ """
+ layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
+ for kv_cache_group in kv_cache_groups:
+ for layer_name in kv_cache_group.layer_names:
+ layer_to_kv_cache_group[layer_name] = kv_cache_group
+
+ for layer_name, target_layer_name in shared_kv_cache_layers.items():
+ tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
+ tgt_kv_cache_group.layer_names.append(layer_name)
+
+ if runner_only_attn_layers is not None:
+ runner_only_attn_layers.add(layer_name)
+
+
+def bind_kv_cache(
+ kv_caches: dict[str, torch.Tensor],
+ forward_context: dict[str, "Attention"],
+ runner_kv_caches: list[torch.Tensor],
+ num_attn_module: Optional[int] = 1,
+) -> None:
+ """
+ Bind the allocated KV cache to both ModelRunner and forward context so
+ that the KV cache can be used in the forward pass.
+
+ This function:
+ 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
+ kv_caches.
+ 2) Associates each attention layer in the `forward_context` with its
+ corresponding KV cache in kv_caches.
+
+ Args:
+ kv_caches: The allocated kv_caches with layer names as keys.
+ forward_context: The global forward context containing all Attention
+ layers with layer names as keys.
+ runner_kv_caches: The kv_cache declared by ModelRunner.
+ """
+ # Bind kv_caches to ModelRunner
+ assert len(runner_kv_caches) == 0
+
+ # Convert kv_caches dict to a list of tensors in the order of layer_index.
+ index2name = defaultdict(list)
+ for layer_name in kv_caches:
+ index2name[extract_layer_index(layer_name,
+ num_attn_module)].append(layer_name)
+
+ for layer_index in sorted(index2name.keys()):
+ layer_names = index2name[layer_index]
+ if len(layer_names) > 1:
+ # One typical case is encoder-decoder model, e.g., bart.
+ # The cross attention and self attention in the same decoder layer
+ # has different layer_name but the same layer_index.
+
+ # TODO - analyze where runner_kv_caches is used and the right
+ # way to ensure it properly reflects multiple attention layers
+ # in the same decoder block.
+ if current_platform.is_kunlun() or current_platform.is_cuda() or current_platform.is_xpu():
+ # We know that the GPU runner is not impacted by this
+ # case. Some test code depends on runner_kv_caches, but
+ # not in a way that's impacted by ignoring this.
+ pass
+ else:
+ raise NotImplementedError
+ layer_name = layer_names[0]
+ runner_kv_caches.append(kv_caches[layer_name])
+
+ # Bind kv_caches to forward context
+ for layer_name, kv_cache in kv_caches.items():
+ # NOTE: Use list because of v0 PP virtual engine.
+ forward_context[layer_name].kv_cache = [kv_cache]
+
+
+def is_residual_scattered_for_sp(vllm_config: VllmConfig,
+ num_input_tokens: int) -> bool:
+ """Check if the residual tensor is scattered for sequence parallelism.
+
+ The residual tensor is scattered across tensor parallel ranks when sequence
+ parallelism and tensor parallelism is enabled, and the number of
+ input tokens is one of the compilation sizes.
+ """
+ if not vllm_config.compilation_config.pass_config.\
+ enable_sequence_parallelism:
+ return False
+
+ tp = vllm_config.parallel_config.tensor_parallel_size
+
+ if tp == 1:
+ return False
+
+ # When sequence parallelism is enabled, we always pad num_input_tokens
+ # to be a multiple of tensor_parallel_size (tp) earlier.
+ assert num_input_tokens % tp == 0
+
+ # Currently, SP is only enabled for static size fx graphs.
+ return (num_input_tokens in vllm_config.compilation_config.compile_sizes)
diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py
index 7425ac1..c323aad 100644
--- a/vllm_kunlun/vllm_utils_wrapper.py
+++ b/vllm_kunlun/vllm_utils_wrapper.py
@@ -1493,4 +1493,45 @@ def _fake_gptq_shuffle(
return None
-gptq_shuffle.register_fake(_fake_gptq_shuffle)
\ No newline at end of file
+gptq_shuffle.register_fake(_fake_gptq_shuffle)
+
+##################################################
+# ---------------- concat_and_cache_mla ------------------
+##################################################
+@custom_op("_C::concat_and_cache_mla", mutates_args=())
+def concat_and_cache_mla(
+ kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
+ k_pe: torch.Tensor, #[num_tokens, pe_dim]
+ kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
+ slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
+) -> None:
+ xtorch_ops.concat_and_cache_mla(
+ kv_c=kv_c,
+ k_pe=k_pe,
+ slot_mapping=slot_mapping,
+ kv_cache=kv_cache,
+ )
+
+@impl("_C::concat_and_cache_mla", "CUDA")
+def concat_and_cache_mla_cuda(
+ kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
+ k_pe: torch.Tensor, #[num_tokens, pe_dim]
+ kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
+ slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
+) -> None:
+ xtorch_ops.concat_and_cache_mla(
+ kv_c=kv_c,
+ k_pe=k_pe,
+ slot_mapping=slot_mapping,
+ kv_cache=kv_cache,
+ )
+
+def _fake_concat_and_cache_mla(
+ kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
+ k_pe: torch.Tensor, #[num_tokens, pe_dim]
+ kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
+ slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
+) -> None:
+ return None
+
+concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla)
\ No newline at end of file