From 87ebaef4e4e519988f27a6aa378f614642202ecf Mon Sep 17 00:00:00 2001 From: zxdukki Date: Sat, 7 Jun 2025 16:46:58 +0800 Subject: [PATCH] [perf]: support dual-batch overlap(dbo) for deepseek (#941) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? Based on the design of dual-batch overlap proposed by Deepseek team and also the implementation of fused moe in VLLM project, we implement the multi-stream(also known as dual-batch) overlap for deepseek+mla on Ascend NPU. We split the input batch of model into two microbatches and then overlap the comp/comm ops in attention and moe layers using two streams to improve the performance. Our approach can be easily extended when adding dispatch/combine communications for moe layer. Compared with the previously proposed [draft](https://github.com/vllm-project/vllm-ascend/pull/842), we use one stream for computation ops and the other for communication ops, separately. In out opinions, it is beneficial for arranging the order of executing different ops and thus avoiding the contention of computation/communication resources. ref: [overlap for llama](https://github.com/vllm-project/vllm/pull/15787/files) ref: [dbo in sglang](https://github.com/sgl-project/sglang/pull/4068/files#diff-b4937569fc71f6ad215181b633b2f89c7183a2b4ac39e41fc22635599a9be7de) ### Does this PR introduce _any_ user-facing change? Adding an env variable "VLLM_ENABLE_DBO". Users can enable dbo by setting "VLLM_ASCEND_ENABLE_DBO=1" See /examples/offline_dualbatch_overlap_npu.py for more info. ### How was this patch tested? This patch can be tested with vllm-0.9.0 using its online service with benchmark tests. We have decoupled the func of dbo from vllm and it should be able to run without any modification to the code of vllm(some modifications is better to implement in vllm though). Any advice/discussion is welcome. ### Performance Benchmark We have ran the benchmark_serving script of vllm to test the performance after using dual-batch overlap. `python -m vllm.entrypoints.openai.api_server \ --model=DeepSeek-R1-W8A8 \ --trust-remote-code \ --distributed-executor-backend=mp \ -tp=16 \ --port 8006 \ --max-num-seqs 390 \ --max-model-len 32768 \ --max-num-batched-tokens 65536 \ --block-size 128 \ --compilation_config 0 \ --gpu-memory-utilization 0.90 \ --disable-log-requests \ --additional-config '{"expert_tensor_parallel_size":1,"enable_inter_dp_scheduling":true,"init_torchair_graph_batch_sizes":true,"trace_recompiles":true,"ascend_scheduler_config":{},"enable_graph_mode":false}'` and run benchmark with the parameters of : `--dataset-name random --random-input-len 4096 --random-output-len 1 --num-prompts 200 --max-concurrency 8 --request-rate 5 --metric-percentiles 90` 1. test with the version using allgather+allreduce in Ascend 910B (tp16 ep16 + deepseek r1 w8a8) 2. test with the version using alltoall: prefill qps: 0.90 -> 1.01 Mean TTFT:8226->7432ms The overlap approach when using alltoall communication can be further optimized by overlapping micro-batch1's moe comp with micro-batch2's dispatch a2a comm --------- Signed-off-by: zhuohuan --- examples/offline_dualbatch_overlap_npu.py | 51 + .../test_offline_inference_distributed.py | 14 + vllm_ascend/attention/mla_v1.py | 69 +- vllm_ascend/envs.py | 2 + vllm_ascend/models/__init__.py | 14 +- vllm_ascend/models/deepseek_dbo.py | 1118 +++++++++++++++++ vllm_ascend/multistream/__init__.py | 0 vllm_ascend/multistream/base.py | 29 + vllm_ascend/multistream/context.py | 67 + vllm_ascend/multistream/decorator.py | 26 + vllm_ascend/multistream/layers.py | 61 + vllm_ascend/multistream/metadata.py | 182 +++ vllm_ascend/multistream/ms_split.py | 245 ++++ vllm_ascend/ops/fused_moe.py | 29 + 14 files changed, 1896 insertions(+), 11 deletions(-) create mode 100644 examples/offline_dualbatch_overlap_npu.py create mode 100644 vllm_ascend/models/deepseek_dbo.py create mode 100644 vllm_ascend/multistream/__init__.py create mode 100644 vllm_ascend/multistream/base.py create mode 100644 vllm_ascend/multistream/context.py create mode 100644 vllm_ascend/multistream/decorator.py create mode 100644 vllm_ascend/multistream/layers.py create mode 100644 vllm_ascend/multistream/metadata.py create mode 100644 vllm_ascend/multistream/ms_split.py diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py new file mode 100644 index 0000000..d8153e3 --- /dev/null +++ b/examples/offline_dualbatch_overlap_npu.py @@ -0,0 +1,51 @@ +import os +import time + +from vllm import LLM, SamplingParams + +# enable dual-batch overlap for vllm ascend +os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1" +os.environ["VLLM_USE_V1"] = "1" + +# Sample prompts. +prompts = ["The president of the United States is"] * 41 +# Create a sampling params object. +sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + + +def main(): + # Create an LLM. + llm = LLM(model="deepseek-ai/DeepSeek-V3-Lite-base-latest-w8a8-dynamic", + enforce_eager=True, + tensor_parallel_size=2, + max_model_len=4096, + trust_remote_code=True, + additional_config={ + "torchair_graph_config": { + "enabled": False + }, + "ascend_scheduler_config": { + "enabled": True + }, + "expert_tensor_parallel_size": 1 + }) + + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Add a buffer to wait for profiler in the background process + # (in case MP is on) to finish writing profiling output. + time.sleep(10) + + +if __name__ == "__main__": + main() diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index dd8da8c..50675cf 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -81,3 +81,17 @@ def test_models_distributed_topk() -> None: distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_distributed_DeepSeek_dbo(): + example_prompts = ["The president of the United States is"] * 41 + dtype = "half" + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ae3dd62..91ddf43 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,6 +13,9 @@ from vllm.model_executor.layers.linear import (LinearBase, from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig +from vllm_ascend.multistream.context import get_multistream_comm_context +from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla if TYPE_CHECKING: @@ -117,6 +120,7 @@ class AscendMLAMetadata: with_prefill_across_dp: bool = False + query_lens: Optional[list[int]] = None # The dimension of the attention heads head_dim: Optional[int] = None attn_mask: torch.Tensor = None @@ -135,6 +139,17 @@ class AscendMLAMetadata: # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMLAMetadata"]: + """Split metadata for multi-stream with AscendMLAMetadata""" + return model_input_split_v1_mla_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMLAMetadata, + ) + M = TypeVar("M", bound=AscendMLAMetadata) @@ -386,6 +401,7 @@ class AscendMLAMetadataBuilder: return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, + query_lens=query_lens.tolist(), slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), num_decodes=self._num_decodes, @@ -585,7 +601,15 @@ class AscendMLAImpl(MLAAttentionImpl): ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) - return self.o_proj(attn_output)[0] + + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self.o_proj(attn_output)[0] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self.o_proj(attn_output)[0] def exec_kv( self, @@ -685,7 +709,14 @@ class AscendMLAImpl(MLAAttentionImpl): context_lens=attn_metadata.decode.seq_lens, # type:ignore mla_vheadsize=self.kv_lora_rank, out=attn_output) - return self._v_up_proj_and_o_proj(attn_output) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj_and_o_proj(attn_output) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj_and_o_proj(attn_output) def forward( self, @@ -811,16 +842,38 @@ class AscendMLAImpl(MLAAttentionImpl): key_cache=kv_cache, slot_indices=attn_metadata.slot_mapping.flatten()) if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + output_prefill = self._forward_prefill(prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[num_decode_tokens:] = output_prefill + current_ms_metadata.after_comm_event.record() + else: + output[num_decode_tokens:] = output_prefill + if has_decode: if self.running_in_graph: return self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, - kv_cache, attn_metadata) + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + output[:num_decode_tokens] = output_decode + current_ms_metadata.after_comm_event.record() + else: + output[:num_decode_tokens] = output_decode + return output_padded diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 23557f1..9aa2d70 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -107,6 +107,8 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to enable the trace recompiles from pytorch. "VLLM_ASCEND_TRACE_RECOMPILES": lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), + "VLLM_ASCEND_ENABLE_DBO": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))), # Whether to enable the model execute time observe profile. Disable it when # running vllm ascend in production environment. "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index e7f021f..4357787 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -1,7 +1,10 @@ from vllm import ModelRegistry +import vllm_ascend.envs as envs + def register_model(): + from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401 from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 @@ -22,9 +25,14 @@ def register_model(): "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" ) - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") + if envs.VLLM_ASCEND_ENABLE_DBO: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + else: + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") ModelRegistry.register_model( "DeepseekV3ForCausalLM", diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py new file mode 100644 index 0000000..c3de6ae --- /dev/null +++ b/vllm_ascend/models/deepseek_dbo.py @@ -0,0 +1,1118 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +import torch_npu +import vllm.envs as envs +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # ruff: noqa: E501 +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, + DeepseekV2DecoderLayer, + DeepseekV2MLAAttention) +from vllm.model_executor.models.utils import ( + PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_comm_context, + get_multistream_layer_context, set_multistream_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.multistream.ms_split import compute_split_seq_index +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor + +VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO +VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 + + +class CustomDeepseekDBOMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + + def forward(self, x): + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + x = tensor_model_parallel_all_reduce(x) + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + def _forward_ms_mlp(self, x): + current_ms_metadata = get_multistream_comm_context() + assert current_ms_metadata is not None + if self.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + x = torch_npu.npu_quant_matmul( + x, + self.gate_up_proj.weight, + self.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=x, + weight_scale=self.gate_up_proj.weight_scale_fp32, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x = tensor_model_parallel_all_reduce(x) + current_ms_metadata.after_comm_event.record() + return x + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x, _ = self.down_proj(x) + current_ms_metadata.after_comm_event.record() + return x + + +class CustomDeepseekDBOMoE(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = AscendFusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = CustomDeepseekDBOMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.shared_experts", + ) + CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + + self.params_dtype = torch.get_default_dtype() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata.num_prefills > 0 + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + num_tokens, hidden_size = hidden_states.shape + + old_hidden_states = hidden_states.clone() + + if self.tp_size > 1: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank] + elif not self.torchair_graph_enabled: + num_padding_tokens = (self.tp_size - + num_tokens % self.tp_size) % self.tp_size + # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C + if num_padding_tokens > 0: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, num_padding_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + hidden_states = chunk_hidden_states[self.tp_rank] + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekDBOMoE.top_k, + enable_force_load_balance=enable_force_load_balance, + ) * self.routed_scaling_factor + + if self.tp_size > 1: + if self.torchair_graph_enabled: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + final_hidden_states = torch.zeros( + [num_tokens, hidden_size], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, + hidden_states, self.tp_group) + hidden_states = final_hidden_states + else: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + else: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_padding_tokens > 0: + hidden_states = hidden_states[:-num_padding_tokens] + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(old_hidden_states) + + if shared_output is not None: + hidden_states = hidden_states + shared_output + + return hidden_states.view(num_tokens, hidden_size) + + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_op_shared_expert( + self, + hidden_states: torch.Tensor, + ): + shared_output = self.shared_experts._forward_ms_mlp(hidden_states) + return shared_output + + def _forward_ms_op_gate( + self, + hidden_states: torch.Tensor, + ): + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + return router_logits + + def _forward_ms_op_tp_allgather( + self, + hidden_states: torch.Tensor, + chunk_hidden_states: torch.Tensor, + num_tokens: int = 0, + ): + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + dist.all_gather(list(chunk_hidden_states), hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens > 0: + final_hidden_states = final_hidden_states[:-num_tokens] + current_ms_metadata.after_comm_event.record() + return final_hidden_states + + +class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + if self.torchair_graph_enabled: + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + if envs.VLLM_USE_V1: + output = output.view(-1, output_shape[-1]) + return output + else: + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=hidden_states.shape) + + +class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = CustomDeepseekDBOMLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = CustomDeepseekDBOMoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = CustomDeepseekDBOMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, CustomDeepseekDBOMoE): + hidden_states = self.mlp(hidden_states, attn_metadata) + else: + hidden_states = self.mlp(hidden_states) + + if isinstance( + self.mlp, + CustomDeepseekDBOMLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_layer( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + layer_index, ms_metadata, _ = get_multistream_layer_context() + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert isinstance(self.mlp, CustomDeepseekDBOMoE) + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [] + hidden_dims = [] + shared_outputs = [] + router_logits = [] + chunk_hidden_states = [] + + # block 1 : attention + # block 2 : attn tp communication + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.ATTN_AR_FINISH], + ) + + with set_multistream_context(context, i): + forward_context = get_forward_context() + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + + # block 3 : shared experts + # if there is an allreduce ops in shared expert, we can overlap it with the computation of the + # shared expert for next microbatch or moe gating + for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.ATTN_AR_FINISH) + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMP_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMM_FINISH], + ) + with set_multistream_context(context, i): + # compute shared expert after finishing ATTN AR + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) + + num_token, hidden_dim = hidden_states[i].shape + hidden_states[i] = hidden_states[i].view(-1, hidden_dim) + num_tokens.append(num_token) + hidden_dims.append(hidden_dim) + if self.mlp.n_shared_experts is not None: + # TODO: we can move shared expert computation into next block if reduce results is false + shared_output = self.mlp._forward_ms_op_shared_expert( + hidden_states[i]) + shared_outputs.append(shared_output) + + # block 4 : moe + for i in range(num_micro_batchs): + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata[i] is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = attn_metadata[i].num_prefills > 0 + enable_force_load_balance = False + + if self.mlp.tp_size > 1: + num_token, _ = hidden_states[i].shape + padded_num_tokens = (self.mlp.tp_size - num_token % + self.mlp.tp_size) % self.mlp.tp_size + if padded_num_tokens > 0: + hidden_states[i] = nn.functional.pad( + hidden_states[i], (0, 0, 0, padded_num_tokens)) + chunk_hidden_state = torch.tensor_split(hidden_states[i], + self.mlp.tp_size, + dim=0) + chunk_hidden_states.append(chunk_hidden_state) + local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] + else: + local_hidden_states = hidden_states[i] + + router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) + router_logits.append(router_logit) + + if CustomDeepseekDBOMoE.top_k: + real_top_k = CustomDeepseekDBOMoE.top_k + else: + real_top_k = self.mlp.experts.top_k + + hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( + local_hidden_states, router_logits[i], is_prefill, real_top_k, + enable_force_load_balance) + + # the following kernels will be submitted to the comm stream to overlap the computation of the + # moe computation of next microbatch and the attn computation of next layer + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + context.before_comm_event.record() + with torch.npu.stream(ms_metadata.communicate_stream): + context.before_comm_event.wait() + if self.mlp.experts.reduce_results and ( + self.mlp.experts.tp_size > 1 + or self.mlp.experts.ep_size > 1): + hidden_states[i] = tensor_model_parallel_all_reduce( + hidden_states[i]) + hidden_states[ + i] = hidden_states[i] * self.mlp.routed_scaling_factor + context.after_comm_event.record() + + context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_AR_FINISH], + ) + with set_multistream_context(context, i): + if self.mlp.tp_size > 1: + hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( + hidden_states[i], chunk_hidden_states[i], + padded_num_tokens) + with torch.npu.stream(ms_metadata.communicate_stream): + # last + if shared_outputs[i] is not None: + hidden_states[i] = hidden_states[i] + shared_outputs[i] + hidden_states[i] = hidden_states[i].view( + num_tokens[i], hidden_dims[i]) + if isinstance(self.mlp, CustomDeepseekDBOMLP + ) and hidden_states[i].dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states[i] *= 1. / self.routed_scaling_factor + context.after_comm_event.record() + return hidden_states, residual + + # should split ops in Decoder Layer + def _forward_ms_op_input_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_ms_op_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + return hidden_states, residual + + def _forward_ms_op_post_attn_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + return hidden_states, residual + + +class CustomDeepseekDBOModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.first_k_dense_replace = config.first_k_dense_replace + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomDeepseekDBODecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + # tbo related members + if VLLM_ASCEND_ENABLE_DBO: + self.use_mla = model_config.use_mla + self.multistream_config = MultiStreamConfig() + multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer + self.first_k_dense_replace, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + num_normal_layers = (self.first_k_dense_replace + if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer) + + for i in range(self.start_layer, self.start_layer + num_normal_layers): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata) + + moe_start_layer = self.start_layer + num_normal_layers + if moe_start_layer != self.end_layer: + # if we enable multistream/dbo, process sparse layers here + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer, + kv_caches=kv_caches, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def can_run_ms(self): + attn_metadata = get_forward_context().attn_metadata + # support mla attention and V1 engine at present + if not self.use_mla or not envs.VLLM_USE_V1: + return False + # enable prefill overlap + if attn_metadata is None or attn_metadata.num_prefills == 0: + return False + else: + [token_index, seq_index + ] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return False + # check whether the total tokens exceed the threshold + if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + return False + return True + + def _forward_ms_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False, + ): + + if moe_start_layer == self.end_layer: + return hidden_states, residual + + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) + # the rest layers + for i in range(moe_start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer._forward_ms_layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + attn_metadata=attn_metadata, + kv_cache=kv_caches[i - self.start_layer] + if kv_caches is not None else None, + is_prefill=is_prefill) + advance_step_multistream_layer_context() + + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) + return hidden_states, residual + + +class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomDeepseekDBOModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm_ascend/multistream/__init__.py b/vllm_ascend/multistream/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py new file mode 100644 index 0000000..fba58b4 --- /dev/null +++ b/vllm_ascend/multistream/base.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from enum import Enum + + +class MSEventKey(Enum): + ATTN_COM_FINISH = 0 + ATTN_AR_FINISH = 1 + FFN_COM_FINISH = 2 + FFN_AR_FINISH = 3 + # events for MOE dispatch and combine + MOE_BEFORE_COMM = 4 + MOE_AFTER_COMM = 5 + # events for shared expert + MOE_SE_COMM_FINISH = 6 + MOE_SE_COMP_FINISH = 7 + MOE_GATE_FINISH = 8 + + +@dataclass +class MSAttentionMetadataSplitConfig: + """ + micro batch split config for split attention metadata + """ + # micro batch num + num_micro_batches: int = 2 + # split micro batches only when total tokens >= min_total_tokens_to_split + min_total_tokens_to_split: int = 256 + # split micro batches only when prefill tokens >= min_prefill_tokens_to_split + min_prefill_tokens_to_split: int = 64 diff --git a/vllm_ascend/multistream/context.py b/vllm_ascend/multistream/context.py new file mode 100644 index 0000000..a1684f2 --- /dev/null +++ b/vllm_ascend/multistream/context.py @@ -0,0 +1,67 @@ +from contextlib import contextmanager +from typing import Any + +_ms_comm_context: Any = None +_cur_micro_batch_num: int = -1 +_ms_layer_index_context: int = -1 +_ms_metadata_context: Any = None +_ms_attn_metadata_context: Any = None + + +def set_multistream_layer_context(start_layer: int, ms_metadata: Any, + attn_metadata: Any): + """ + set multistream layer context before transformer layers + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = start_layer + _ms_metadata_context = ms_metadata + _ms_attn_metadata_context = attn_metadata + + +def reset_multistream_layer_context(): + """ + reset multistream layer context + """ + global _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + _ms_layer_index_context = -1 + _ms_metadata_context = None + _ms_attn_metadata_context = None + + +def get_multistream_layer_context(): + """ + get multistream layer context + """ + return _ms_layer_index_context, _ms_metadata_context, _ms_attn_metadata_context + + +def advance_step_multistream_layer_context(): + """ + advance multistream layer index context + """ + global _ms_layer_index_context + _ms_layer_index_context += 1 + + +def get_multistream_comm_context() -> Any: + """Get the current comm forward context.""" + return _ms_comm_context + + +def get_multistream_microbatch_context() -> int: + return _cur_micro_batch_num + + +@contextmanager +def set_multistream_context(context: Any, micro_batch_num: int): + """A context manager that stores the current comm forward context, + can be attention metadata, etc.""" + global _ms_comm_context, _cur_micro_batch_num + _ms_comm_context = context + _cur_micro_batch_num = micro_batch_num + try: + yield + finally: + _ms_comm_context = None + _cur_micro_batch_num = -1 diff --git a/vllm_ascend/multistream/decorator.py b/vllm_ascend/multistream/decorator.py new file mode 100644 index 0000000..6c7f16a --- /dev/null +++ b/vllm_ascend/multistream/decorator.py @@ -0,0 +1,26 @@ +from vllm.logger import init_logger + +from .context import (get_multistream_layer_context, + get_multistream_microbatch_context) + +logger = init_logger(__name__) + + +# vllm v1 use get_forward_context to get the attn_metadata, +# we can use this decorator to update the attn metadata +def set_multistream_support(): + + def decorator(func): + + def wrapper(): + context = func() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + micro_batch_num = get_multistream_microbatch_context() + if layer_index != -1 and micro_batch_num != -1: + context.attn_metadata = attn_metadata[micro_batch_num] + return context + + return wrapper + + return decorator diff --git a/vllm_ascend/multistream/layers.py b/vllm_ascend/multistream/layers.py new file mode 100644 index 0000000..c5273bc --- /dev/null +++ b/vllm_ascend/multistream/layers.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Tuple, Union + +import torch +from vllm.forward_context import get_forward_context + +from .base import MSEventKey +from .context import (get_multistream_layer_context, + reset_multistream_layer_context, + set_multistream_layer_context) +from .metadata import MultiStreamMetadata + + +class MultiStreamPreTransformerLayer(torch.nn.Module): + + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + + def forward( + self, + intput_tensors: List[torch.Tensor], + ): + attn_metadata = get_forward_context().attn_metadata + if self.multistream_metadata is None or attn_metadata is None: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + # TODO add attn_metadata management + do_ms, attn_metadata, intput_tensors, _ = self.multistream_metadata.split_micro_batch( + attn_metadata, intput_tensors) + if do_ms: + set_multistream_layer_context( + self.multistream_metadata.start_layer, + self.multistream_metadata, attn_metadata) + else: + set_multistream_layer_context(-1, None, None) + return attn_metadata, intput_tensors + + +class MultiStreamPostTransformerLayer(torch.nn.Module): + + def __init__(self, multistream_metadata: MultiStreamMetadata): + super().__init__() + self.multistream_metadata = multistream_metadata + + def forward(self, + input_tensors: Union[List[Tuple[torch.Tensor]], + List[torch.Tensor], + List[List[torch.Tensor]]], + wait_layer_index: Optional[int] = None): + if self.multistream_metadata is None or self.multistream_metadata.ms_config is None: + return input_tensors + layer_index, ms_metadata, ms_attn_metadata = get_multistream_layer_context( + ) + if layer_index >= 0: + true_wait_layer = self.multistream_metadata.end_layer - 1 if wait_layer_index is None else wait_layer_index + self.multistream_metadata.try_wait_event( + true_wait_layer, + self.multistream_metadata.ms_config.num_micro_batches - 1, + MSEventKey.FFN_AR_FINISH) + reset_multistream_layer_context() + return self.multistream_metadata.merge_micro_batches(input_tensors) diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py new file mode 100644 index 0000000..b521d3f --- /dev/null +++ b/vllm_ascend/multistream/metadata.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from vllm.sequence import IntermediateTensors + +from vllm_ascend.attention.mla_v1 import AscendMLAMetadata + +from .base import MSAttentionMetadataSplitConfig, MSEventKey + + +def split_micro_batches_tensors(input_tensors, + split_index: int, + keys: Optional[List[str]] = None): + if isinstance(input_tensors, list): + micro_batches = [] + for tensor in input_tensors: + if tensor is None: + micro_batches.append([None, None]) + else: + micro_batches.append( + [tensor[:split_index], tensor[split_index:]]) + return micro_batches + elif isinstance(input_tensors, torch.Tensor): + return [input_tensors[:split_index], input_tensors[split_index:]] + elif input_tensors is None: + return [None, None] + elif isinstance(input_tensors, Dict): + assert keys is not None + micro_batches_pre = {} + for key in keys: + micro_batches_pre[key] = input_tensors[key][:split_index] + micro_batches_post = {} + for key in keys: + micro_batches_post[key] = input_tensors[key][split_index:] + return [micro_batches_pre, micro_batches_post] + else: + raise NotImplementedError + + +@dataclass +class MultiStreamStepMetadata: + comm_stream: torch.npu.Stream = None + before_comm_event: torch.npu.Event = None + after_comm_event: torch.npu.Event = None + + +@dataclass +class MultiStreamConfig: + """Controls the behavior of multi-stream models.""" + min_total_tokens_to_split: int = 256 + min_prefill_tokens_to_split: int = 64 + num_micro_batches: int = 2 + imbalance_ratio: float = 0.1 + + +class MultiStreamMetadata: + # direct stream + calculate_stream = None + # delay stream + communicate_stream = None + # events + ms_events: Dict[int, Dict[int, Dict[MSEventKey, torch.npu.Event]]] = {} + # multi-stream-flag + enable_multi_stream: bool = False + + def __init__( + self, + calculate_stream: torch.npu.Stream, + communicate_stream: torch.npu.Stream, + start_layer: int, + end_layer: int, + event_keys: List[MSEventKey], + multistream_config: Optional[MultiStreamConfig], + causal_lm: bool = True, + ): + self.calculate_stream = calculate_stream + self.communicate_stream = communicate_stream + self.start_layer = start_layer + self.end_layer = end_layer + self.ms_config = multistream_config + self.causal_lm = causal_lm + self._build_events(event_keys) + self._build_ms_split_config() + + def _build_events(self, event_keys): + if self.ms_config is not None: + for i in range(self.start_layer - 1, self.end_layer): + self.ms_events[i] = {} + for j in range(self.ms_config.num_micro_batches): + self.ms_events[i][j] = {} + for key in event_keys: + self.ms_events[i][j][key] = torch.npu.Event() + + def _build_ms_split_config(self): + if self.ms_config is not None: + self.ms_split_config = MSAttentionMetadataSplitConfig( + num_micro_batches=self.ms_config.num_micro_batches, + min_total_tokens_to_split=self.ms_config. + min_total_tokens_to_split, + min_prefill_tokens_to_split=self.ms_config. + min_prefill_tokens_to_split, + ) + + def try_wait_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].wait() + + def try_record_event(self, layer_index: int, micro_batch_index: int, + event_key: MSEventKey): + self.ms_events[layer_index][micro_batch_index][event_key].record() + + def split_micro_batch( + self, + attn_metadata: "AscendMLAMetadata", + intput_tensors: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors_keys: Optional[List[str]] = None, + ) -> Tuple[bool, Union[AscendMLAMetadata, List[AscendMLAMetadata]], Union[ + List[torch.Tensor], List[List[torch.Tensor]]], Union[ + IntermediateTensors, List[IntermediateTensors]]]: + attn_metadata_list = attn_metadata.split_metadata_for_multistream( + self.ms_split_config) + if len(attn_metadata_list) == 1: + return False, attn_metadata_list[ + 0], intput_tensors, intermediate_tensors + split_index = attn_metadata_list[0].slot_mapping.shape[0] + input_tensors = split_micro_batches_tensors(intput_tensors, + split_index) + if intermediate_tensors is not None: + inter_tensors_list = split_micro_batches_tensors( + intermediate_tensors.tensors, split_index, + intermediate_tensors_keys) + intermediate_tensors = [ + IntermediateTensors(inter_tensors) + for inter_tensors in inter_tensors_list + ] + return True, attn_metadata_list, input_tensors, intermediate_tensors + + def merge_micro_batches( + self, input_tensors: Union[List[torch.Tensor], + List[List[torch.Tensor]]] + ) -> List[torch.Tensor]: + if input_tensors is None or isinstance(input_tensors[0], torch.Tensor): + return input_tensors + batch: List[Optional[torch.Tensor]] = [] + for tensors in input_tensors: + if tensors is None or tensors[0] is None: + batch.append(None) + else: + batch.append(torch.cat(tensors, dim=0)) + return batch + + +def make_multistream_metadata_ds( + start_layer: int, + end_layer: int, + causal_lm: bool = True, + multistream_config: Optional[MultiStreamConfig] = None, +): + if multistream_config is None: + return None + event_keylist = [ + MSEventKey.ATTN_COM_FINISH, + MSEventKey.ATTN_AR_FINISH, + MSEventKey.FFN_COM_FINISH, + MSEventKey.FFN_AR_FINISH, + MSEventKey.MOE_BEFORE_COMM, + MSEventKey.MOE_AFTER_COMM, + MSEventKey.MOE_SE_COMM_FINISH, + MSEventKey.MOE_SE_COMP_FINISH, + MSEventKey.MOE_GATE_FINISH, + ] + return MultiStreamMetadata( + calculate_stream=torch.npu.current_stream(), + communicate_stream=torch.npu.Stream(), + start_layer=start_layer, + end_layer=end_layer, + multistream_config=multistream_config, + event_keys=event_keylist, + causal_lm=causal_lm, + ) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py new file mode 100644 index 0000000..430f57b --- /dev/null +++ b/vllm_ascend/multistream/ms_split.py @@ -0,0 +1,245 @@ +from copy import deepcopy +from typing import Any, List, Optional + +import numpy as np +import torch + +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +from .base import MSAttentionMetadataSplitConfig + + +def compute_split_seq_index( + query_lens: Optional[list[int]], + attn_state: AscendAttentionState, + num_tokens: int, + imbalance_ratio: float = 0.1, +) -> list[int]: + if attn_state != AscendAttentionState.DecodeOnly: + assert query_lens is not None + total_tokens = sum(query_lens) + # the first index in last split + tokens, split_index = 0, 0 + for value in query_lens: + tokens += value + split_index += 1 + if tokens >= total_tokens // 2: + # check the current split index + if abs(tokens - + total_tokens // 2) < total_tokens * imbalance_ratio: + return [tokens, split_index] + # check the previous split index + elif abs(tokens - total_tokens // 2 - + value) < total_tokens * imbalance_ratio: + return [tokens - value, split_index - 1] + # fail to split if it is imbalanced + # TODO: split tokens in seq + else: + return [0, 0] + else: + tokens = num_tokens // 2 + return [tokens, tokens] + return [0, 0] + + +def split_attn_tensor_type( + input_tensor: torch.Tensor, + index: int, +) -> List[torch.Tensor]: + return [input_tensor[:index], input_tensor[index:]] + + +def split_attn_int_type( + var: int, + index: int, +) -> List[torch.Tensor]: + return [min(var, index), max(var - index, 0)] + + +def model_input_split_v1_mla_attn( + attn_metadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, +) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + if attn_metadata is None: + return [attn_metadata] + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_decode_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] + + query_start_loc_cpu = np.zeros(shape=(len(attn_metadata.query_lens) + 1, ), + dtype=int) + np.cumsum(attn_metadata.query_lens, out=query_start_loc_cpu[1:]) + if attn_metadata.num_prefills > 0: + prefill_query_start_loc = np.zeros( + shape=(len(attn_metadata.prefill.query_lens) + 1, ), dtype=int) + np.cumsum(attn_metadata.prefill.query_lens, + out=prefill_query_start_loc[1:]) + + # split attn metadata + [slot_mapping_pre, + slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, + token_index) + [num_decodes_pre, + num_decodes_post] = split_attn_int_type(attn_metadata.num_decodes, + seq_index) + [num_decode_tokens_pre, num_decode_tokens_post + ] = split_attn_int_type(attn_metadata.num_decode_tokens, token_index) + [num_prefills_pre, num_prefills_post + ] = split_attn_int_type(attn_metadata.num_prefills, + max(0, seq_index - attn_metadata.num_decodes)) + seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens + [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] + [block_table_pre, + block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, + seq_index) + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + else: + # chunked prefill + if num_prefills_pre > 0: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, + AscendMLAPrefillMetadata) + if num_prefills_pre > 0: + # split metadata.prefill + [input_positions_pre, input_positions_post] = split_attn_tensor_type( + attn_metadata.prefill.input_positions, + token_index - attn_metadata.num_decode_tokens) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.prefill.block_table, + seq_index - attn_metadata.num_decodes) + [prefill_query_lens_pre, prefill_query_lens_post + ] = split_attn_tensor_type(attn_metadata.prefill.query_lens, + seq_index - attn_metadata.num_decodes) + prefill_query_start_loc_pre = attn_metadata.prefill.query_start_loc[: + seq_index + + + 1 - + attn_metadata + . + num_decodes] + prefill_query_start_loc_post = deepcopy( + attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes:] + ) - attn_metadata.prefill.query_start_loc[seq_index - + attn_metadata.num_decodes] + context_len_pre = seq_lens_pre[attn_metadata.num_decodes:] + context_len_post = seq_lens_post + prefill_max_query_len_pre = max(prefill_query_lens_pre) + prefill_max_query_len_post = max(prefill_query_lens_post) + prefill_pre = AscendMLAPrefillMetadata( + attn_mask=attn_mask_pre, + query_lens=prefill_query_lens_pre, + seq_lens=seq_lens_pre, + query_start_loc=prefill_query_start_loc_pre, + input_positions=input_positions_pre, + context_lens=context_len_pre, + block_table=block_tables_pre, + max_query_len=prefill_max_query_len_pre, + max_seq_lens=context_len_pre.max().item(), + ) + prefill_post = AscendMLAPrefillMetadata( + attn_mask=attn_mask_post, + query_lens=prefill_query_lens_post, + seq_lens=seq_lens_post, + query_start_loc=prefill_query_start_loc_post, + input_positions=input_positions_post, + context_lens=context_len_post, + block_table=block_tables_post, + max_query_len=prefill_max_query_len_post, + max_seq_lens=context_len_post.max().item(), + ) + decode_pre = attn_metadata.decode + decode_post = None + else: + # prefill is None, split metadata.decode + [input_positions_pre, input_positions_post + ] = split_attn_tensor_type(attn_metadata.decode.input_positions, + token_index) + [block_tables_pre, block_tables_post + ] = split_attn_tensor_type(attn_metadata.decode.block_table, + seq_index) + [decode_seq_lens_pre, + decode_seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) + decode_pre = AscendMLADecodeMetadata( + input_positions=input_positions_pre, + block_table=block_tables_pre, + seq_lens=decode_seq_lens_pre, + max_seq_lens=max(decode_seq_lens_pre), + seq_lens_list=decode_seq_lens_pre.tolist(), + ) + decode_post = AscendMLADecodeMetadata( + input_positions=input_positions_post, + block_table=block_tables_post, + seq_lens=decode_seq_lens_post, + max_seq_lens=max(decode_seq_lens_post), + seq_lens_list=decode_seq_lens_post.tolist(), + ) + prefill_pre = None + prefill_post = attn_metadata.prefill + # construct metadata + from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + num_input_tokens=token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_pre, + seq_lens=seq_lens_pre, + query_start_loc=query_start_loc_pre, + block_tables=block_table_pre, + num_decodes=num_decodes_pre, + num_prefills=num_prefills_pre, + num_decode_tokens=num_decode_tokens_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + prefill=prefill_pre, + decode=decode_pre, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + ) + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + head_dim=attn_metadata.head_dim, + slot_mapping=slot_mapping_post, + seq_lens=seq_lens_post, + query_start_loc=query_start_loc_post, + block_tables=block_table_post, + num_decodes=num_decodes_post, + num_prefills=num_prefills_post, + num_decode_tokens=num_decode_tokens_post, + attn_mask=attn_mask_post, + attn_state=attn_state_post, + prefill=prefill_post, + decode=decode_post, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + ) + return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fd06d90..56df04e 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1151,3 +1151,32 @@ class AscendFusedMoE(FusedMoE): if self.enable_multistream_shared_expert and not is_prefill: return hidden_states, shared_output return hidden_states + + # ----------------------------------------- TBO-related -------------------------------------------- + + def _forward_ms_fused_moe_comp( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, + ): + hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance) + + return hidden_states