diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py new file mode 100644 index 0000000..d8153e3 --- /dev/null +++ b/examples/offline_dualbatch_overlap_npu.py @@ -0,0 +1,51 @@ +import os +import time + +from vllm import LLM, SamplingParams + +# enable dual-batch overlap for vllm ascend +os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1" +os.environ["VLLM_USE_V1"] = "1" + +# Sample prompts. +prompts = ["The president of the United States is"] * 41 +# Create a sampling params object. +sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + + +def main(): + # Create an LLM. + llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic", + enforce_eager=True, + tensor_parallel_size=2, + max_model_len=4096, + trust_remote_code=True, + additional_config={ + "torchair_graph_config": { + "enabled": False + }, + "ascend_scheduler_config": { + "enabled": True + }, + "expert_tensor_parallel_size": 1 + }) + + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Add a buffer to wait for profiler in the background process + # (in case MP is on) to finish writing profiling output. + time.sleep(10) + + +if __name__ == "__main__": + main() diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index dd8da8c..50675cf 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -81,3 +81,17 @@ def test_models_distributed_topk() -> None: distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_distributed_DeepSeek_dbo(): + example_prompts = ["The president of the United States is"] * 41 + dtype = "half" + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ae3dd62..91ddf43 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,6 +13,9 @@ from vllm.model_executor.layers.linear import (LinearBase, from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.context import get_multistream_comm_context +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla if TYPE_CHECKING: @@ -117,6 +120,7 @@ class AscendMLAMetadata: with_prefill_across_dp: bool = False + query_lens: Optional[list[int]] = None # The dimension of the attention heads head_dim: Optional[int] = None attn_mask: torch.Tensor = None @@ -135,6 +139,17 @@ class AscendMLAMetadata: # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMLAMetadata"]: + """Split metadata for multi-stream with AscendMLAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + M = TypeVar("M", bound=AscendMLAMetadata) @@ -386,6 +401,7 @@ class AscendMLAMetadataBuilder: return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), num_decodes=self._num_decodes, @@ -585,7 +601,15 @@ class AscendMLAImpl(MLAAttentionImpl): ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - return self.o_proj(attn_output)[0] + + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self.o_proj(attn_output)[0] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self.o_proj(attn_output)[0] def exec_kv( self, @@ -685,7 +709,14 @@ class AscendMLAImpl(MLAAttentionImpl): context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) - return self._v_up_proj_and_o_proj(attn_output) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj_and_o_proj(attn_output) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj_and_o_proj(attn_output) def forward( self, @@ -811,16 +842,38 @@ class AscendMLAImpl(MLAAttentionImpl): key_cache=kv_cache, slot_indices=attn_metadata.slot_mapping.flatten()) if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[num_decode_tokens:] = output_prefill + current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill + if has_decode: if self.running_in_graph: return self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, - kv_cache, attn_metadata) + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[:num_decode_tokens] = output_decode + current_ms_metadata.after_comm_event.record() + else: + output[:num_decode_tokens] = output_decode + return output_padded diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 23557f1..9aa2d70 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -107,6 +107,8 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to enable the trace recompiles from pytorch. "VLLM_ASCEND_TRACE_RECOMPILES": lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), + "VLLM_ASCEND_ENABLE_DBO": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))), # Whether to enable the model execute time observe profile. Disable it when # running vllm ascend in production environment. "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index e7f021f..4357787 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -1,7 +1,10 @@ from vllm import ModelRegistry +import vllm_ascend.envs as envs + def register_model(): + from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401 from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 @@ -22,9 +25,14 @@ def register_model(): "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" ) - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") + if envs.VLLM_ASCEND_ENABLE_DBO: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + else: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") ModelRegistry.register_model( "DeepseekV3ForCausalLM", diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py new file mode 100644 index 0000000..c3de6ae --- /dev/null +++ b/vllm_ascend/models/deepseek_dbo.py @@ -0,0 +1,1118 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +import torch_npu +import vllm.envs as envs +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, + DeepseekV2DecoderLayer, + DeepseekV2MLAAttention) +from vllm.model_executor.models.utils import ( + PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_comm_context, + get_multistream_layer_context, set_multistream_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor + +VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO +VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 + + +class CustomDeepseekDBOMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + + def forward(self, x): + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + x = tensor_model_parallel_all_reduce(x) + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + def _forward_ms_mlp(self, x): + current_ms_metadata = get_multistream_comm_context() + assert current_ms_metadata is not None + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x = tensor_model_parallel_all_reduce(x) + current_ms_metadata.after_comm_event.record() + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x, _ = self.down_proj(x) + current_ms_metadata.after_comm_event.record() + return x + + +class CustomDeepseekDBOMoE(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = AscendFusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = CustomDeepseekDBOMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.shared_experts", + ) + CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + + self.params_dtype = torch.get_default_dtype() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata.num_prefills > 0 + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + num_tokens, hidden_size = hidden_states.shape + + old_hidden_states = hidden_states.clone() + + if self.tp_size > 1: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank] + elif not self.torchair_graph_enabled: + num_padding_tokens = (self.tp_size - + num_tokens % self.tp_size) % self.tp_size + # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C + if num_padding_tokens > 0: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, num_padding_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + hidden_states = chunk_hidden_states[self.tp_rank] + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekDBOMoE.top_k, + enable_force_load_balance=enable_force_load_balance, + ) * self.routed_scaling_factor + + if self.tp_size > 1: + if self.torchair_graph_enabled: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + final_hidden_states = torch.zeros( + [num_tokens, hidden_size], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, + hidden_states, self.tp_group) + hidden_states = final_hidden_states + else: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + else: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_padding_tokens > 0: + hidden_states = hidden_states[:-num_padding_tokens] + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(old_hidden_states) + + if shared_output is not None: + hidden_states = hidden_states + shared_output + + return hidden_states.view(num_tokens, hidden_size) + + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_op_shared_expert( + self, + hidden_states: torch.Tensor, + ): + shared_output = self.shared_experts._forward_ms_mlp(hidden_states) + return shared_output + + def _forward_ms_op_gate( + self, + hidden_states: torch.Tensor, + ): + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + return router_logits + + def _forward_ms_op_tp_allgather( + self, + hidden_states: torch.Tensor, + chunk_hidden_states: torch.Tensor, + num_tokens: int = 0, + ): + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + current_ms_metadata.after_comm_event.record() + return final_hidden_states + + +class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + if self.torchair_graph_enabled: + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + if envs.VLLM_USE_V1: + output = output.view(-1, output_shape[-1]) + return output + else: + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) + + +class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = CustomDeepseekDBOMLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = CustomDeepseekDBOMoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = CustomDeepseekDBOMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, CustomDeepseekDBOMoE): + hidden_states = self.mlp(hidden_states, attn_metadata) + else: + hidden_states = self.mlp(hidden_states) + + if isinstance( + self.mlp, + CustomDeepseekDBOMLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_layer( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + layer_index, ms_metadata, _ = get_multistream_layer_context() + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert isinstance(self.mlp, CustomDeepseekDBOMoE) + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [] + hidden_dims = [] + shared_outputs = [] + router_logits = [] + chunk_hidden_states = [] + + # block 1 : attention + # block 2 : attn tp communication + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_AR_FINISH], + ) + + with set_multistream_context(context, i): + forward_context = get_forward_context() + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + + # block 3 : shared experts + # if there is an allreduce ops in shared expert, we can overlap it with the computation of the + # shared expert for next microbatch or moe gating + for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.ATTN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMP_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMM_FINISH], + ) + with set_multistream_context(context, i): + # compute shared expert after finishing ATTN AR + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) + + num_token, hidden_dim = hidden_states[i].shape + hidden_states[i] = hidden_states[i].view(-1, hidden_dim) + num_tokens.append(num_token) + hidden_dims.append(hidden_dim) + if self.mlp.n_shared_experts is not None: + # TODO: we can move shared expert computation into next block if reduce results is false + shared_output = self.mlp._forward_ms_op_shared_expert( + hidden_states[i]) + shared_outputs.append(shared_output) + + # block 4 : moe + for i in range(num_micro_batchs): + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata[i] is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata[i].num_prefills > 0 + enable_force_load_balance = False + + if self.mlp.tp_size > 1: + num_token, _ = hidden_states[i].shape + padded_num_tokens = (self.mlp.tp_size - num_token % + self.mlp.tp_size) % self.mlp.tp_size + if padded_num_tokens > 0: + hidden_states[i] = nn.functional.pad( + hidden_states[i], (0, 0, 0, padded_num_tokens)) + chunk_hidden_state = torch.tensor_split(hidden_states[i], + self.mlp.tp_size, + dim=0) + chunk_hidden_states.append(chunk_hidden_state) + local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] + else: + local_hidden_states = hidden_states[i] + + router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) + router_logits.append(router_logit) + + if CustomDeepseekDBOMoE.top_k: + real_top_k = CustomDeepseekDBOMoE.top_k + else: + real_top_k = self.mlp.experts.top_k + + hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( + local_hidden_states, router_logits[i], is_prefill, real_top_k, + enable_force_load_balance) + + # the following kernels will be submitted to the comm stream to overlap the computation of the + # moe computation of next microbatch and the attn computation of next layer + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + context.before_comm_event.record() + with torch.npu.stream(ms_metadata.communicate_stream): + context.before_comm_event.wait() + if self.mlp.experts.reduce_results and ( + self.mlp.experts.tp_size > 1 + or self.mlp.experts.ep_size > 1): + hidden_states[i] = tensor_model_parallel_all_reduce( + hidden_states[i]) + hidden_states[ + i] = hidden_states[i] * self.mlp.routed_scaling_factor + context.after_comm_event.record() + + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_AR_FINISH], + ) + with set_multistream_context(context, i): + if self.mlp.tp_size > 1: + hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( + hidden_states[i], chunk_hidden_states[i], + padded_num_tokens) + with torch.npu.stream(ms_metadata.communicate_stream): + # last + if shared_outputs[i] is not None: + hidden_states[i] = hidden_states[i] + shared_outputs[i] + hidden_states[i] = hidden_states[i].view( + num_tokens[i], hidden_dims[i]) + if isinstance(self.mlp, CustomDeepseekDBOMLP + ) and hidden_states[i].dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states[i] *= 1. / self.routed_scaling_factor + context.after_comm_event.record() + return hidden_states, residual + + # should split ops in Decoder Layer + def _forward_ms_op_input_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_ms_op_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + return hidden_states, residual + + def _forward_ms_op_post_attn_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + return hidden_states, residual + + +class CustomDeepseekDBOModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomDeepseekDBODecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + # tbo related members + if VLLM_ASCEND_ENABLE_DBO: + self.use_mla = model_config.use_mla + self.multistream_config = MultiStreamConfig() + multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer + self.first_k_dense_replace, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + num_normal_layers = (self.first_k_dense_replace + if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer) + + for i in range(self.start_layer, self.start_layer + num_normal_layers): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata) + + moe_start_layer = self.start_layer + num_normal_layers + if moe_start_layer != self.end_layer: + # if we enable multistream/dbo, process sparse layers here + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer, + kv_caches=kv_caches, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def can_run_ms(self): + attn_metadata = get_forward_context().attn_metadata + # support mla attention and V1 engine at present + if not self.use_mla or not envs.VLLM_USE_V1: + return False + # enable prefill overlap + if attn_metadata is None or attn_metadata.num_prefills == 0: + return False + else: + [token_index, seq_index + ] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return False + # check whether the total tokens exceed the threshold + if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + return False + return True + + def _forward_ms_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False, + ): + + if moe_start_layer == self.end_layer: + return hidden_states, residual + + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) + # the rest layers + for i in range(moe_start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer._forward_ms_layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + attn_metadata=attn_metadata, + kv_cache=kv_caches[i - self.start_layer] + if kv_caches is not None else None, + is_prefill=is_prefill) + advance_step_multistream_layer_context() + + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) + return hidden_states, residual + + +class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomDeepseekDBOModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm_ascend/multistream/__init__.py b/vllm_ascend/multistream/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py new file mode 100644 index 0000000..fba58b4 --- /dev/null +++ b/vllm_ascend/multistream/base.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from enum import Enum + + +class MSEventKey(Enum): + ATTN_COM_FINISH = 0 + ATTN_AR_FINISH = 1 + FFN_COM_FINISH = 2 + FFN_AR_FINISH = 3 + # events for MOE dispatch and combine + MOE_BEFORE_COMM = 4 + MOE_AFTER_COMM = 5 + # events for shared expert + MOE_SE_COMM_FINISH = 6 + MOE_SE_COMP_FINISH = 7 + MOE_GATE_FINISH = 8 + + +@dataclass +class MSAttentionMetadataSplitConfig: + """ + micro batch split config for split attention metadata + """ + # micro batch num + num_micro_batches: int = 2 + # split micro batches only when total tokens >= min_total_tokens_to_split + min_total_tokens_to_split: int = 256 + # split micro batches only when prefill tokens >= min_prefill_tokens_to_split + min_prefill_tokens_to_split: int = 64 diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py new file mode 100644 index 0000000..a1684f2 --- /dev/null +++ b/vllm_ascend/multistream/context.py @@ -0,0 +1,67 @@ +from contextlib import contextmanager +from typing import Any + +_ms_comm_context: Any = None +_cur_micro_batch_num: int = -1 +_ms_layer_index_context: int = -1 +_ms_metadata_context: Any = None +_ms_attn_metadata_context: Any = None + + +def set_multistream_layer_context(start_layer: int, ms_metadata: Any, + attn_metadata: Any): + """ + set multistream layer context before transformer layers + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = start_layer + _ms_metadata_context = ms_metadata + _ms_attn_metadata_context = attn_metadata + + +def reset_multistream_layer_context(): + """ + reset multistream layer context + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = -1 + _ms_metadata_context = None + _ms_attn_metadata_context = None + + +def get_multistream_layer_context(): + """ + get multistream layer context + """ + return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + + +def advance_step_multistream_layer_context(): + """ + advance multistream layer index context + """ + global _ms_layer_index_context + _ms_layer_index_context += 1 + + +def get_multistream_comm_context() -> Any: + """Get the current comm forward context.""" + return _ms_comm_context + + +def get_multistream_microbatch_context() -> int: + return _cur_micro_batch_num + + +@contextmanager +def set_multistream_context(context: Any, micro_batch_num: int): + """A context manager that stores the current comm forward context, + can be attention metadata, etc.""" + global _ms_comm_context, _cur_micro_batch_num + _ms_comm_context = context + _cur_micro_batch_num = micro_batch_num + try: + yield + finally: + _ms_comm_context = None + _cur_micro_batch_num = -1 diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py new file mode 100644 index 0000000..6c7f16a --- /dev/null +++ b/vllm_ascend/multistream/decorator.py @@ -0,0 +1,26 @@ +from vllm.logger import init_logger + +from .context import (get_multistream_layer_context, + get_multistream_microbatch_context) + +logger = init_logger(__name__) + + +# vllm v1 use get_forward_context to get the attn_metadata, +# we can use this decorator to update the attn metadata +def set_multistream_support(): + + def decorator(func): + + def wrapper(): + context = func() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + micro_batch_num = get_multistream_microbatch_context() + if layer_index != -1 and micro_batch_num != -1: + context.attn_metadata = attn_metadata[micro_batch_num] + return context + + return wrapper + + return decorator diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py new file mode 100644 index 0000000..c5273bc --- /dev/null +++ b/vllm_ascend/multistream/layers.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Tuple, Union + +import torch +from vllm.forward_context import get_forward_context + +from .base import MSEventKey +from .context import (get_multistream_layer_context, + reset_multistream_layer_context, + set_multistream_layer_context) +from .metadata import MultiStreamMetadata + + +class MultiStreamPreTransformerLayer(torch.nn.Module): + + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + + def forward( + self, + intput_tensors: List[torch.Tensor], + ): + attn_metadata = get_forward_context().attn_metadata + if self.multistream_metadata is None or attn_metadata is None: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + # TODO add attn_metadata management + do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch( + attn_metadata, intput_tensors) + if do_ms: + set_multistream_layer_context( + self.multistream_metadata.start_layer, + self.multistream_metadata, attn_metadata) + else: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + + +class MultiStreamPostTransformerLayer(torch.nn.Module): + + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + + def forward(self, + input_tensors: Union[List[Tuple[torch.Tensor]], + List[torch.Tensor], + List[List[torch.Tensor]]], + wait_layer_index: Optional[int] = None): + if self.multistream_metadata is None or self.multistream_metadata.ms_config is None: + return input_tensors + layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context( + ) + if layer_index >= 0: + true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index + self.multistream_metadata.try_wait_event( + true_wait_layer, + self.multistream_metadata.ms_config.num_micro_batches - 1, + MSEventKey.FFN_AR_FINISH) + reset_multistream_layer_context() + return self.multistream_metadata.merge_micro_batches(input_tensors) diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py new file mode 100644 index 0000000..b521d3f --- /dev/null +++ b/vllm_ascend/multistream/metadata.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from vllm.sequence import IntermediateTensors + +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata + +from .base import MSAttentionMetadataSplitConfig, MSEventKey + + +def split_micro_batches_tensors(input_tensors, + split_index: int, + keys: Optional[List[str]] = None): + if isinstance(input_tensors, list): + micro_batches = [] + for tensor in input_tensors: + if tensor is None: + micro_batches.append([None, None]) + else: + micro_batches.append( + [tensor[:split_index], tensor[split_index:]]) + return micro_batches + elif isinstance(input_tensors, torch.Tensor): + return [input_tensors[:split_index], input_tensors[split_index:]] + elif input_tensors is None: + return [None, None] + elif isinstance(input_tensors, Dict): + assert keys is not None + micro_batches_pre = {} + for key in keys: + micro_batches_pre[key] = input_tensors[key][:split_index] + micro_batches_post = {} + for key in keys: + micro_batches_post[key] = input_tensors[key][split_index:] + return [micro_batches_pre, micro_batches_post] + else: + raise NotImplementedError + + +@dataclass +class MultiStreamStepMetadata: + comm_stream: torch.npu.Stream = None + before_comm_event: torch.npu.Event = None + after_comm_event: torch.npu.Event = None + + +@dataclass +class MultiStreamConfig: + """Controls the behavior of multi-stream models.""" + min_total_tokens_to_split: int = 256 + min_prefill_tokens_to_split: int = 64 + num_micro_batches: int = 2 + imbalance_ratio: float = 0.1 + + +class MultiStreamMetadata: + # direct stream + calculate_stream = None + # delay stream + communicate_stream = None + # events + ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {} + # multi-stream-flag + enable_multi_stream: bool = False + + def __init__( + self, + calculate_stream: torch.npu.Stream, + communicate_stream: torch.npu.Stream, + start_layer: int, + end_layer: int, + event_keys: List[MSEventKey], + multistream_config: Optional[MultiStreamConfig], + causal_lm: bool = True, + ): + self.calculate_stream = calculate_stream + self.communicate_stream = communicate_stream + self.start_layer = start_layer + self.end_layer = end_layer + self.ms_config = multistream_config + self.causal_lm = causal_lm + self._build_events(event_keys) + self._build_ms_split_config() + + def _build_events(self, event_keys): + if self.ms_config is not None: + for i in range(self.start_layer - 1, self.end_layer): + self.ms_events[i] = {} + for j in range(self.ms_config.num_micro_batches): + self.ms_events[i][j] = {} + for key in event_keys: + self.ms_events[i][j][key] = torch.npu.Event() + + def _build_ms_split_config(self): + if self.ms_config is not None: + self.ms_split_config = MSAttentionMetadataSplitConfig( + num_micro_batches=self.ms_config.num_micro_batches, + min_total_tokens_to_split=self.ms_config. + min_total_tokens_to_split, + min_prefill_tokens_to_split=self.ms_config. + min_prefill_tokens_to_split, + ) + + def try_wait_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].wait() + + def try_record_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].record() + + def split_micro_batch( + self, + attn_metadata: "AscendMLAMetadata", + intput_tensors: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors_keys: Optional[List[str]] = None, + ) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[ + List[torch.Tensor], List[List[torch.Tensor]]], Union[ + IntermediateTensors, List[IntermediateTensors]]]: + attn_metadata_list = attn_metadata.split_metadata_for_multistream( + self.ms_split_config) + if len(attn_metadata_list) == 1: + return False, attn_metadata_list[ + 0], intput_tensors, intermediate_tensors + split_index = attn_metadata_list[0].slot_mapping.shape[0] + input_tensors = split_micro_batches_tensors(intput_tensors, + split_index) + if intermediate_tensors is not None: + inter_tensors_list = split_micro_batches_tensors( + intermediate_tensors.tensors, split_index, + intermediate_tensors_keys) + intermediate_tensors = [ + IntermediateTensors(inter_tensors) + for inter_tensors in inter_tensors_list + ] + return True, attn_metadata_list, input_tensors, intermediate_tensors + + def merge_micro_batches( + self, input_tensors: Union[List[torch.Tensor], + List[List[torch.Tensor]]] + ) -> List[torch.Tensor]: + if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): + return input_tensors + batch: List[Optional[torch.Tensor]] = [] + for tensors in input_tensors: + if tensors is None or tensors[0] is None: + batch.append(None) + else: + batch.append(torch.cat(tensors, dim=0)) + return batch + + +def make_multistream_metadata_ds( + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, +): + if multistream_config is None: + return None + event_keylist = [ + MSEventKey.ATTN_COM_FINISH, + MSEventKey.ATTN_AR_FINISH, + MSEventKey.FFN_COM_FINISH, + MSEventKey.FFN_AR_FINISH, + MSEventKey.MOE_BEFORE_COMM, + MSEventKey.MOE_AFTER_COMM, + MSEventKey.MOE_SE_COMM_FINISH, + MSEventKey.MOE_SE_COMP_FINISH, + MSEventKey.MOE_GATE_FINISH, + ] + return MultiStreamMetadata( + calculate_stream=torch.npu.current_stream(), + communicate_stream=torch.npu.Stream(), + start_layer=start_layer, + end_layer=end_layer, + multistream_config=multistream_config, + event_keys=event_keylist, + causal_lm=causal_lm, + ) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py new file mode 100644 index 0000000..430f57b --- /dev/null +++ b/vllm_ascend/multistream/ms_split.py @@ -0,0 +1,245 @@ +from copy import deepcopy +from typing import Any, List, Optional + +import numpy as np +import torch + +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +from .base import MSAttentionMetadataSplitConfig + + +def compute_split_seq_index( + query_lens: Optional[list[int]], + attn_state: AscendAttentionState, + num_tokens: int, + imbalance_ratio: float = 0.1, +) -> list[int]: + if attn_state != AscendAttentionState.DecodeOnly: + assert query_lens is not None + total_tokens = sum(query_lens) + # the first index in last split + tokens, split_index = 0, 0 + for value in query_lens: + tokens += value + split_index += 1 + if tokens >= total_tokens // 2: + # check the current split index + if abs(tokens - + total_tokens // 2) < total_tokens * imbalance_ratio: + return [tokens, split_index] + # check the previous split index + elif abs(tokens - total_tokens // 2 - + value) < total_tokens * imbalance_ratio: + return [tokens - value, split_index - 1] + # fail to split if it is imbalanced + # TODO: split tokens in seq + else: + return [0, 0] + else: + tokens = num_tokens // 2 + return [tokens, tokens] + return [0, 0] + + +def split_attn_tensor_type( + input_tensor: torch.Tensor, + index: int, +) -> List[torch.Tensor]: + return [input_tensor[:index], input_tensor[index:]] + + +def split_attn_int_type( + var: int, + index: int, +) -> List[torch.Tensor]: + return [min(var, index), max(var - index, 0)] + + +def model_input_split_v1_mla_attn( + attn_metadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, +) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + if attn_metadata is None: + return [attn_metadata] + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] + + query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ), + dtype=int) + np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) + if attn_metadata.num_prefills > 0: + prefill_query_start_loc = np.zeros( + shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int) + np.cumsum(attn_metadata.prefill.query_lens, + out=prefill_query_start_loc[1:]) + + # split attn metadata + [slot_mapping_pre, + slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, + token_index) + [num_decodes_pre, + num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, + seq_index) + [num_decode_tokens_pre, num_decode_tokens_post + ] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) + [num_prefills_pre, num_prefills_post + ] = split_attn_int_type(attn_metadata.num_prefills, + max(0, seq_index - attn_metadata.num_decodes)) + seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens + [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] + [block_table_pre, + block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, + seq_index) + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + else: + # chunked prefill + if num_prefills_pre > 0: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, + AscendMLAPrefillMetadata) + if num_prefills_pre > 0: + # split metadata.prefill + [input_positions_pre, input_positions_post] = split_attn_tensor_type( + attn_metadata.prefill.input_positions, + token_index - attn_metadata.num_decode_tokens) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.prefill.block_table, + seq_index - attn_metadata.num_decodes) + [prefill_query_lens_pre, prefill_query_lens_post + ] = split_attn_tensor_type(attn_metadata.prefill.query_lens, + seq_index - attn_metadata.num_decodes) + prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[: + seq_index + + + 1 - + attn_metadata + . + num_decodes] + prefill_query_start_loc_post = deepcopy( + attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes:] + ) - attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes] + context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] + context_len_post = seq_lens_post + prefill_max_query_len_pre = max(prefill_query_lens_pre) + prefill_max_query_len_post = max(prefill_query_lens_post) + prefill_pre = AscendMLAPrefillMetadata( + attn_mask=attn_mask_pre, + query_lens=prefill_query_lens_pre, + seq_lens=seq_lens_pre, + query_start_loc=prefill_query_start_loc_pre, + input_positions=input_positions_pre, + context_lens=context_len_pre, + block_table=block_tables_pre, + max_query_len=prefill_max_query_len_pre, + max_seq_lens=context_len_pre.max().item(), + ) + prefill_post = AscendMLAPrefillMetadata( + attn_mask=attn_mask_post, + query_lens=prefill_query_lens_post, + seq_lens=seq_lens_post, + query_start_loc=prefill_query_start_loc_post, + input_positions=input_positions_post, + context_lens=context_len_post, + block_table=block_tables_post, + max_query_len=prefill_max_query_len_post, + max_seq_lens=context_len_post.max().item(), + ) + decode_pre = attn_metadata.decode + decode_post = None + else: + # prefill is None, split metadata.decode + [input_positions_pre, input_positions_post + ] = split_attn_tensor_type(attn_metadata.decode.input_positions, + token_index) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.decode.block_table, + seq_index) + [decode_seq_lens_pre, + decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + decode_pre = AscendMLADecodeMetadata( + input_positions=input_positions_pre, + block_table=block_tables_pre, + seq_lens=decode_seq_lens_pre, + max_seq_lens=max(decode_seq_lens_pre), + seq_lens_list=decode_seq_lens_pre.tolist(), + ) + decode_post = AscendMLADecodeMetadata( + input_positions=input_positions_post, + block_table=block_tables_post, + seq_lens=decode_seq_lens_post, + max_seq_lens=max(decode_seq_lens_post), + seq_lens_list=decode_seq_lens_post.tolist(), + ) + prefill_pre = None + prefill_post = attn_metadata.prefill + # construct metadata + from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + num_input_tokens=token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_pre, + seq_lens=seq_lens_pre, + query_start_loc=query_start_loc_pre, + block_tables=block_table_pre, + num_decodes=num_decodes_pre, + num_prefills=num_prefills_pre, + num_decode_tokens=num_decode_tokens_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + prefill=prefill_pre, + decode=decode_pre, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + ) + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_post, + seq_lens=seq_lens_post, + query_start_loc=query_start_loc_post, + block_tables=block_table_post, + num_decodes=num_decodes_post, + num_prefills=num_prefills_post, + num_decode_tokens=num_decode_tokens_post, + attn_mask=attn_mask_post, + attn_state=attn_state_post, + prefill=prefill_post, + decode=decode_post, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + ) + return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fd06d90..56df04e 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1151,3 +1151,32 @@ class AscendFusedMoE(FusedMoE): if self.enable_multistream_shared_expert and not is_prefill: return hidden_states, shared_output return hidden_states + + # ----------------------------------------- TBO-related -------------------------------------------- + + def _forward_ms_fused_moe_comp( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, + ): + hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance) + + return hidden_states