diff --git a/README.md b/README.md index 64e8424..26f500c 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,30 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans + + DeepSeek-R1 + ✅ + ✅ + + ✅ + + + + DeepSeek-V3 + ✅ + ✅ + + ✅ + + + + DeepSeek-V3.2 + ✅ + ✅ + + ✅ + + diff --git a/vllm_kunlun/__init__.py b/vllm_kunlun/__init__.py index d0124d5..f4fbbbb 100644 --- a/vllm_kunlun/__init__.py +++ b/vllm_kunlun/__init__.py @@ -10,34 +10,15 @@ import vllm.envs as envs OLD_IMPORT_HOOK = builtins.__import__ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0): try: - start_time = time.time() - - # 模块映射表 module_mappings = { - "vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer", - "vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe", "vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper", - "vllm.v1.worker.gpu_model_runner": "vllm_kunlun.v1.worker.gpu_model_runner" + "vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils", + "vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader", + "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", + "vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler", + "vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler", } - # 需要保持原始导入的模块 - original_imports = [ - "vllm.model_executor.layers.fused_moe.base", - "vllm.model_executor.layers.fused_moe.config", - "vllm.model_executor.layers.fused_moe.layer" - ] - - if module_name in original_imports: - if module_name == "vllm.model_executor.layers.fused_moe.layer" and fromlist: - if "FusedMoEMethodBase" in fromlist: - return OLD_IMPORT_HOOK( - module_name, - globals=globals, - locals=locals, - fromlist=fromlist, - level=level - ) - if module_name in module_mappings: if module_name in sys.modules: return sys.modules[module_name] @@ -45,25 +26,6 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0) module = importlib.import_module(target_module) sys.modules[module_name] = module sys.modules[target_module] = module - return module - - relative_mappings = { - ("compressed_tensors_moe", "compressed_tensors"): "vllm_kunlun.ops.quantization.compressed_tensors_moe", - ("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer", - } - - if level == 1: - parent = globals.get('__package__', '').split('.')[-1] if globals else '' - key = (module_name, parent) - if key in relative_mappings: - if module_name in sys.modules: - return sys.modules[module_name] - target_module = relative_mappings[key] - module = importlib.import_module(target_module) - sys.modules[module_name] = module - sys.modules[target_module] = module - return module - except Exception: pass @@ -77,79 +39,16 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0) def import_hook(): """Apply import hook for VLLM Kunlun""" - if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")): - builtins.__import__ = _custom_import - - try: - modules_to_preload = [ - "vllm_kunlun.ops.quantization.compressed_tensors_moe", - "vllm_kunlun.ops.fused_moe.custom_ops", - "vllm_kunlun.ops.fused_moe.layer", - "vllm_kunlun.ops.quantization.fp8", - ] - for module_name in modules_to_preload: - importlib.import_module(module_name) - except Exception: - pass + builtins.__import__ = _custom_import def register(): """Register the Kunlun platform""" from .utils import redirect_output from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema - patch_bitsandbytes_loader() import_hook() - if envs.VLLM_USE_V1: - # patch_V1blockTable() - patch_V1top_p_K() - # TODO fixed fast top & k for vLLM 0.10.2, - pass - else: - patch_sampler() return "vllm_kunlun.platforms.kunlun.KunlunPlatform" def register_model(): """Register models for training and inference""" from .models import register_model as _reg - _reg() - -# [monkey patach sampler] -import sys -import sys, importlib, warnings - -def patch_bitsandbytes_loader(): - try: - # 载入你插件里自定义的 direct_register_custom_op 实现 - custom_utils = importlib.import_module("vllm_kunlun.models.model_loader.bitsandbytes_loader") - # 覆盖 vllm.utils - sys.modules["vllm.model_executor.model_loader.bitsandbytes_loader"] = custom_utils - print("[vllm_kunlun] bitsandbytes_loader patched ->", custom_utils.__file__) - except Exception as e: - warnings.warn(f"[vllm_kunlun] bitsandbytes_loader patch failed: {e!r}") - -def patch_sampler(): - try: - custom_sampler = importlib.import_module("vllm_kunlun.ops.sample.sampler") - sys.modules["vllm.model_executor.layers.sampler"] = custom_sampler - print("[vllm_kunlun] sampler patched ->", custom_sampler.__file__) - except Exception as e: - warnings.warn(f"[vllm_kunlun] sampler patch failed: {e!r}") - - -def patch_V1top_p_K(): - try: - custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.topk_topp_sampler") - sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler - print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__) - except Exception as e: - warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}") - -def patch_V1blockTable(): - try: - custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table") - sys.modules["vllm.v1.worker.block_table"] = custom_sampler - print("[vllm_kunlun] V1 block table patched ->", custom_sampler.__file__) - except Exception as e: - warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}") - -# 在模块导入时自动应用补丁 -import_hook() + _reg() \ No newline at end of file diff --git a/vllm_kunlun/models/__init__.py b/vllm_kunlun/models/__init__.py index 5dd90ec..9fd12c5 100644 --- a/vllm_kunlun/models/__init__.py +++ b/vllm_kunlun/models/__init__.py @@ -80,6 +80,14 @@ def register_model(): ModelRegistry.register_model( "GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM") - + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") + + ModelRegistry.register_model( + "DeepseekV32ForCausalLM", + "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM") + def register_quant_method(): """to do""" diff --git a/vllm_kunlun/models/deepseek_v2.py b/vllm_kunlun/models/deepseek_v2.py new file mode 100644 index 0000000..1c3c11a --- /dev/null +++ b/vllm_kunlun/models/deepseek_v2.py @@ -0,0 +1,1445 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV2/DeepseekV3 model.""" +import typing +from collections.abc import Callable, Iterable +from itertools import islice +from typing import Any, Optional, Union + +import torch +from torch.library import custom_op +from torch import nn +from transformers import DeepseekV2Config, DeepseekV3Config + +from vllm_kunlun.ops.attention.layer import Attention +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm_kunlun.ops.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm_kunlun.ops.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm_kunlun.ops.attention.mla import MLAModules, MultiHeadLatentAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv, direct_register_custom_op +from vllm_kunlun.ops.deep_gemm import int8_mqa_logits, int8_paged_mqa_logits +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend +from vllm_kunlun.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata +from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec + +from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +import xspeedgate_ops +_is_kunlun = True +logger = init_logger(__name__) + +class DeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + is_sequence_parallel=False, + prefix: str = "", + ) -> None: + super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV2MoE(nn.Module): + + def __init__( + self, + config: Union[DeepseekV2Config, DeepseekV3Config], + parallel_config: ParallelConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32)) + else: + self.gate.e_score_correction_bias = None + + # Load balancing settings. + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + if config.n_shared_experts is None: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + self.shared_experts = None + else: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out + else: + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= (1. / self.routed_scaling_factor) + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV2Attention(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + topk_indices_buffer: Optional[torch.Tensor] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + assert topk_indices_buffer is None, "topk_indices_buffer is not \ + supported for DeepseekV2Attention" + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + attn_output = self.attn(q, k, v) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + +@torch.inference_mode() +def cp_gather_indexer_k_quant_cache( + kv_cache, # [num_blocks, block_size, head_dim + 1] + block_table, # [batch_size, num_blocks] + cu_seq_lens, # [batch_size + 1, ] + batch_size, + head_dim, +): + num_blocks, block_size, _ = kv_cache.shape + kv_cache = kv_cache.view(num_blocks, -1) + + expected_value = [] + expected_scale = [] + for b in range(batch_size): + s = cu_seq_lens[b + 1] - cu_seq_lens[b] + if s == 0: + continue + tot = cdiv(s, block_size) + blocks = block_table[b, :tot] + + value = [] + scale = [] + full_block = torch.arange(tot - 1, + device=kv_cache.device, + dtype=torch.int32) + non_remaining_value = kv_cache[blocks[full_block], :block_size * + head_dim].view(-1, head_dim) + non_remaining_scale = kv_cache[blocks[full_block], + block_size * head_dim:].view(-1, 4) + + remaining = s - (tot - 1) * block_size + + value = torch.cat([ + non_remaining_value, + kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) + ], + dim=0) + scale = torch.cat([ + non_remaining_scale, + kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + + remaining * 4].view(-1, 4) + ], + dim=0) + + expected_value.append(value) + expected_scale.append(scale) + + gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) + gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) + gather_value = gather_value.view(torch.int8) + gather_scale = gather_scale.view(torch.float32) + return gather_value, gather_scale + +@torch.inference_mode() +def kunlun_indexer_k_quant_cache( + k, #[num_tokens, head_dim] + kv_cache, # [num_blocks, cache_block_size, head_dim + 1] + slot_mapping, # [num_tokens] + quant_block_size, +): + num_blocks, cache_block_size, cache_stride = kv_cache.shape + # num_tokens, head_dim = k.shape + head_dim = k.shape[1] + num_tokens = slot_mapping.shape[0] + assert head_dim % quant_block_size == 0 + kv_cache = kv_cache.view(num_blocks, -1) + + k_fp8 = torch.empty( + k.shape, + device=k.device, + dtype=torch.int8, + ) + k_scale = torch.empty( + [k.shape[0], 1], + device=k.device, + dtype=torch.float32, + ) + + torch.ops._C.quant2d(k, k_fp8, k_scale, force_sdnn=True) + k_scale /= 127 + for token_idx in range(num_tokens): + slot_idx = slot_mapping[token_idx] + if slot_idx < 0: + continue + block_idx = slot_idx // cache_block_size + block_offset = slot_idx % cache_block_size + v_offset = block_offset * head_dim + kv_cache[block_idx, v_offset:v_offset + head_dim] = k_fp8[token_idx, :].view(torch.uint8).contiguous() + s_offset = cache_block_size * head_dim + block_offset * 4 + kv_cache[block_idx, s_offset:s_offset + 4] = k_scale[token_idx, :].view(torch.uint8).contiguous() + kv_cache = kv_cache.view(num_blocks, cache_block_size, cache_stride) + +@custom_op("vllm::sparse_attn_indexer_vllm_kunlun", mutates_args=()) +def sparse_attn_indexer_vllm_kunlun( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> None: + + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + sparse_attn_indexer_vllm_kunlun_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + return + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + # kunlun_indexer_k_quant_cache( + # k, + # kv_cache, + # slot_mapping, + # quant_block_size, + # ) + + torch.ops.xspeedgate_ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + topk_indices_buffer[:hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8, k_scale = cp_gather_indexer_k_quant_cache( + kv_cache, + chunk.block_table, + chunk.cu_seq_lens, + chunk.num_reqs, + head_dim, + ) + + logits = int8_mqa_logits( + q_fp8[chunk.token_start:chunk.token_end], + (k_fp8, k_scale), + weights[chunk.token_start:chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), + dim=-1)[1] + topk_indices -= chunk.cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (chunk.cu_seqlen_ke - + chunk.cu_seqlen_ks)[:, None] < 0 + mask = torch.full_like(topk_indices, + False, + dtype=torch.bool, + device=topk_indices.device) + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + topk_indices_buffer[ + chunk.token_start:chunk.token_end, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:]) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + logits = int8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.seq_lens_cpu, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + # padded query len + current_device = padded_q_fp8_decode_tokens.device + padded_num_tokens = batch_size * next_n + positions = torch.arange(max_model_len, + device=current_device).unsqueeze(0).expand( + batch_size * next_n, -1) + row_indices = torch.arange(padded_num_tokens, + device=current_device) // next_n + next_n_offset = torch.arange( + padded_num_tokens, + device=padded_q_fp8_decode_tokens.device) % next_n + index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + + next_n_offset).unsqueeze(1) + # index_end_pos: [B * N, 1] + mask = positions <= index_end_pos + # mask: [B * N, L] + logits = logits.masked_fill(~mask, float('-inf')) + topk_indices = logits.topk(topk_tokens, + dim=-1)[1].to(torch.int32) # [B * N, K] + # ensure we don't set indices for the top k + # that is out of range(masked already) + # this will happen if context length is shorter than K + topk_indices[topk_indices > index_end_pos] = -1 + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens) + topk_indices_buffer[:num_decode_tokens, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) + + # return topk_indices_buffer + + +def sparse_attn_indexer_vllm_kunlun_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> None: + return + +sparse_attn_indexer_vllm_kunlun.register_fake(sparse_attn_indexer_vllm_kunlun_fake) + +class Indexer(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + q_lora_rank: int, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + topk_indices_buffer: Optional[torch.Tensor], + prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = config + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 + self.q_lora_rank = q_lora_rank # 1536 + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear(self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b") + self.wk = ReplicatedLinear(hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk") + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.weights_proj = ReplicatedLinear(hidden_size, + self.n_head, + bias=False, + quant_config=None, + prefix=f"{prefix}.weights_proj") + self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer + + # NOTE: (zyongye) we use fp8 naive cache, + # where we store value in fp8 and scale in fp32 + # per self.quant_block_size element + self.k_cache = DeepseekV32IndexerCache( + head_dim=self.head_dim + + self.head_dim // self.quant_block_size * 4, + dtype=torch.uint8, + prefix=f"{prefix}.k_cache", + cache_config=cache_config) + self.max_model_len = vllm_config.model_config.max_model_len + if self.max_model_len % cache_config.block_size != 0: #由于I8_paged_mqa_logits输入参数的限制,最大长度必须为block_zise的整数倍 + self.max_model_len = self.max_model_len + cache_config.block_size - (self.max_model_len % cache_config.block_size) + self.prefix = prefix + from vllm.v1.attention.backends.mla.indexer import ( + get_max_prefill_buffer_size) + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + + def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, + rotary_emb) -> torch.Tensor: + q, _ = self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + + k, _ = self.wk(hidden_states) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + + # we only quant q here since k quant is fused with cache insertion + q = q.view(-1, self.head_dim) + q_fp8 = torch.empty( + q.shape, + device=q.device, + dtype=torch.int8, + ) + q_scale = torch.empty( + [q.shape[0], 1], + device=q.device, + dtype=torch.float32, + ) + torch.ops._C.quant2d(q, q_fp8, q_scale, force_sdnn=True) + q_scale /= 127 + q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) + q_scale = q_scale.view(-1, self.n_head) + weights, _ = self.weights_proj(hidden_states) + weights = weights * self.n_head**-0.5 + weights = weights * q_scale * self.softmax_scale + + torch.ops.vllm.sparse_attn_indexer_vllm_kunlun( + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + return self.topk_indices_buffer + + +class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: + vllm/v1/attention/backends/mla/utils.py + """ + + def __init__( + self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + topk_indices_buffer: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.fused_qkv_a_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True) + else: + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + + if self.q_lora_rank is not None: + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.is_v32 = hasattr(config, "index_topk") + + if self.is_v32: + self.indexer = Indexer(vllm_config, config, hidden_size, + q_lora_rank, quant_config, cache_config, + topk_indices_buffer, f"{prefix}.indexer") + else: + self.indexer = None + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=self.indexer, + is_sparse=self.is_v32, + topk_indices_buffer=topk_indices_buffer, + ) + + self.mla_attn = MultiHeadLatentAttention( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.mla_attn(positions, hidden_states) + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer: Optional[torch.Tensor] = None) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + if model_config.use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") + else: + topk_indices_buffer = None + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, + SupportsLoRA): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = hasattr( + config, "q_lora_rank") and config.q_lora_rank is not None + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, DeepseekV2MoE): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + if name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +# Compatibility with +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py +def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, + DeepseekV3Config], + weight_name: str) -> Optional[int]: + if (hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None \ No newline at end of file diff --git a/vllm_kunlun/models/gpt_oss.py b/vllm_kunlun/models/gpt_oss.py index 532718d..33e97b9 100644 --- a/vllm_kunlun/models/gpt_oss.py +++ b/vllm_kunlun/models/gpt_oss.py @@ -16,7 +16,7 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) -from vllm_kunlun.ops.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) @@ -176,7 +176,7 @@ class MLPBlock(torch.nn.Module): x = sequence_parallel_chunk(x) g = self.router(x) - x = self.experts(hidden_states=x, router_logits=g, linear_weights=self.router.weight) + x = self.experts(hidden_states=x, router_logits=g) if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) diff --git a/vllm_kunlun/models/mimo_v2_flash.py b/vllm_kunlun/models/mimo_v2_flash.py index 0dfcb86..e381b24 100644 --- a/vllm_kunlun/models/mimo_v2_flash.py +++ b/vllm_kunlun/models/mimo_v2_flash.py @@ -21,7 +21,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) from vllm.logger import init_logger -from vllm_kunlun.ops.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -185,8 +185,7 @@ class MiMoV2MoE(nn.Module): gate_input = hidden_states router_logits = self.gate(gate_input) final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits, linear_weights=self.gate.weight - ) + hidden_states=hidden_states, router_logits=router_logits) return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py index fec8bbb..d6da6ae 100644 --- a/vllm_kunlun/ops/__init__.py +++ b/vllm_kunlun/ops/__init__.py @@ -20,4 +20,8 @@ import vllm_kunlun.ops.layernorm import vllm_kunlun.ops.quantization.awq import vllm_kunlun.ops.quantization.gptq import vllm_kunlun.ops.vocab_parallel_embedding -import vllm_kunlun.ops.linear \ No newline at end of file +import vllm_kunlun.ops.linear +import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass +import vllm_kunlun.ops.vocab_parallel_embedding +import vllm_kunlun.ops.quantization.compressed_tensors_moe +import vllm_kunlun.ops.fused_moe.layer \ No newline at end of file diff --git a/vllm_kunlun/ops/_kunlun_ops.py b/vllm_kunlun/ops/_kunlun_ops.py index 6250964..94a875e 100644 --- a/vllm_kunlun/ops/_kunlun_ops.py +++ b/vllm_kunlun/ops/_kunlun_ops.py @@ -417,7 +417,6 @@ class KunlunOps: w1: torch.Tensor, w2: torch.Tensor, router_logits: torch.Tensor, - linear_weights: torch.Tensor, ep_rank: int, moe_top_k: int, renormalize: bool, diff --git a/vllm_kunlun/ops/activation.py b/vllm_kunlun/ops/activation.py index 0526acf..6719808 100644 --- a/vllm_kunlun/ops/activation.py +++ b/vllm_kunlun/ops/activation.py @@ -108,7 +108,7 @@ class SiluAndMul(CustomOp): d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - xtorch_ops.swiglu(x, out) + torch.ops._C.silu_and_mul(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm_kunlun/ops/attention/flashmla.py b/vllm_kunlun/ops/attention/flashmla.py new file mode 100644 index 0000000..f53dfcf --- /dev/null +++ b/vllm_kunlun/ops/attention/flashmla.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform +import xtorch_ops + +logger = init_logger(__name__) + +if current_platform.is_cuda(): + try: + import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True + except ImportError: + _flashmla_C_AVAILABLE = False +else: + _flashmla_C_AVAILABLE = False + +if current_platform.is_cuda(): + try: + import vllm._flashmla_extension_C # noqa: F401 + _flashmla_extension_C_AVAILABLE = True + except ImportError: + _flashmla_extension_C_AVAILABLE = False +else: + _flashmla_extension_C_AVAILABLE = False + + +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + return True, None + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int = 1, + num_heads_k: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Returns: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + # return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + cache_seqlens_cpu = cache_seqlens.cpu() + return cache_seqlens_cpu, cache_seqlens + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Returns: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + softmax_lse = None + out = torch.ones(q.size(0), q.size(1), q.size(2), head_dim_v, dtype= q.dtype, device=q.device) + kv_lora_rank = head_dim_v + qk_rope_head_dim = q.size(3) - head_dim_v + head_dim = k_cache.shape[3] + page_block_size = k_cache.shape[1] + k_cache = k_cache.view(-1, 1, page_block_size, head_dim) + + # todo: optimize memcp + # q_c = q[..., : kv_lora_rank].contiguous() + # q_r = q[..., kv_lora_rank :].contiguous() + + is_context = False + vo_head_dim = -1 + + xtorch_ops.paged_attention(out, + q, + k_cache, None, + block_table, + tile_scheduler_metadata, # context_lens_cpu + num_splits, # context_lens_xpu + is_context, + causal, + vo_head_dim, + kv_lora_rank, + qk_rope_head_dim, + softmax_scale, + q_r=q) + return out, softmax_lse + +def kunlun_flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + cache_seqlens_cpu: torch.Tensor, + head_dim_v: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + max_seq_kv: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_tokens_kv, head_dim). + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head dimension of v. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. + max_seq_kv: seq中最大的kv长度 + + Returns: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + max_logits: (batch_size, seq_len_q, num_heads_q), torch.float32. + p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32. + """ + assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache." + assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now." + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if indices is not None: + # NOTE (zyongye): sparse attention is also causal + # since it only attend to the tokens before + # but here `causal` should not be specified + assert not causal, \ + "causal must be `false` if sparse attention is enabled." + + q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式 + batch_size, seq_len_q, num_heads_q, head_dim = q.shape + kv_lora_rank = head_dim_v + rope_head_dim = head_dim - kv_lora_rank + + out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank], + dtype=q.dtype, device=q.device) + max_logits = torch.zeros([batch_size, seq_len_q, num_heads_q], + dtype=torch.float32, device=q.device) + p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q], + dtype=torch.float32, device=q.device) + + xtorch_ops.fwd_kvcache_mla( + q_c=q, + kv_cache=k_cache, + indices=indices, + kv_lod_cpu=cache_seqlens_cpu, + max_seq_kv=max_seq_kv, + softmax_scale=softmax_scale, + # q_r=q_r, + # pe_cache=pe_cache, + out=out, + max_logits=max_logits, + p_sums=p_sums, + kv_lod_xpu=cache_seqlens, + ) + + return out, max_logits, p_sums + + +def flash_mla_sparse_prefill( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + q_lod_xpu: torch.Tensor, + q_lod_cpu: torch.Tensor, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + - q: [s_q, h_q, d_qk], bfloat16 + - kv: [s_kv, d_qk], bfloat16 + - indices: [s_q, h_kv, topk], int32. + Invalid indices should be set to -1 or numbers >= s_kv + - sm_scale: float + - q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长). + - d_v: The dimension of value vectors. Can only be 512 + + Returns: + - (output, max_logits, lse) + About the definition of output, + max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + s_q, h_q, d_qk = q.shape + + out = torch.zeros([s_q, h_q, d_v], dtype=q.dtype, device=q.device) + max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device) + lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device) + + xtorch_ops.sparse_prefill_fwd_opt( + q=q, + kv=kv, + indices=indices, + qlod_cpu=q_lod_cpu, + qlod_xpu=q_lod_xpu, + kvlod_cpu=q_lod_cpu, + kvlod_xpu=q_lod_xpu, + sm_scale=sm_scale, + d_v=d_v, + is_causal=True, #aiak这个值为true,这是为啥 + out=out, + max_logits=max_logits, + lse=lse, + ) + + # NOTE: Compared with torch.ops._flashmla_C.sparse_prefill_fwd, + # out_scale = 1 / math.log2(math.e) + # gpu_max_logits * out_scale = kunlun_lse + # gpu_lse * out_scale = kunlun_lse + return out, max_logits, lse + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# \ No newline at end of file diff --git a/vllm_kunlun/ops/attention/mla.py b/vllm_kunlun/ops/attention/mla.py new file mode 100644 index 0000000..e50ac2f --- /dev/null +++ b/vllm_kunlun/ops/attention/mla.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm_kunlun.ops.attention.layer import Attention +# from vllm.attention import Attention +from vllm.config import CacheConfig +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig + + +@dataclass +class MLAModules: + """Modules used in MLA. + """ + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + rotary_emb: torch.nn.Module + o_proj: torch.nn.Module + fused_qkv_a_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_b_proj: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + indexer: Optional[torch.nn.Module] + is_sparse: bool + topk_indices_buffer: Optional[torch.Tensor] + + +@CustomOp.register("vllm_kunlun_multi_head_latent_attention") +class MultiHeadLatentAttention(CustomOp): + """MLA layer registered as CustomOp. + Note that currently MLA ignores the enable/disable mechanism of CustomOp + because there is only one in-tree implementation in forward_native. + TODO: implement this with a new PluggableLayer mechanism. + + This class takes positions and hidden_states as input. + The input tensors can either contain prefill tokens or decode tokens. + The class does the following: + + 1. MLA Preprocess. + 2. Perform multi-head attention to prefill tokens and + multi-query attention to decode tokens separately. + 3. Return the output tensor. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + mla_modules: MLAModules, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj + self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa + self.q_a_layernorm = mla_modules.q_a_layernorm + self.q_b_proj = mla_modules.q_b_proj + self.q_proj = mla_modules.q_proj + self.kv_a_layernorm = mla_modules.kv_a_layernorm + self.kv_b_proj = mla_modules.kv_b_proj + self.rotary_emb = mla_modules.rotary_emb + self.o_proj = mla_modules.o_proj + self.indexer = mla_modules.indexer + self.is_sparse = mla_modules.is_sparse + + if self.indexer is not None: + assert hasattr(self.indexer, "topk_tokens") + self.topk_tokens = self.indexer.topk_tokens + self.topk_indices_buffer = mla_modules.topk_indices_buffer + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scale, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + use_sparse=mla_modules.is_sparse, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=self.kv_b_proj, + indexer=self.indexer, + ) + + self.prefix = prefix + + def forward_native( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + q_c = None + kv_lora = None + + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, \ + "fused_qkv_a_proj is required when q_lora_rank is not None" + assert self.q_a_layernorm is not None, \ + "q_a_layernorm is required when q_lora_rank is not None" + assert self.q_b_proj is not None, \ + "q_b_proj is required when q_lora_rank is not None" + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, \ + "kv_a_proj_with_mqa is required when q_lora_rank is None" + assert self.q_proj is not None, \ + "q_proj is required when q_lora_rank is None" + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], + dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) + + q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + + if self.indexer and self.is_sparse: + _topk_indices = self.indexer(hidden_states, q_c, positions, + self.rotary_emb) + + hidden_states_shape_0 = 0 + if isinstance(hidden_states, tuple): + x_q, x_scale = hidden_states + hidden_states_shape_0 = x_q.shape[0] + else: + hidden_states_shape_0 = hidden_states.shape[0] + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states_shape_0, + self.num_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] + + def forward_cuda(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) \ No newline at end of file diff --git a/vllm_kunlun/ops/deep_gemm.py b/vllm_kunlun/ops/deep_gemm.py new file mode 100644 index 0000000..af0640f --- /dev/null +++ b/vllm_kunlun/ops/deep_gemm.py @@ -0,0 +1,114 @@ +import torch +import xtorch_ops + +def int8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + logits = torch.empty((q.shape[0], kv[0].shape[0]), dtype=torch.float32, device=q.device) + context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device) + context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device) + + xtorch_ops.I8_mqa_logits( + q=q, + fused_kv_cache=kv, + weights=weights, + context_q_lens=(context_q_lens_xpu.cpu(), context_q_lens_xpu), + context_k_lens=(context_k_lens_xpu.cpu(), context_k_lens_xpu), + logits=logits, + clean_logits=True, + use_xfa_boost=False, + ) + seq_len_kv = kv[0].shape[0] + # mask参考 https://github.com/vllm-project/vllm/blob/v0.11.0/tests/kernels/attention/test_deepgemm_attention.py 的_ref_fp8_mqa_logits函数的实现 + mask_lo = (torch.arange(0, seq_len_kv, device=cu_seqlen_ks.device)[None, :] + >= cu_seqlen_ks[:, None]) + mask_hi = (torch.arange(0, seq_len_kv, device=cu_seqlen_ke.device)[None, :] + < cu_seqlen_ke[:, None]) + mask = mask_lo & mask_hi + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + +def int8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + context_lens_cpu: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + batch_size, next_n, _, D = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + + kv_cache_fp8=kv_cache_fp8.view(num_blocks, -1) + k_val = kv_cache_fp8[:,:block_size*D].view(torch.int8) + k_val = k_val.view(-1,block_size, 1, D) + k_scale_list = [] + for block_tables_idx in range(block_tables.shape[0]): + k_scale_item = kv_cache_fp8[block_tables[block_tables_idx], block_size * + D:].view(-1, 4) + k_scale_list.append(k_scale_item) + k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).view(-1,max_model_len) + kv_cache = [k_val, k_scale] + + weights = weights.view(batch_size,next_n,-1) + + logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device) + + xtorch_ops.I8_paged_mqa_logits( + q=q_fp8, + fused_kv_cache=kv_cache, + weights=weights, + context_lens=[context_lens_cpu, context_lens], + block_table=block_tables, + max_context_len=max_model_len, + clean_logits=True, + out=logits, + use_xfa_boost=False + ) + logits = logits.view(-1, max_model_len) + return logits \ No newline at end of file diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index 772fe1a..953fbc0 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -1,37 +1,14 @@ """layer.py""" + +from contextlib import nullcontext +from typing import Callable, Optional, Union, get_args + import torch -import os -from typing import Callable, Optional +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod -import vllm.envs as envs -from vllm.config import get_current_vllm_config -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.distributed import get_ep_group -from vllm.distributed.eplb.eplb_state import EplbState - -from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase -from vllm.model_executor.layers.fused_moe.layer import ( - UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig) - -from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform - -from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod - - -class FusedMoEMethodBase(VllmFusedMoEMethodBase): - """FusedMoEMethodBase""" - moe: FusedMoEConfig - -@CustomOp.register("vllm_kunlun_unquantized_fused_moe") -class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): - """UnquantizedFusedMoEMethod""" - def apply( +def apply( self, layer: torch.nn.Module, x: torch.Tensor, @@ -45,6 +22,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -52,40 +30,12 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - linear_weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """apply""" if enable_eplb: raise NotImplementedError( "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") - - return self.forward_kunlun(x=x, - layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - linear_weights=linear_weights, - e_score_correction_bias=e_score_correction_bias) - - def forward_kunlun( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - linear_weights: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + """forward_kunlun""" from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops if self.moe.use_ep: @@ -93,21 +43,18 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): layer.w13_weight, layer.w2_weight, router_logits, - linear_weights, self.moe.ep_rank, top_k, renormalize=renormalize, inplace=True, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, - topk_group=topk_group - ) + topk_group=topk_group) else: return ops.fused_moe(x, layer.w13_weight, layer.w2_weight, router_logits, - linear_weights, self.moe.ep_rank, top_k, renormalize=renormalize, @@ -118,12 +65,13 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, w1_bias = layer.w13_bias, - w2_bias = layer.w2_bias, - ) + w2_bias = layer.w2_bias) -class FusedMoE(VllmFusedMoE): - """FusedMoE""" - def __init__(self, +UnquantizedFusedMoEMethod.apply = apply + +class VllmFusedMoE(FusedMoE): + def __init__( + self, num_experts: int, # Global number of experts top_k: int, hidden_size: int, @@ -141,198 +89,47 @@ class FusedMoE(VllmFusedMoE): prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, num_redundant_experts: int = 0, - is_sequence_parallel=False, has_bias: bool = False, + is_sequence_parallel=False, + zero_expert_num: Optional[int] = 0, + zero_expert_type: Optional[str] = None, ): super().__init__( - num_experts=num_experts, # Global number of experts - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=reduce_results, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - quant_config=quant_config, - tp_size=tp_size, - ep_size=ep_size, - dp_size=dp_size, - prefix=prefix, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - apply_router_weight_on_input=apply_router_weight_on_input, - activation=activation, - enable_eplb=enable_eplb, - num_redundant_experts=num_redundant_experts, - ) - - vllm_config = get_current_vllm_config() - if vllm_config.model_config is not None: - model_dtype = vllm_config.model_config.dtype - else: - # TODO (bnell): This is a hack to get test_mixtral_moe to work - # since model_config is not set in the pytest test. - model_dtype = params_dtype - - moe = FusedMoEConfig( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + num_experts=num_experts, # Global number of experts + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=reduce_results, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + quant_config=quant_config, + tp_size=tp_size, + ep_size=ep_size, + dp_size=dp_size, + prefix=prefix, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + apply_router_weight_on_input=apply_router_weight_on_input, + activation=activation, + enable_eplb=enable_eplb, + num_redundant_experts=num_redundant_experts, has_bias=has_bias, - # quant_config=quant_config, - ) - self.moe_config = moe - self.quant_config = quant_config + is_sequence_parallel=is_sequence_parallel, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type) self.has_bias=has_bias self.register_parameter("w13_bias", None) self.register_parameter("w2_bias", None) - - # Note: get_quant_method will look at the layer's local_num_experts - # for heuristic purposes, so it must be initialized first. - quant_method: Optional[QuantizeMethodBase] = None - quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None - else quant_config.get_quant_method(self, prefix)) - assert quant_method is not None - # assert isinstance(quant_method, FusedMoEMethodBase) - self.quant_method = quant_method - - if self.enable_eplb: - from vllm_kunlun.ops.quantization.fp8 import ( - Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod): - # TODO: Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API - # design causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError("EPLB is only supported for FP8 " - "quantization for now.") - - moe_quant_params = { - "num_experts": self.local_num_experts, - "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, - "params_dtype": params_dtype, - "weight_loader": self.weight_loader, - } - # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod")): - moe_quant_params["intermediate_size_full"] = intermediate_size - - self.quant_method.create_weights(layer=self, **moe_quant_params) - - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor = None, - linear_weights: torch.Tensor = None): - """forward""" - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) - else: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[self.layer_name] - assert self.quant_method is not None - return self.forward_impl(hidden_states, router_logits, linear_weights) - # return torch.ops.vllm.moe_forward(hidden_states, router_logits, - # self.layer_name) - - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, - linear_weights: torch.Tensor = None): - """forward_impl""" - assert self.quant_method is not None - if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): - return self.forward_impl_chunked(hidden_states, router_logits) - - do_naive_dispatch_combine: bool = ( - self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels) - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) - - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - enable_eplb=self.enable_eplb, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, - linear_weights=linear_weights - ) - - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) - - return final_hidden_states - @classmethod - def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: - - num_physical_experts = num_experts + num_redundant_experts - - # In the returned mapping: - # - `expert_id` is the physical expert id - # - `weight_name` contains the weight name of the logical expert - # So that we should map the expert id to logical in `weight_name` - physical_to_logical_map = \ - EplbState.build_initial_global_physical_to_logical_map( - num_experts, num_redundant_experts) - - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", - expert_id, shard_id) for expert_id in range(num_physical_experts) - for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] +FusedMoE=VllmFusedMoE \ No newline at end of file diff --git a/vllm_kunlun/ops/quantization/compressed_tensors_moe.py b/vllm_kunlun/ops/quantization/compressed_tensors_moe.py index 7b06bc5..7a73e8b 100644 --- a/vllm_kunlun/ops/quantization/compressed_tensors_moe.py +++ b/vllm_kunlun/ops/quantization/compressed_tensors_moe.py @@ -1,244 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from enum import Enum +from typing import Callable, Optional, Union + import torch -from typing import Any, Literal, Optional, cast, Callable, Optional +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod -from compressed_tensors.config import (CompressionFormat, - SparsityCompressionConfig, - SparsityStructure) -from compressed_tensors.quantization import (ActivationOrdering, - QuantizationStrategy) -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.quantization.utils import replace_parameter -# TODO: import position will be changed after 0.9.0 -# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe +def klx_process_weights_after_loading(layer: torch.nn.Module) -> None: + """modify scale -> abs max""" + layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + layer.w13_weight_scale.data * 127, requires_grad=False + ) + layer.w2_weight_scale = torch.nn.Parameter( + layer.w2_weight_scale.data * 127, requires_grad=False + ) -from vllm.model_executor.utils import set_weight_attrs -import re -import xtorch_ops +def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + klx_process_weights_after_loading(layer) +def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + hidden_states = x + global_num_experts, up_gate_size, _ = layer.w13_weight.shape + M, N = hidden_states.shape + hidden_dim = layer.w2_weight.shape[1] + normed_score = torch.empty(M, + top_k, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + top_k, + dtype=torch.int32, + device=hidden_states.device) + num_blocks = 12 + block_statistic = torch.zeros( + num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device + ) -from safetensors.torch import load_file as safe_load_file - -class CompressedTensorsMoEMethod(FusedMoEMethodBase): - - def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod": - tsm = getattr(quant_config, "target_scheme_map", None) or {} - linear_cfg = None - for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"): - if k in tsm and isinstance(tsm[k], dict): - linear_cfg = tsm[k]; break - if not linear_cfg: - # print("target_scheme_map missing; fallback to INT8(W8A8) method") - return CompressedTensorsW8A8Int8MoEMethod(quant_config) - wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations") - if not wq or not aq: - # print("incomplete scheme; fallback to INT8(W8A8)") - return CompressedTensorsW8A8Int8MoEMethod(quant_config) - # 其它分流按需;默认回落: - return CompressedTensorsW8A8Int8MoEMethod(quant_config) - -# copied from vllm 0.9.0 -class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): - - def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): - self.quant_config = quant_config - - # 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题 - # print("Creating default INT8 quantization config for MoE") - - # 创建默认的权重量化配置字典 - self.weight_quant = type('WeightQuant', (), { - 'type': 'int', - 'num_bits': 8, - 'strategy': 'channel', - 'group_size': 128, - 'symmetric': True, - 'dynamic': False, - 'actorder': 'none', - 'observer': None, - 'observer_kwargs': {}, - 'block_structure': None - })() - - # 创建默认的输入激活量化配置字典 - self.input_quant = type('InputQuant', (), { - 'type': 'int', - 'num_bits': 8, - 'strategy': 'token', - 'group_size': 128, - 'symmetric': True, - 'dynamic': True, - 'actorder': 'none', - 'observer': None, - 'observer_kwargs': {}, - 'block_structure': None - })() - - # 修改比较方式,直接比较字符串 - per_channel = ( - self.weight_quant.strategy == "channel" - and self.input_quant.strategy == "token") - if not per_channel: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") - - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found static input scales.") - - def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - # 权重先用浮点占位,便于从 ckpt 加载原始权重 - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), # 通常是 torch.bfloat16 - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐) - w13_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32), - requires_grad=False) - w2_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - # 输入 scale 动态计算即可 - layer.w13_input_scale = None - layer.w2_input_scale = None - - def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=torch.int8), # 直接使用 int8 - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=torch.int8), # 直接使用 int8 - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # 缩放因子 - w13_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32), - requires_grad=False) - w2_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - # 输入 scale 动态计算 - layer.w13_input_scale = None - layer.w2_input_scale = None - - @torch.no_grad() - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - return - #原始权重转 float32 做统计更稳健 - w13_f = layer.w13_weight.float() - w2_f = layer.w2_weight.float() - - # 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1) - qmax = 127.0 - w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N] - w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H] - - w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32 - w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32 - - # 量化:用 3D scale 广播,存回 2D scale - w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1] - w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1] - - w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8) - w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8) - - # 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行: - w13_scale_2d = w13_scale_2d * 127.0 - w2_scale_2d = w2_scale_2d * 127.0 - - # 回写参数:权重 int8;scale 用 float32 + 2D - replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False)) - replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False)) - replace_parameter(layer, 'w13_weight_scale', - torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False)) - replace_parameter(layer, 'w2_weight_scale', - torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False)) - - # 简要检查 - print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}") - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, # 添加这个参数 - expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数 - logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数 - logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数 - linear_weights: Optional[torch.Tensor] = None, # 添加这个参数 - ) -> torch.Tensor: - - output = torch.empty_like(x) - torch.ops._C.moe_ffn_per_token_block( - x=x, - inter_weight=layer.w13_weight, - inter_scale=layer.w13_weight_scale, - outer_weight=layer.w2_weight, - outer_scale=layer.w2_weight_scale, - top_k=top_k, - global_num_experts=global_num_experts, - linear_weights=linear_weights, - expert_map=expert_map, - activation=activation, - output=output, - use_expert_parallel=expert_map is not None, - ep_size=expert_map.size(0) if expert_map is not None else 1, - ep_rank=0, + router_logits = router_logits.float() + if scoring_func == "softmax": + torch.ops._C.moe_softmax_topk_norm( + x=router_logits, + normed_score=normed_score, + topk_index=topk_ids, + block_statistic=None, + stable=True) + elif scoring_func == "sigmoid": + torch.ops._C.moe_sigmoid_group_topk_norm( + x=router_logits, + norm_score=normed_score, + topk_index=topk_ids, + block_static=block_statistic, + bias=e_score_correction_bias, + n_group=num_expert_group, + topk_group=topk_group, + scale=routed_scaling_factor, ) - return output -print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \ - --> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod") \ No newline at end of file + moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float + expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] + sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] + sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device) + + torch.ops._C.gen_block_statistic(topk_ids,block_statistic) + + torch.ops._C.moe_pre_sorted( + x=hidden_states, + topk_index=topk_ids, + block_statistic=block_statistic, + moe_expand=moe_expand, + moe_index=sorted_tokens_idx, + expert_m=expert_m, + sorted_tokens_num_lod=sorted_tokens_num_lod) + + y = torch.empty(M,top_k, + layer.w13_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + moe_expand = moe_expand.view(M * top_k, hidden_dim) + + x_shape = moe_expand.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) + x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device) + torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) + + torch.ops._C.moe_fc( + x=x_q, + x_perchannel_max=x_scale, + weight=layer.w13_weight, + w_perchannel_max=layer.w13_weight_scale, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=top_k, + y=y, + topk_ids=topk_ids, + # sort_mode=False, + act=None) + + d = y.shape[-1] // 2 + output_shape = (y.shape[:-1] + (d, )) + out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) + torch.ops._C.silu_and_mul(out1, y) + + out = torch.empty(M,top_k, + layer.w2_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device) + + out1 = out1.reshape(-1, out1.shape[-1]) + x_shape = out1.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) + x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device) + torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) + + torch.ops._C.moe_fc( + x=x_q, + x_perchannel_max=x_scale, + weight=layer.w2_weight, + w_perchannel_max=layer.w2_weight_scale, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=top_k, + y=out, + topk_ids=topk_ids, + # sort_mode=False, + act=None) + + dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device) + output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) + sorted_tokens_idx = sorted_tokens_idx.view(M, top_k) + + torch.ops._C.moe_post( + x=out, + moe_index=sorted_tokens_idx, + normed_scale=normed_score, + dequant_scale=dequant_scale, + y=output + ) + return output + +CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading +CompressedTensorsW8A8Int8MoEMethod.apply = apply \ No newline at end of file diff --git a/vllm_kunlun/ops/quantization/kernels/__init__.py b/vllm_kunlun/ops/quantization/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/ops/quantization/kernels/scaled_mm/__init__.py b/vllm_kunlun/ops/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py b/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py new file mode 100644 index 0000000..25a3add --- /dev/null +++ b/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) + +def can_implement_kunlun( + cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]: + return True, None + +def klx_process_weights_after_loading(layer: torch.nn.Module) -> None: + """modify scale -> abs max""" + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data * 127, requires_grad=False) + +def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.config.input_symmetric: + weight = getattr(layer, self.w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, self.i_zp_name) * azp_adj + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + klx_process_weights_after_loading(layer) + +def apply_weights_kunlun(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x_q, x_scale, out = None, None, None + w_t_shape = layer.weight.T.shape + if isinstance(x, tuple): + x_q, x_scale = x + out = torch.empty((x_q.shape[0], w_t_shape[0]), + dtype=torch.bfloat16, + device=x_q.device) + else: + x_shape = x.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device) + x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device) + out = torch.empty((x_shape[0], w_t_shape[0]), + dtype=x.dtype, + device=x.device) + torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True) + torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out) + return out + +CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun +CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun +CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun \ No newline at end of file diff --git a/vllm_kunlun/ops/rotary_embedding.py b/vllm_kunlun/ops/rotary_embedding.py index bbe32a6..a4c6289 100644 --- a/vllm_kunlun/ops/rotary_embedding.py +++ b/vllm_kunlun/ops/rotary_embedding.py @@ -19,7 +19,9 @@ import torch import xspeedgate_ops import os from vllm.model_executor.layers.rotary_embedding import ( - RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding) + RotaryEmbedding, YaRNScalingRotaryEmbedding, + DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding, + DeepseekScalingRotaryEmbedding) from typing import Optional, Tuple def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor: @@ -143,12 +145,15 @@ def vllm_kunlun_mrope_forward_cuda( return query, key +DeepseekScalingRotaryEmbedding_forward = DeepseekScalingRotaryEmbedding.forward +DeepseekScalingRotaryEmbedding_forward_cuda = DeepseekScalingRotaryEmbedding.forward_cuda RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda RotaryEmbedding.forward = vllm_kunlun_forward_cuda +DeepseekScalingRotaryEmbedding.forward = DeepseekScalingRotaryEmbedding_forward +DeepseekScalingRotaryEmbedding.forward_cuda = DeepseekScalingRotaryEmbedding_forward_cuda MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda - def Split_Norm_Rope( qkv: torch.Tensor, cos_sin_cache: torch.Tensor, diff --git a/vllm_kunlun/platforms/kunlun.py b/vllm_kunlun/platforms/kunlun.py index d52f294..b79cd62 100644 --- a/vllm_kunlun/platforms/kunlun.py +++ b/vllm_kunlun/platforms/kunlun.py @@ -177,6 +177,8 @@ class KunlunPlatform(Platform): # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then # we default to FlashMLA backend, so we need to force the blocksize # here + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \ or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") from vllm.attention.ops.flashmla import is_flashmla_supported @@ -185,6 +187,11 @@ class KunlunPlatform(Platform): cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") + if use_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse " + "backend.") if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 @@ -224,6 +231,14 @@ class KunlunPlatform(Platform): Returns: str: Class name of the attention backend. """ + if use_mla: + if use_sparse: + logger.info_once("Using Sparse MLA backend on V1 engine.") + # return ("vllm.v1.attention.backends.mla.flashmla_sparse." + # "FlashMLASparseBackend") + return ("vllm_kunlun.v1.attention.backends.mla.flashmla_sparse." + "FlashMLASparseBackend") + return "vllm_kunlun.v1.attention.backends.mla.flashmla.FlashMLABackend" if use_v1: return "vllm_kunlun.v1.attention.backends.kunlun_attn.KunlunAttentionBackend" elif not use_mla: diff --git a/vllm_kunlun/v1/attention/backends/mla/__init__.py b/vllm_kunlun/v1/attention/backends/mla/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_kunlun/v1/attention/backends/mla/common.py b/vllm_kunlun/v1/attention/backends/mla/common.py new file mode 100644 index 0000000..d39b405 --- /dev/null +++ b/vllm_kunlun/v1/attention/backends/mla/common.py @@ -0,0 +1,1867 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +# MLA Common Components + +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatenated per head + `q_b_proj` is [W_UQ; W_QR] concatenated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Runtime +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([ql_nope, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import Generic, Optional, TypeVar, Union + +import torch +from tqdm import tqdm + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) +from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.distributed import get_tp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod) +from vllm.platforms import current_platform +from vllm.utils import cdiv, round_down +from vllm.utils.flashinfer import has_nvidia_artifactory +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec +import xtorch_ops + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True +except ImportError: + # For rocm use upstream flash attention + if current_platform.is_rocm(): + from flash_attn import flash_attn_varlen_func + is_vllm_fa = False + +try: + from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + from flashinfer.prefill import ( # noqa: F401 + cudnn_batch_prefill_with_kv_cache) + flashinfer_available = True +except ImportError: + flashinfer_available = False + + +def is_rocm_aiter_fp8bmm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_FP8BMM \ + and envs.VLLM_ROCM_USE_AITER + + +if is_rocm_aiter_fp8bmm_enabled(): + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + as aiter_triton_fp8_bmm) + + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +logger = init_logger(__name__) + +CUDNN_WORKSPACE_SIZE = 12800 + + +class MLACommonBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + + +@dataclass +class MLACommonPrefillMetadata: + """ Prefill Specific Metadata """ + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + workspace: torch.Tensor + + # for mla DCP + cp_chunk_seq_lens: Optional[list[list[int]]] = None + origin_context_lens: Optional[list[int]] = None + cp_cu_seq_lens: Optional[torch.Tensor] = None + chunk_size: Optional[int] = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + + block_table: torch.Tensor + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor + max_query_len: int + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class FlashInferPrefillMetadata(MLACommonPrefillMetadata): + prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None + prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( + default_factory=list) + + +@dataclass +class CudnnPrefillMetadata(MLACommonPrefillMetadata): + + class ChunkedContextMetadata( + MLACommonPrefillMetadata.ChunkedContextMetadata): + seq_lens: torch.Tensor + + query_seq_lens: Optional[torch.Tensor] = None + cudnn_workspace: Optional[torch.Tensor] = None + + +@dataclass +class MLACommonDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + + +D = TypeVar("D", bound=MLACommonDecodeMetadata) + + +@dataclass +class MLACommonMetadata(Generic[D]): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor + slot_mapping: torch.Tensor + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + decode: Optional[D] = None + prefill: Optional[Union[MLACommonPrefillMetadata, + FlashInferPrefillMetadata, + CudnnPrefillMetadata]] = None + + def __post_init__(self): + if self.head_dim is not None: + MLACommonBackend.validate_head_size(self.head_dim) + + +M = TypeVar("M", bound=MLACommonMetadata) +A = TypeVar("A") + + +def use_flashinfer_prefill() -> bool: + # For blackwell default to flashinfer prefill if it's available since + # it is faster than FA2. + return (not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available + and not envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100)) + + +def use_cudnn_prefill() -> bool: + return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + and has_nvidia_artifactory()) + + +# Currently 394MB, this can be tuned based on GEMM sizes used. +# Chosen to be the same as sglang: +# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 +FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + reorder_batch_threshold: int = 1 + + @staticmethod + def determine_chunked_prefill_workspace_size( + vllm_config: VllmConfig) -> int: + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + chunked_prefill_workspace_size = min( + # Try for 8 full length request or at least 4 pages per-request + max(8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 64 * 1024) + + # Enforce that we enough for at least 1 page per request + chunked_prefill_workspace_size = max( + chunked_prefill_workspace_size, + scheduler_config.max_num_seqs * cache_config.block_size) + + return chunked_prefill_workspace_size + + def __init__(self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[type[M]] = None): + self.metadata_cls = metadata_cls \ + if metadata_cls is not None else MLACommonMetadata + self.kv_cache_spec = kv_cache_spec + scheduler_config = vllm_config.scheduler_config + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.compilation_config = vllm_config.compilation_config + self.device = device + + self.num_heads = self.model_config.get_num_attention_heads( + parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.aot_schedule = current_platform.is_cuda() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + # Don't try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.kv_cache_spec.block_size + + self.chunked_prefill_workspace_size = \ + self.determine_chunked_prefill_workspace_size(vllm_config) + + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % \ + self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + else: + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + + self._use_cudnn_prefill = use_cudnn_prefill() + self._use_fi_prefill = use_flashinfer_prefill() + self.prefill_metadata_cls = ( + FlashInferPrefillMetadata + if self._use_fi_prefill else CudnnPrefillMetadata + if self._use_cudnn_prefill else MLACommonPrefillMetadata) + + if self._use_fi_prefill: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device) + + self._fi_prefill_main: Optional[ + BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_chunks: list[ + BatchPrefillWithRaggedKVCacheWrapper] = [] + + self._global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, layer_names, + MLACommonImpl)) + + if self._use_cudnn_prefill: + self.cudnn_workspace = torch.empty( + CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, + dtype=torch.int8, + device=device, + ) + + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): + qo_indptr = prefill.query_start_loc + + has_context = False + if prefill.chunked_context is not None: + chunked_context = prefill.chunked_context + has_context = True + + if self._fi_prefill_main is None: + self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass") + + if has_context: + num_chunks = chunked_context.cu_seq_lens.shape[0] + # Allocate more prefill chunk wrappers if needed + if len(self._fi_prefill_chunks) < num_chunks: + for _ in range(len(self._fi_prefill_chunks), num_chunks): + self._fi_prefill_chunks.append( + BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass")) + assert num_chunks <= len(self._fi_prefill_chunks) + + # In MLA, the non-latent num_qo_heads == num_kv_heads + num_qo_heads = self.num_heads + num_kv_heads = num_qo_heads + + # Sanity: Verify that num_kv_heads == 1 since it is latent space + assert self.kv_cache_spec.num_kv_heads == 1 + + # Get non-latent head_dim_qk and head_dim_vo + head_dim_qk = (self.mla_dims.qk_nope_head_dim + + self.mla_dims.qk_rope_head_dim) + head_dim_vo = self.mla_dims.v_head_dim + + # For main run, qo_indptr == kv_indptr + kv_indptr = qo_indptr.clone() + + # Prepare main prefill + self._fi_prefill_main.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=True, # This is main run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, + q_data_type=self.model_config.dtype, + ) + + # Prepare context prefills + if has_context: + for i in range(num_chunks): + kv_indptr_chunk = chunked_context.cu_seq_lens[i] + + self._fi_prefill_chunks[i].plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_chunk, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=False, # This is context run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters. + logits_soft_cap, + q_data_type=self.model_config.dtype, + ) + + prefill.prefill_main = self._fi_prefill_main + prefill.prefill_chunks = self._fi_prefill_chunks + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> MLACommonDecodeMetadata: + return MLACommonDecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with MLA. + """ + m = common_attn_metadata + assert m.num_reqs <= (m.num_actual_tokens * + self.reorder_batch_threshold), \ + "MLA only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + assert m.max_query_len <= self.reorder_batch_threshold # decode only + + return self.build(0, m) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> M: + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + + # Note(simon): be careful about the CPU <> GPU memory movement in this + # function. We should avoid GPU -> CPU sync as much as possible because + # it blocks on all previous kernels. + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - + query_seq_lens_cpu) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + # Note(hc): update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + seq_lens[:num_decodes] = seq_lens[:num_decodes] \ + // self.dcp_world_size + (self.dcp_rank <= \ + (seq_lens[:num_decodes] - 1) % self.dcp_world_size) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + reqs_start = num_decodes # prefill_start + + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + # Note(hc): The context lengths in the perspective of dcp rank0. + cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / + self.dcp_world_size).int() + origin_context_lens = context_lens_cpu.tolist() + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + chunked_context_metadata = None + if max_context_len_cpu > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note(simon): this is done in CPU because of downstream's + # of `to_list`. + chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + + cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + + if self.dcp_world_size > 1: + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + cp_max_context_chunk = max_context_chunk // \ + self.dcp_world_size + cp_chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * cp_max_context_chunk + cp_chunk_ends = torch.min( + cp_context_lens_cpu.unsqueeze(0), + cp_chunk_starts + cp_max_context_chunk) + cp_chunk_seq_lens = (cp_chunk_ends - + cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + + chunked_context_metadata_cls = \ + CudnnPrefillMetadata.ChunkedContextMetadata \ + if self._use_cudnn_prefill else \ + MLACommonPrefillMetadata.ChunkedContextMetadata + if self.dcp_world_size > 1: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=cp_chunk_starts.to(device, non_blocking=True), + seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), + origin_context_lens=origin_context_lens, + cp_cu_seq_lens=cp_cu_seq_lens_cpu \ + .to(device, non_blocking=True), + chunk_size=max_context_chunk, + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + ) + else: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens + + assert max(chunked_context_metadata.max_seq_lens) <= \ + self.chunked_prefill_workspace_size + + prefill_metadata = self.prefill_metadata_cls( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + query_start_loc_cpu=prefill_query_start_loc.cpu(), + max_query_len=max_query_len, + chunked_context=chunked_context_metadata, + ) + + if self._use_cudnn_prefill: + assert isinstance(prefill_metadata, CudnnPrefillMetadata) + prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ + - prefill_query_start_loc[:-1] + prefill_metadata.cudnn_workspace = self.cudnn_workspace + + decode_metadata = None + if num_decodes > 0: + decode_metadata = self._build_decode( + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], + query_start_loc_device=query_start_loc[:num_decodes + 1], + num_decode_tokens=num_decode_tokens, + ) + + attn_metadata = self.metadata_cls( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, + num_actual_tokens=num_tokens, + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + if self._use_fi_prefill and num_prefills > 0: + assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) + self._build_fi_prefill_wrappers(attn_metadata.prefill) + + return attn_metadata + + +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + cp_chunk_seq_lens_lst: list[int], + origin_context_lens: list[int], + cp_world_size: int, + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg kvcache after cp local gather to tp layout for attn kernel. + + Args: + cp_chunk_seq_lens_lst: chunk context lengths under CP. + origin_context_lens: origin full context lengths under CP. + cp_world_size: CP size. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: equals to max_context_chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, + origin_context_lens): + chunk_context_len = chunk_size + if cp_chunk_seq_len != 0: + chunk_context_len = min( + chunk_context_len, origin_context_len - chunk_size * chunk_idx) + cp_target_rank = (chunk_context_len - 1) % cp_world_size + cur_seq_len = 0 + for rank in range(cp_world_size): + if rank > cp_target_rank and cp_chunk_seq_len: + real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + else: + real_cp_chunk_seq_len = cp_chunk_seq_len + if real_cp_chunk_seq_len: + kv_c_segment = allgatered_kv_c_normed[rank * toks + + src_token_idx:rank * + toks + src_token_idx + + real_cp_chunk_seq_len] + k_pe_segment = allgatered_k_pe[rank * toks + + src_token_idx:rank * toks + + src_token_idx + + real_cp_chunk_seq_len] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += real_cp_chunk_seq_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += cp_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + +# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, +# and MLACommonImpl -> MLACommonDenseImpl or somthing like that +class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + indexer=None, + q_pad_num_heads: Optional[int] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported for MLA") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.kv_b_proj = kv_b_proj + self.indexer = indexer + self.q_pad_num_heads = q_pad_num_heads + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + # Copy result + out.copy_(x) + else: + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape( + -1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result + + +class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if use_flashinfer_prefill(): + logger.debug_once("Using FlashInfer prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi + self._pad_v = False + elif use_cudnn_prefill(): + logger.debug_once("Using CUDNN prefill for MLA") + self._run_prefill_context_chunk = \ + self._run_prefill_context_chunk_cudnn + self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn + self._pad_v = False + else: # Use FlashAttention + logger.debug_once("Using FlashAttention prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa + + # Handle the differences between the flash_attn_varlen from + # flash_attn and the one from vllm_flash_attn. The former is used on + # RoCM and the latter has an additional parameter to control + # FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + self.dcp_world_size: Optional[int] = None + + self.chunked_prefill_workspace_size = \ + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config()) + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + context_seq_lod_xpu=None, + context_seq_lod_cpu=None, + return_softmax_lse=False, + causal=True, + softmax_scale=None, + **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + if is_vllm_fa: + kwargs["return_softmax_lse"] = return_softmax_lse + else: + # ROCm leverages the upstream flash_attn, which takes a parameter + # called "return_attn_probs" instead of return_softmax_lse + kwargs["return_attn_probs"] = return_softmax_lse + + # attn_out = self.flash_attn_varlen_func( + # q=q, + # k=k, + # v=maybe_padded_v, + # softmax_scale=softmax_scale, + # **kwargs, + # ) + attn_out = torch.empty_like(q) + ds_alpha = 1.8738542070926265 + tp_q_head_num=128 + softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device) + softmax_lse.fill_(float('-inf')) + xtorch_ops.attention( + q=q, + k_cache=k, + v_cache=maybe_padded_v, + out=attn_out, + is_causal=causal, + is_prefill=True, + prefill_len=0, + k_perchannel_scale=None, + v_perchannel_scale=None, + smooth=None, + context_seq_lod_cpu=context_seq_lod_cpu, + context_seq_lod_xpu=context_seq_lod_xpu, + slot_mapping_cpu=None, + slot_mapping_xpu=None, + v_trans=False, + v_trans_threshold=0, + alpha=ds_alpha, + softmax_lse=softmax_lse + ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + + def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, + k, v, return_softmax_lse): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + context_seq_lod_xpu=prefill.query_start_loc, + context_seq_lod_cpu=prefill.query_start_loc_cpu, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) + + def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, + k, v, return_softmax_lse): + assert isinstance(prefill, FlashInferPrefillMetadata) + assert prefill.prefill_main is not None + ret = prefill.prefill_main.run( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, + q, k, v, return_softmax_lse): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.query_seq_lens is not None + output, lse = cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.max_query_len, + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), + causal=True, + return_lse=True, # do not support False for now + is_cuda_graph_compatible= + True, #Indicates actual_seq_lens are on GPU or CPU. + ) + if return_softmax_lse: + return output, lse + return output + + def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + context_seq_lod_xpu=prefill.query_start_loc, + context_seq_lod_cpu=prefill.chunked_context.cu_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, + return_softmax_lse=True, + ) + + def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert isinstance(prefill, FlashInferPrefillMetadata) + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( + q=q, + k=k, + v=v, + return_lse=True, + ) + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + + def _run_prefill_context_chunk_cudnn(self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + assert prefill.query_seq_lens is not None + return cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. + view(-1, 1, 1, 1), + causal=False, + return_lse=True, + is_cuda_graph_compatible= + True, #Indicates actual_seq_lens are on GPU or CPU. + ) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def gather_and_maybe_dequant_cache_py_optimized( + self, + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None + ) -> None: + device = src_cache.device + num_blocks, block_size, head_dim = src_cache.shape + tot_tokens = dst.shape[0] + + src_cache_2d = src_cache.reshape(-1, head_dim) + + token_ids = torch.arange(tot_tokens, device=device, dtype=torch.int32) + + batch_ids = torch.zeros(tot_tokens, dtype=torch.int32, device=device) + for i in range(batch_size): + start = cu_seq_lens[i] + end = cu_seq_lens[i + 1] + mask = (token_ids >= start) & (token_ids < end) + batch_ids[mask] = i + + batch_starts = cu_seq_lens[batch_ids] + batch_offsets = token_ids - batch_starts + + if seq_starts is not None: + seq_start_offsets = seq_starts[batch_ids] + batch_offsets = batch_offsets + seq_start_offsets + + block_table_ids = batch_offsets // block_size + slot_ids = batch_offsets % block_size + + block_table_flat = block_table.view(-1) + block_table_stride = block_table.shape[1] + block_table_indices = batch_ids * block_table_stride + block_table_ids + block_ids = block_table_flat[block_table_indices] + + valid_mask = block_ids >= 0 + linear_indices = torch.where( + valid_mask, + block_ids * block_size + slot_ids, + torch.zeros_like(block_ids, device=device) + ) + + if kv_cache_dtype == 'auto': + gathered = src_cache_2d[linear_indices.long()] + dst[valid_mask] = gathered[valid_mask].to(dst.dtype) + else: + gathered = src_cache_2d[linear_indices.long()].float() + if kv_cache_dtype in ['fp8', 'fp8_e4m3', 'fp8_e5m2']: + gathered = gathered * scale.item() + dst[valid_mask] = gathered[valid_mask].to(dst.dtype) + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + ): + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + self.gather_and_maybe_dequant_cache_py_optimized( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank] + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _context_parallel_compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + dcp_world_size: int, + ): + assert k_scale is None, "DCP not support scaled kvcache now." + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.origin_context_lens is not None + assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None + assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset:allgather_offset * (1 + dcp_world_size)] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[:toks * + dcp_world_size] + cur_allgather_kvcache.copy_(get_dcp_group().all_gather( + local_gathered_kvcache, dim=0)) + assert cur_allgather_kvcache.shape[ + -1] == self.kv_lora_rank + self.qk_rope_head_dim + allgatered_kv_c_normed, allgatered_k_pe = \ + cur_allgather_kvcache.unsqueeze( + 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. + cp_chunk_seq_lens[i], + origin_context_lens=prefill_metadata.chunked_context. + origin_context_lens, + cp_world_size=dcp_world_size, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] + [-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + ) -> torch.Tensor: + # TODO (zyongye): Prefill function here + assert attn_metadata.prefill is not None + assert self.dcp_world_size is not None + + has_context = attn_metadata.prefill.chunked_context is not None + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + output = self._run_prefill_new_tokens( + prefill=attn_metadata.prefill, + q=q, + k=k, + v=v, + return_softmax_lse=has_context, + ) + + if has_context: + suffix_output, suffix_lse = output + if self.dcp_world_size > 1: + context_output, context_lse = \ + self._context_parallel_compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, + k_scale=None, dcp_world_size=self.dcp_world_size) + else: + context_output, context_lse = \ + self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + + return output.flatten(start_dim=-2) + + @abstractmethod + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: M, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: M, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl") + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (self.chunked_prefill_workspace_size, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_q = q[:num_decode_tokens] + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + xtorch_ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + attn_metadata.slot_mapping.flatten(), + kv_cache, + ) + + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata, layer._k_scale) + + if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + + # Pads the head_dim if necessary (for the underlying kernel) + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty( + (B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L)) + decode_ql_nope.resize_((N, B, L)) + + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape([ + ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] + ]), layer._q_scale) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape( + [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + + # call decode attn + attn_out, lse = self._forward_decode(decode_q, kv_cache, + attn_metadata, layer) + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + + # v_up projection + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) + return output_padded diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla.py b/vllm_kunlun/v1/attention/backends/mla/flashmla.py new file mode 100644 index 0000000..46268eb --- /dev/null +++ b/vllm_kunlun/v1/attention/backends/mla/flashmla.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +import torch + +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm_kunlun.ops.attention.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm_kunlun.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.kv_cache_interface import AttentionSpec + +logger = init_logger(__name__) + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_metadata_cls() -> type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLAImpl"]: + return FlashMLAImpl + + +@dataclass +class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: torch.Tensor + num_splits: torch.Tensor + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): + pass + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashMLAMetadata) + + self.num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None + + device_properties = torch.cuda.get_device_properties(self.device) + num_sms = device_properties.multi_processor_count + + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.cg_buf_tile_scheduler_metadata = torch.zeros( + # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) + # TileSchedulerMetaDataSize = 8 + (num_sms, 8), + device=self.device, + dtype=torch.int32, + ) + self.cg_buf_num_splits = torch.empty( + (vllm_config.scheduler_config.max_num_seqs + 1), + device=self.device, + dtype=torch.int32) + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = \ + get_mla_metadata( + seq_lens_device, + self.num_q_heads, + 1, # MQA for the decode path + ) + + # TODO: we can disambiguate between decode and mixed-prefill decode here + # so we can only use the persistent buffer if a cudagraph is actually + # being used. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None + + sm_parts = tile_scheduler_metadata.size(0) + # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) + tile_scheduler_metadata_view = \ + self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) + tile_scheduler_metadata = tile_scheduler_metadata_view + + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + # Num splits needs to monotonically increasing + # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise + # it needs to monotonically increasing by 1) + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + num_splits = num_splits_view + + return FlashMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + ) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + is_supported, reason = is_flashmla_supported() + assert is_supported, reason + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # TODO: (zyongye) decode function for mla here + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if type(q) is tuple: + q = torch.cat(q, dim=-1) + + assert isinstance(q, torch.Tensor) + o, lse = flash_mla_with_kvcache( + q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=attn_metadata.decode.block_table, + cache_seqlens=attn_metadata.decode.seq_lens, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=attn_metadata.decode. + tile_scheduler_metadata, + num_splits=attn_metadata.decode.num_splits, + softmax_scale=self.scale, + causal=True, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), + ) + + return o, lse diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py new file mode 100644 index 0000000..3a3a4b5 --- /dev/null +++ b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata) +from vllm.attention.backends.utils import get_mla_dims +from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata, + kunlun_flash_mla_with_kvcache) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.distributed import get_tp_group + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer + +logger = init_logger(__name__) +""" +NOTE: FlashMLA Sparse uses an fp8 cache with the following format + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 + `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, + the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This + part is not quantized for accuracy. +""" + + +def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: + # Convert base-2 LSE to natural-log LSE + # Keep FP32 for numerical stability during the merge. + return (lse_base2.to(torch.float32) * math.log(2.0)) + + +class FlashMLASparseBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return FlashMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + +@dataclass +class MLASparsePrefillMetadata: + # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because + # the kernel is not from flashmla + block_table: torch.Tensor = None + has_context: bool = False + context_lens: Optional[torch.Tensor] = None + + # Sequence lengths (context + query) for prefill requests + # Shape: [num_prefill_reqs] + seq_lens: torch.Tensor = None + + # Request ID for each token: -1 for decode tokens, request index + # (0, 1, 2, ...) for prefill tokens. + # Shape: [num_actual_tokens] + request_ids: torch.Tensor = None + query_start_loc: torch.Tensor = None + query_start_loc_cpu: torch.Tensor = None + +@dataclass +class FlashMLASparseDecodeAndContextMetadata: + scheduler_metadata: torch.Tensor = None + num_splits: torch.Tensor = None + cache_lens: torch.Tensor = None + prefill_context_lengths: Optional[torch.Tensor] = None + prefill_new_k_start_locs: Optional[torch.Tensor] = None + dummy_block_table: torch.Tensor = None + + seq_lens: torch.Tensor = None + seq_lens_cpu: torch.Tensor = None + max_seq_len: int = -1 # needed for reshape in spec decode + + def filter_prefill_indices( + self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.prefill_context_lengths is not None + prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) + context_indices = torch.where(indices < prefill_context_lengths, + indices, -1) + new_token_indices = torch.where(indices >= prefill_context_lengths, + indices - prefill_context_lengths, -1) + return context_indices, new_token_indices + + +@dataclass +class FlashMLASparseMetadata: + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + block_size: int = 64 + topk_tokens: int = 2048 + + num_prefills: int = 0 + num_decodes: int = 0 + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + + decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None + prefill_metadata: Optional[MLASparsePrefillMetadata] = None + + @dataclass + class FP8KernelMetadata: + scheduler_metadata: Optional[torch.Tensor] + num_splits: torch.Tensor + dummy_block_table: torch.Tensor + cache_lens: torch.Tensor + + fp8_extra_metadata: Optional[FP8KernelMetadata] = None + + +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where(is_invalid_tok | (~valid_block), -1, + base * BLOCK_SIZE + inblock_off) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch. + Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ + f"BLOCK_N ({BLOCK_N})" + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + +def kunlun_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, +): + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + + out = torch.zeros_like(token_indices) + + # Compute block_id and inblock_off for all tokens at once + block_id = token_indices // BLOCK_SIZE + inblock_off = token_indices % BLOCK_SIZE + + # Create mask for invalid tokens (tok < 0) + invalid_tok_mask = token_indices < 0 + + # Create mask for out-of-bounds block_id + oob_block_mask = block_id >= max_num_blocks_per_req + + # Combine masks - output -1 for either condition + invalid_mask = invalid_tok_mask | oob_block_mask + + # Get request IDs expanded to match token_indices shape + req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS) + + # Gather base addresses from block_table + # Clamp block_id to avoid index errors (we'll mask these out anyway) + block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1) + + # Use advanced indexing to get base addresses + base_addrs = block_table[req_ids_expanded, block_id_clamped] + + # Compute the global indices + global_indices = base_addrs * BLOCK_SIZE + inblock_off + + # Apply mask: set invalid positions to -1 + out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices) + + return out + +def kunlun_concat_and_cache_mla( + kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] + k_pe: torch.Tensor, #[num_tokens, pe_dim] + kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] + kv_cache_dtype: str, + scale: torch.Tensor +): + num_tokens = slot_mapping.shape[0] + kv_lora_rank = kv_c.shape[1] + pe_dim = k_pe.shape[1] + block_size = kv_cache.shape[1] + + def kunlun_fp8_ds_mla(): + for token_idx in range(num_tokens): + slot = slot_mapping[token_idx].item() + if slot < 0: continue + block_idx = slot // block_size + block_offset = slot % block_size + kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous() + kv_c_i_int8 = torch.zeros( + kv_c_i.shape, + device=kv_c.device, + dtype=torch.int8, + ) + kv_c_i_scale = torch.zeros( + [kv_c_i.shape[0], 1], + device=kv_c.device, + dtype=torch.float32, + ) + torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True) + kv_c_i_scale /= 127 + kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous() + kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous() + kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous() + + def kunlun_mla(): + for token_idx in range(num_tokens): + slot = slot_mapping[token_idx].item() + if slot < 0: continue + block_idx = slot // block_size + block_offset = slot % block_size + kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous() + kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous() + + if (kv_cache_dtype == "fp8_ds_mla"): + assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla" + assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla" + assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla" + assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla" + assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla" + kunlun_fp8_ds_mla() + else: + assert kv_cache.shape[2] == kv_lora_rank + pe_dim + kunlun_mla() + + +@dataclass +class FlashMLASparseMetadataBuilder( + AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + self.vllm_config = vllm_config + self.layer_names = layer_names + cache_config = vllm_config.cache_config + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + + # Treat requests with query length <= 1 as decodes to match the + # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) + # 从最新版本vllm中引入的 + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + + self.num_heads = self.model_config.get_num_attention_heads( + parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" + + self.topk_tokens_tensor = torch.tensor([self.topk_tokens], + device=device, + dtype=torch.int32) + # self.max_model_len_tensor = torch.tensor( + # [self.model_config.max_model_len], + # device=device, + # dtype=torch.int32) + + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty((1, 1), + dtype=torch.int32, + device=self.device) + + # Equation taken from FlashMLA/csrc/pybind.cpp + h_q, h_k = self.num_heads, 1 + s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest + max_num_sm_parts = int( + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) + if current_platform.is_device_capability(100): + max_num_sm_parts *= 2 + self.tile_scheduler_metadata_buffer = torch.zeros( + # TileSchedulerMetaDataSize = 8 + # see: FlashMLA/csrc/params.h + (max_num_sm_parts, 8), + dtype=torch.int32, + device=device) + self.num_splits_buffer = torch.zeros( + # We pack all the tokens into one batch for sparse attention. + # Otherwise, we can exceed the sm of `get_mla_metadata`. + ( + 2, ), + dtype=torch.int32, + device=device) + self.req_id_per_token_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_batched_tokens, ), + dtype=torch.int32, + device=device) + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashMLASparseMetadata: + + num_tokens = common_attn_metadata.num_actual_tokens + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, + dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ + .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + fp8_extra_metadata = None + + if self.use_fp8_kv_cache: + cache_seqlens_cpu, cache_seqlens = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor, + ) + fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=None, + num_splits=None, + # cache_lens and block_table are basically unused in sparse case + # but the decode kernel will treat -1 and indices >= cache_lens + # as invalid so we make sure cache_lens is large enough to not + # accidentally mark indices invalid, we will use -1 exclusively + # to mark invalid indices + cache_lens=cache_seqlens_cpu, + dummy_block_table=self.dummy_block_table) + + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold or 1, + require_uniform=True, + ) + ) + + # For pure decode batches, prefill_request_id will be None + # For mixed batches, it will have -1 for decode and request_id for prefill + prefill_metadata = None + if num_prefills > 0: + prefill_metadata = MLASparsePrefillMetadata( + query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值 + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes], + ) + + decode_metadata = None + if num_decodes > 0: + max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max()) + + decode_metadata = FlashMLASparseDecodeAndContextMetadata( + max_seq_len=max_seq_len, + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes], + ) + + + metadata = FlashMLASparseMetadata( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + block_table=common_attn_metadata.block_table_tensor, + req_id_per_token=req_id_per_token, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + fp8_extra_metadata=fp8_extra_metadata, + num_prefills=num_prefills, + num_decodes=num_decodes, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + decode_metadata=decode_metadata, + prefill_metadata=prefill_metadata + ) + return metadata + + +class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + topk_indice_buffer: Optional[torch.Tensor] = None, + indexer: Optional["Indexer"] = None, + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer = indexer.topk_indices_buffer + self.padding = 128 if current_platform.is_device_capability( + 100) else 64 + + def _forward_bf16_kv( + self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view( + -1, kv_c_and_k_pe_cache.shape[-1]) + + # num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decodes = attn_metadata.num_decodes + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: + # Reshape q: (num_decode_tokens, num_heads, head_dim) + # -> (num_decodes, seq_len, num_heads, head_dim) + q = reshape_query_for_spec_decode(q, num_decodes) + seq_len = q.shape[1] + # Reshape topk_indices: (num_decode_tokens, topk) + # -> (num_decodes, seq_len, topk) + topk_indices = topk_indices.view(num_decodes, seq_len, -1) + decode_metadata = attn_metadata.decode_metadata + _attn_out, _, _ = kunlun_flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache, + head_dim_v=512, + cache_seqlens=decode_metadata.seq_lens, + cache_seqlens_cpu=decode_metadata.seq_lens_cpu, + is_fp8_kvcache=False, + indices=topk_indices, + softmax_scale=self.softmax_scale, + max_seq_kv=decode_metadata.max_seq_len + ) + # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) + # -> (num_decode_tokens, num_heads, head_dim_v) + return reshape_attn_output_for_spec_decode(_attn_out) + + def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: + prefill_metadata = attn_metadata.prefill_metadata + topk_indices = topk_indices.view(num_prefill_tokens, 1, -1) + # NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的 + _attn_out = flash_mla_sparse_prefill( + q=q, + kv=kv_c_and_k_pe_cache, + indices=topk_indices, + sm_scale=self.softmax_scale, + q_lod_xpu=prefill_metadata.query_start_loc, + q_lod_cpu=prefill_metadata.query_start_loc_cpu + )[0] + return _attn_out + + topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index( + req_id=attn_metadata.req_id_per_token, + block_table=attn_metadata.block_table, + token_indices=topk_indices, + block_size=attn_metadata.block_size, + num_topk_tokens=attn_metadata.topk_tokens, + ) + + attn_out = torch.empty( + (num_tokens, self.num_heads, self.kv_lora_rank), + dtype=q.dtype, + device=q.device, + ) + if has_prefill: + prefill_q = q[num_decode_tokens:] + prefill_topk_indices_global = topk_indices_global[num_decode_tokens:] + attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global) + + # 处理decode部分 - 需要正确的block table映射print + if has_decode: + decode_q = q[:num_decode_tokens] + decode_topk_indices_global = topk_indices_global[:num_decode_tokens] + attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global) + + return attn_out + + + def _forward_fp8_kv(self, q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + # TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function. + assert attn_metadata.fp8_extra_metadata is not None + extra_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = flash_mla_with_kvcache( + q=q.unsqueeze(0), # unsqueeze to add batch_dim + k_cache=kv_c_and_k_pe_cache, + block_table=extra_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=extra_metadata.cache_lens, + tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None + num_splits=extra_metadata.num_splits, # None + is_fp8_kvcache=True, + indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim + softmax_scale=self.softmax_scale, + max_seq_kv=attn_metadata.max_seq_len + ) + + return _attn_out + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode + + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl") + + if attn_metadata is None: + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + q = torch.cat([ql_nope, q_pe], dim=-1) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + torch.ops._C.concat_and_cache_mla( + kv_c=k_c_normed, + k_pe=k_pe.squeeze(1), + kv_cache=kv_cache, + slot_mapping=attn_metadata.slot_mapping.flatten(), + ) + + if self.kv_cache_dtype != "fp8_ds_mla": + attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, + attn_metadata) + else: + # attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, + # attn_metadata) + raise NotImplementedError + + self._v_up_proj(attn_out, out=output[:num_actual_toks]) + return output diff --git a/vllm_kunlun/v1/attention/backends/mla/indexer.py b/vllm_kunlun/v1/attention/backends/mla/indexer.py new file mode 100644 index 0000000..67471be --- /dev/null +++ b/vllm_kunlun/v1/attention/backends/mla/indexer.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from dataclasses import dataclass +from typing import ClassVar, Optional +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks, + DeepseekV32IndexerMetadataBuilder, + DeepseekV32IndexerPrefillMetadata) + +logger = init_logger(__name__) + +@dataclass +class DeepSeekV32IndexerDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + decode_lens: torch.Tensor + requires_padding: bool + schedule_metadata: torch.Tensor + + +@dataclass +class DeepseekV32IndexerMetadata: + + # FIXME (zyongye) + # hacky way to access the data now, need to be in chunked meta + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + # The dimension of the attention heads + head_dim: int + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None + prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None + +def kunlun_build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> DeepseekV32IndexerMetadata: + + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + chunk_seq_ids = split_prefill_chunks( + common_attn_metadata.seq_lens_cpu, + self.max_prefill_buffer_size, + num_decodes, + ) + chunks = [ + self.build_one_prefill_chunk( + reqs_start, reqs_end, query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu, + common_attn_metadata.block_table_tensor) + for reqs_start, reqs_end in chunk_seq_ids + ] + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + chunks=chunks, ) + + decode_metadata = None + if num_decodes > 0: + torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes]) + decode_lens = self.decode_lens_buffer[:num_decodes] + decode_lens_cpu = torch.diff( + common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) + + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() + > decode_lens_cpu.min()).item() + + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + + decode_metadata = DeepSeekV32IndexerDecodeMetadata( + block_table=common_attn_metadata. + block_table_tensor[:num_decodes, ...], + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(), + decode_lens=decode_lens, + requires_padding=requires_padding, + schedule_metadata=self.scheduler_metadata_buffer, + ) + + attn_metadata = DeepseekV32IndexerMetadata( + seq_lens=common_attn_metadata.seq_lens, + seq_lens_cpu=common_attn_metadata.seq_lens.cpu(), + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + head_dim=128, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + # if get_tensor_model_parallel_rank() == 0: + # logger.info(f"attn_metadata: {attn_metadata}") + return attn_metadata + +DeepseekV32IndexerMetadataBuilder.build = kunlun_build \ No newline at end of file diff --git a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py index 08a33f9..b080524 100644 --- a/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_kunlun/v1/sample/ops/topk_topp_sampler.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional -import os + import torch import torch.nn as nn from packaging import version @@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module): def __init__(self, logprobs_mode): super().__init__() + self.logprobs_mode = logprobs_mode logger.info_once( "Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_kunlun @@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module): The logits tensor may be updated in-place. """ - logits = apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == "processed_logits": + logits_to_return = logits + elif self.logprobs_mode == "processed_logprobs": + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators), None + return random_sample(probs, generators), logits_to_return def forward_kunlun( self, @@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module): p: Optional[torch.Tensor], ) -> torch.Tensor: """More optimized implementation for top-k and top-p sampling.""" - if k is None and p is None: - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators), None - if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + if (k is None and p is None) or generators: + if generators: + logger.debug_once( + "FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation." + ) return self.forward_native(logits, generators, k, p) # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous @@ -196,6 +199,7 @@ def flashinfer_sample( probs, top_k=k, deterministic=True) else: # Both top-k and top-p. + k = k.to(torch.int32) next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs( probs, top_k=k, top_p=p, deterministic=True) diff --git a/vllm_kunlun/v1/worker/utils.py b/vllm_kunlun/v1/worker/utils.py new file mode 100644 index 0000000..e31f887 --- /dev/null +++ b/vllm_kunlun/v1/worker/utils.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config +from vllm.multimodal.registry import MultiModalRegistry +from vllm.platforms import current_platform +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget +from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec + +if TYPE_CHECKING: + from vllm.attention.layer import Attention + + +class MultiModalBudget: + """Helper class to calculate budget information for multi-modal models.""" + + def __init__( + self, + model_config: ModelConfig, + scheduler_config: SchedulerConfig, + mm_registry: MultiModalRegistry, + ) -> None: + super().__init__() + + self.model_config = model_config + self.scheduler_config = scheduler_config + self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config( + model_config, mm_registry) + + self.max_model_len = model_config.max_model_len + self.max_num_reqs = scheduler_config.max_num_seqs + + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, + cache=cache) + + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config, + cache=cache) + + encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( + scheduler_config, + max_tokens_by_modality, + ) + + self.encoder_compute_budget = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + max_items_per_prompt_by_modality = dict[str, int]() + max_items_per_batch_by_modality = dict[str, int]() + + for modality, max_tokens in max_tokens_by_modality.items(): + ( + max_items_per_prompt, + max_items_per_batch, + ) = self.get_max_items(modality, max_tokens) + + max_items_per_prompt_by_modality[modality] = max_items_per_prompt + max_items_per_batch_by_modality[modality] = max_items_per_batch + + self.max_tokens_by_modality = max_tokens_by_modality + self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality + self.max_items_per_batch_by_modality = max_items_per_batch_by_modality + + def get_modality_with_max_tokens(self) -> str: + max_tokens_by_modality = self.max_tokens_by_modality + modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) + + return modality + + def get_encoder_budget(self) -> int: + return min(self.encoder_compute_budget, self.encoder_cache_size) + + def get_max_items( + self, + modality: str, + max_tokens_per_item: int, + ) -> tuple[int, int]: + if max_tokens_per_item == 0: + return 0, 0 + + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = self.get_encoder_budget() + + # TODO: handle encoder-decoder models once we support them. + if encoder_budget == 0: + return 0, 0 + + max_encoder_items_per_batch = encoder_budget // max_tokens_per_item + + # Check how many items of this modality can be supported by + # the decoder budget. + mm_limit = self.mm_limits[modality] + + max_items_per_prompt = max( + 1, + min(mm_limit, self.max_model_len // max_tokens_per_item), + ) + + scheduler_config = self.scheduler_config + max_num_reqs = self.max_num_reqs + + if not scheduler_config.enable_chunked_prefill: + max_num_reqs = min( + max_num_reqs, + scheduler_config.max_num_batched_tokens // max_tokens_per_item, + ) + + max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt + + max_items_per_batch = max( + 1, + min(max_encoder_items_per_batch, max_decoder_items_per_batch), + ) + + return max_items_per_prompt, max_items_per_batch + + +@dataclass +class AttentionGroup: + backend: type[AttentionBackend] + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. + metadata_builders: list[AttentionMetadataBuilder] + layer_names: list[str] + kv_cache_spec: KVCacheSpec + + @staticmethod + def create_with_metadata_builders( + backend: type[AttentionBackend], + layer_names: list[str], + kv_cache_spec: KVCacheSpec, + vllm_config: VllmConfig, + device: torch.device, + num_metadata_builders: int = 1, + ) -> 'AttentionGroup': + metadata_builders = [ + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, + device) + for _ in range(num_metadata_builders) + ] + return AttentionGroup(backend, metadata_builders, layer_names, + kv_cache_spec) + + def get_metadata_builder(self, + ubatch_id: int = 0) -> AttentionMetadataBuilder: + assert len(self.metadata_builders) > ubatch_id + return self.metadata_builders[ubatch_id] + + +def sanity_check_mm_encoder_outputs( + mm_embeddings: MultiModalEmbeddings, + expected_num_items: int, +) -> None: + """ + Perform sanity checks for the result of + [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][]. + """ + assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( + "Expected multimodal embeddings to be a list/tuple of 2D tensors, " + f"or a single 3D tensor, but got {type(mm_embeddings)} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + assert len(mm_embeddings) == expected_num_items, ( + "Expected number of multimodal embeddings to match number of " + f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + assert all(e.ndim == 2 for e in mm_embeddings), ( + "Expected multimodal embeddings to be a sequence of 2D tensors, " + f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + +def scatter_mm_placeholders( + embeds: torch.Tensor, + is_embed: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Scatter the multimodal embeddings into a contiguous tensor that represents + the placeholder tokens. + + [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][]. + + Args: + embeds: The multimodal embeddings. + Shape: `(num_embeds, embed_dim)` + is_embed: A boolean mask indicating which positions in the placeholder + tokens need to be filled with multimodal embeddings. + Shape: `(num_placeholders, num_embeds)` + """ + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed] = embeds + return placeholders + + +def gather_mm_placeholders( + placeholders: torch.Tensor, + is_embed: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Reconstructs the embeddings from the placeholder tokens. + + This is the operation of [`scatter_mm_placeholders`] + [vllm.v1.worker.utils.scatter_mm_placeholders]. + """ + if is_embed is None: + return placeholders + + return placeholders[is_embed] + + +def add_kv_sharing_layers_to_kv_cache_groups( + shared_kv_cache_layers: dict[str, str], + kv_cache_groups: list[KVCacheGroupSpec], + runner_only_attn_layers: Optional[set[str]] = None, +) -> None: + """ + Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` + for layers that do not allocate its own KV cache, based on the mapping in + `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache + group, which is needed to ensure that attention metadata is assigned later. + + Args: + shared_kv_cache_layers: Layer pairings for cross-layer KV sharing. + If an Attention layer `layer_name` is in the keys of this dict, it + means this layer will perform attention using the keys and values + from the KV cache of `shared_kv_cache_layers[layer_name]`. + kv_cache_groups: The KV cache groups of the model. + """ + layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} + for kv_cache_group in kv_cache_groups: + for layer_name in kv_cache_group.layer_names: + layer_to_kv_cache_group[layer_name] = kv_cache_group + + for layer_name, target_layer_name in shared_kv_cache_layers.items(): + tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] + tgt_kv_cache_group.layer_names.append(layer_name) + + if runner_only_attn_layers is not None: + runner_only_attn_layers.add(layer_name) + + +def bind_kv_cache( + kv_caches: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches: list[torch.Tensor], + num_attn_module: Optional[int] = 1, +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name, + num_attn_module)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + + # TODO - analyze where runner_kv_caches is used and the right + # way to ensure it properly reflects multiple attention layers + # in the same decoder block. + if current_platform.is_kunlun() or current_platform.is_cuda() or current_platform.is_xpu(): + # We know that the GPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] + + +def is_residual_scattered_for_sp(vllm_config: VllmConfig, + num_input_tokens: int) -> bool: + """Check if the residual tensor is scattered for sequence parallelism. + + The residual tensor is scattered across tensor parallel ranks when sequence + parallelism and tensor parallelism is enabled, and the number of + input tokens is one of the compilation sizes. + """ + if not vllm_config.compilation_config.pass_config.\ + enable_sequence_parallelism: + return False + + tp = vllm_config.parallel_config.tensor_parallel_size + + if tp == 1: + return False + + # When sequence parallelism is enabled, we always pad num_input_tokens + # to be a multiple of tensor_parallel_size (tp) earlier. + assert num_input_tokens % tp == 0 + + # Currently, SP is only enabled for static size fx graphs. + return (num_input_tokens in vllm_config.compilation_config.compile_sizes) diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index 7425ac1..c323aad 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1493,4 +1493,45 @@ def _fake_gptq_shuffle( return None -gptq_shuffle.register_fake(_fake_gptq_shuffle) \ No newline at end of file +gptq_shuffle.register_fake(_fake_gptq_shuffle) + +################################################## +# ---------------- concat_and_cache_mla ------------------ +################################################## +@custom_op("_C::concat_and_cache_mla", mutates_args=()) +def concat_and_cache_mla( + kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] + k_pe: torch.Tensor, #[num_tokens, pe_dim] + kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] +) -> None: + xtorch_ops.concat_and_cache_mla( + kv_c=kv_c, + k_pe=k_pe, + slot_mapping=slot_mapping, + kv_cache=kv_cache, + ) + +@impl("_C::concat_and_cache_mla", "CUDA") +def concat_and_cache_mla_cuda( + kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] + k_pe: torch.Tensor, #[num_tokens, pe_dim] + kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] +) -> None: + xtorch_ops.concat_and_cache_mla( + kv_c=kv_c, + k_pe=k_pe, + slot_mapping=slot_mapping, + kv_cache=kv_cache, + ) + +def _fake_concat_and_cache_mla( + kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] + k_pe: torch.Tensor, #[num_tokens, pe_dim] + kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] +) -> None: + return None + +concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla) \ No newline at end of file