Files
xc-llm-ascend/vllm_ascend/attention/mla_v1.py

1566 lines
67 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, TypeVar
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
import numpy as np
import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
[Main2Main][Deps][Misc] Upgrade vLLM to v0.15.0 (#6470) ### What this PR does / why we need it? This PR upgrades the vLLM dependency from `v0.14.1` to `v0.15.0`. This involves: - Updating the `VLLM_TAG` in all `Dockerfile`. - Updating the vLLM version in `docs/source/conf.py`. - Removing conditional code paths specific to `v0.14.1` across the codebase, which simplifies maintenance. - Fix `TypeError: MMEncoderAttention.__init__() got an unexpected keyword argument 'multimodal_config'` due to https://github.com/vllm-project/vllm/pull/31972. - Fix `_shared_experts: 'NoneType' object is not callable` due to https://github.com/vllm-project/vllm/pull/32082 by https://github.com/vllm-project/vllm-ascend/pull/6335. - Fix `ReshapeAndCacheOperation setup failed!` due to https://github.com/vllm-project/vllm/pull/25954 by overriding attention metadata slots. This upgrade is necessary to keep the project aligned with the latest features, bug fixes, and API changes in the vLLM project. ### Does this PR introduce _any_ user-facing change? No, this is an internal dependency update and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with these changes, ensuring that all existing tests are successful with the new vLLM version. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd co-authored-by: shen-shanshan <467638484@qq.com> --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-02-02 15:57:55 +08:00
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata,
ascend_chunked_prefill_workspace_size,
enable_cp,
enabling_mlapo,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight,
transdata,
wait_for_kv_layer_from_connector,
)
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params,
get_graph_params,
update_draft_graph_params_workspaces,
update_graph_params_workspaces,
)
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer,
post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
get_weight_prefetch_method,
maybe_trans_nz,
weak_ref_tensors,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
BUILD_METADATA_STEP_PREFILL = 0
BUILD_METADATA_STEP_DECODE = 1
# token count limits within the mlapo operator
MLAPO_MAX_SUPPORTED_TOKENS = 1024
class AscendMLABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_builder_cls():
if enable_cp():
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]:
return num_blocks, block_size, num_kv_heads, head_size
@staticmethod
def get_impl_cls() -> type["MLAAttentionImpl"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
return AscendMlaCPImpl
return AscendMLAImpl
@staticmethod
def get_supported_kernel_block_sizes() -> list[int]:
return [128]
@dataclass
class ChunkedContextMetadata:
"""
Metadata for chunked context handling in MLA attention.
Manages sequence boundaries and workspace for chunked prefill processing.
"""
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
@dataclass
class AscendMLAPrefillMetadata:
"""Prefill Specific Metadata for Ascend"""
attn_mask: torch.Tensor
query_lens: torch.Tensor
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
chunked_context: ChunkedContextMetadata | CPChunkedContextMetadata | None = None
sin: torch.Tensor = None
cos: torch.Tensor = None
pcp_metadata: AscendPCPMetadata | None = None
@dataclass
class AscendMLADecodeMetadata:
"""Decode-specific metadata for Ascend MLA attention."""
# Input positions for rotary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
actual_seq_lengths_q: list[int] | None = None
attn_mask: torch.Tensor | None = None
sin: torch.Tensor = None
cos: torch.Tensor = None
cp_seq_len: torch.Tensor = None
@dataclass
class AscendMLAMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
num_actual_tokens_pcp_padded: int
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_tables: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
query_lens: list[int] | None = None
# The dimension of the attention heads
head_dim: int | None = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
decode: AscendMLADecodeMetadata | None = None
prefill: AscendMLAPrefillMetadata | None = None
reshape_cache_event: torch.npu.Event = None
def __post_init__(self):
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(
kv_cache_spec,
layer_names,
vllm_config,
device,
metadata_cls if metadata_cls is not None else AscendMLAMetadata,
supports_dcp_with_varlen,
)
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
)
self.reorder_batch_threshold = self.decode_threshold
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.chunk_seq_lens: torch.Tensor = None
self.cu_seq_lens_cpu: torch.Tensor = None
self.num_chunks: torch.Tensor = None
self.max_context_chunk = 0
self.num_decodes = 0
self.num_prefills = 0
self.num_decode_tokens = 0
self.num_prefill_tokens = 0
self.context_lens_cpu: torch.Tensor = None
self.num_actual_tokens: int | None = None
self.block_table: torch.Tensor = None
self.slot_mapping: torch.Tensor = None
self.graph_pad_size = 0
self.query_lens: torch.Tensor = None
self.seq_lens: torch.Tensor = None
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@classmethod
def get_cudagraph_support(
cls: type["AscendMLAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if num_tokens <= self.decode_threshold:
decodes.append(i)
else:
prefills.append(i)
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
return modified_batch
def pad_actual_seq_len_q_mtp_enable_pad(
self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata
):
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
"""
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
in order to meet the requirement of npu_fused_infer_attention_score.
In Torchair scenario, the lengths of the queries must be padded to the same length.
And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
For example:
batch_size=36, num_reqs_pad_size=2, num_reqs=16
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q
evenly to not exceed 16 tokens per request.
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
"""
FIA_SEQ_LEN_LIMIT = 16
need_padding = (
num_reqs_pad_size != 0
and len(common_attn_metadata.actual_seq_lengths_q) > num_reqs
and common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
)
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
if need_padding:
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size]
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
start_val = actual_seq_lengths_q[-1]
end_val = padding_seq_len_q[-1]
num_step = len(padding_seq_len_q)
interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist()
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
assert interpolated[-1] == end_val
assert len(interpolated) == len(padding_seq_len_q)
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
else:
actual_seq_lengths_q = (
actual_seq_lengths_q
+ common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size]
)
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
return actual_seq_lengths_q
def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q):
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
"""
Only use for acl full graph mode.
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
the num of dimensions equal to the batch_size of main model.
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
For example:
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8]
"""
need_padding = num_reqs_pad_size > 0
if need_padding:
start_val = actual_seq_lengths_q[-1]
end_val = num_reqs + num_reqs_pad_size
num_step = num_reqs_pad_size
interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist()
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
assert interpolated[-1] == end_val
assert len(interpolated) == num_reqs_pad_size
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
return actual_seq_lengths_q
def set_num_actual_tokens(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
self.num_actual_tokens = common_attn_metadata.num_actual_tokens
def build(
self,
[New model] Qwen3-next support (#2917) ### What this PR does / why we need it? Add Qwen3-next support. ### Does this PR introduce _any_ user-facing change? Yes, users can use Qwen3 next. Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the tutorial will be ready in [here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html) ### How was this patch tested? Doc CI passed Related: https://github.com/vllm-project/vllm-ascend/issues/2884 Co-Authored-By: Angazenn <supperccell@163.com> Co-Authored-By: zzzzwwjj <1183291235@qq.com> Co-Authored-By: MengqingCao <cmq0113@163.com> Co-Authored-By: linfeng-yuan <1102311262@qq.com> Co-Authored-By: hust17yixuan <303660421@qq.com> Co-Authored-By: SunnyLee219 <3294305115@qq.com> Co-Authored-By: maoxx241 <maoxx241@umn.edu> - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b834b4cbf1d5094affdf231df2be86920610d83e --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: Your Name <you@example.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: hust17yixuan <303660421@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com> Co-authored-by: Angazenn <supperccell@163.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: hust17yixuan <303660421@qq.com>
2025-09-16 01:17:42 +08:00
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
fast_build: bool = False,
) -> AscendMLAMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
)
self.set_num_actual_tokens(common_attn_metadata)
assert self.num_decodes + self.num_prefills == num_reqs
assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens
support cp&dcp (#3260) ### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
2025-10-24 10:32:01 +08:00
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
self.slot_mapping = common_attn_metadata.slot_mapping[: self.num_actual_tokens]
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
self.graph_pad_size = common_attn_metadata.graph_pad_size
block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
self.block_table = common_attn_metadata.block_table_tensor[:block_table_size]
prefill_metadata = None
if self.num_prefills > 0:
prefill_metadata = self.build_prefill_metadata(common_prefix_len, common_attn_metadata)
decode_metadata = None
if self.num_decodes > 0:
decode_metadata = self.build_decode_metadata(common_prefix_len, common_attn_metadata)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=self.num_actual_tokens,
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=self.num_actual_tokens,
query_lens=self.query_lens.tolist(),
slot_mapping=self.slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=self.num_decodes,
num_decode_tokens=self.num_decode_tokens,
num_prefills=self.num_prefills,
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=self.block_table,
seq_lens=self.seq_lens,
)
def build_chunked_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
):
if not self.chunked_prefill_enabled:
return None
num_reqs = common_attn_metadata.num_reqs
num_computed_tokens_cpu = self.seq_lens - self.query_lens
reqs_start = self.num_decodes # prefill_start
self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = self.context_lens_cpu.max().item()
if not max_context_len_cpu > 0:
return None
num_prefills_with_context_cpu = (self.context_lens_cpu > 0).sum().item()
self.max_context_chunk = self.chunked_prefill_workspace_size // num_prefills_with_context_cpu
self.max_context_chunk = round_down(self.max_context_chunk, self.block_size)
assert self.max_context_chunk > 0
self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk)
chunk_starts = (
torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills)
* self.max_context_chunk
)
chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0), chunk_starts + self.max_context_chunk)
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True)
torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32)
return ChunkedContextMetadata(
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
starts=chunk_starts.pin_memory().to(self.device, non_blocking=True),
seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=self.chunk_seq_lens,
chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# If graph_pad_size > -1, mean is running in fullgraph mode.
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if (
self.graph_pad_size > common_attn_metadata.num_reqs
and self.speculative_config.disable_padded_drafter_batch
):
return self.graph_pad_size
return common_attn_metadata.num_reqs
return self.num_decodes
def build_prefill_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendMLAPrefillMetadata:
query_start_loc = common_attn_metadata.query_start_loc
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long()
chunked_context_metadata = self.build_chunked_metadata(common_prefix_len, common_attn_metadata)
reqs_start = self.num_decodes # prefill_start
tokens_start = self.num_decode_tokens
max_query_len = self.query_lens[reqs_start:].max().item()
max_seq_lens = self.seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[reqs_start:] - query_start_loc[reqs_start]
prefill_input_positions = input_positions[tokens_start:]
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
return AscendMLAPrefillMetadata(
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
query_lens=self.query_lens[reqs_start:].to(torch.int32),
seq_lens=self.seq_lens,
context_lens=self.seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=self.block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
def build_decode_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendMLADecodeMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1 : self.num_decodes + 1].tolist()
max_seq_lens = self.seq_lens[: self.num_decodes].max().item()
self.seq_lens = self.seq_lens[: self.num_decodes]
input_positions = input_positions[: self.num_decode_tokens]
block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_DECODE)
self.block_table = self.block_table[:block_table_size]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if self.graph_pad_size > self.num_decodes and self.speculative_config.disable_padded_drafter_batch:
self.block_table = self.block_table[: self.graph_pad_size, ...]
seq_lens_list = self.seq_lens.tolist()
cp_seq_len = None
if self.graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = self.graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q
)
seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size - self.num_decodes)
num_block_pad_size = self.graph_pad_size - self.block_table.shape[0]
if num_block_pad_size > 0:
block_table_padding = torch.zeros(
(num_block_pad_size,) + self.block_table.shape[1:],
dtype=self.block_table.dtype,
device=self.block_table.device,
)
self.block_table = torch.cat([self.block_table, block_table_padding], dim=0)
else:
num_token_pad_size = self.graph_pad_size - self.num_decode_tokens
num_reqs_pad_size = self.graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs
num_block_table_pad_size = (
self.graph_pad_size // common_attn_metadata.decode_token_per_req - self.num_decodes
)
seq_lens_list = self.seq_lens.tolist() + [0] * num_reqs_pad_size
slot_padding = torch.full(
(num_token_pad_size,), PAD_SLOT_ID, dtype=self.slot_mapping.dtype, device=self.slot_mapping.device
)
self.slot_mapping = torch.cat([self.slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_block_table_pad_size,) + self.block_table.shape[1:],
dtype=self.block_table.dtype,
device=self.block_table.device,
)
self.block_table = torch.cat([self.block_table, block_table_padding], dim=0)
position_padding = torch.zeros(
num_token_pad_size, dtype=input_positions.dtype, device=input_positions.device
)
input_positions = torch.cat([input_positions, position_padding])
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata
)
cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[: self.num_decode_tokens, ...],
cos=cos[: self.num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
)
return decode_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
class DecodeMLAPreprocessResult(NamedTuple):
ql_nope: torch.Tensor | None = None
q_pe: torch.Tensor | None = None
k_nope: torch.Tensor | None = None
k_pe: torch.Tensor | None = None
decode_q_wo_k_up: torch.Tensor | None = None
class PrefillMLAPreprocessResult(NamedTuple):
q_nope: torch.Tensor | None = None
q_pe: torch.Tensor | None = None
k_nope: torch.Tensor | None = None
k_pe: torch.Tensor | None = None
value: torch.Tensor | None = None
class AscendMLAImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
**kwargs,
):
self.vllm_config = get_current_vllm_config()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs["q_lora_rank"]
self.kv_lora_rank = kwargs["kv_lora_rank"]
self.qk_nope_head_dim = kwargs["qk_nope_head_dim"]
self.qk_rope_head_dim = kwargs["qk_rope_head_dim"]
self.qk_head_dim = kwargs["qk_head_dim"]
self.v_head_dim = kwargs["v_head_dim"]
self.rotary_emb = kwargs["rotary_emb"]
self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj")
self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"]
self.kv_b_proj = kwargs["kv_b_proj"]
self.o_proj = kwargs["o_proj"]
[Feat] Flashcomm2 use o_shared linear (#4188) ### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart <img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39" src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0" /> ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: zzh02232027 <zzh02232027@antgroup.com> Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
2025-12-11 12:43:04 +08:00
self.vllm_config = get_current_vllm_config()
self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa")
self.kv_a_layernorm = kwargs.get("kv_a_layernorm")
self.q_a_layernorm = kwargs.get("q_a_layernorm")
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
ascend_config = get_ascend_config()
[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198) ### What this PR does / why we need it? 1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to pure DP for shared experts, enabling more efficient execution. 2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by using ReduceScatter, made possible by pure DP sharding. 3.AllGather Postponed: Delayed to after QKV down projection to reduce synchronization impact during prefill. ### How was this patch tested? Adding ut case in `tests/ut/attention/test_mla_v1.py` #### How to run use parameter `--additional_config='{"enable_shared_expert_dp": true}'` ##### a.How to run eager mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true}' ##### b.How to run graph mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9edd1db02bc6dce6da503503a373657f3466a78b --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
2025-08-12 14:12:12 +08:00
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_kv_nz = ascend_config.enable_kv_nz
self.ring_mla_mask_size = 512
[Feat] Flashcomm2 use o_shared linear (#4188) ### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart <img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39" src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0" /> ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: zzh02232027 <zzh02232027@antgroup.com> Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
2025-12-11 12:43:04 +08:00
self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = enabling_mlapo(self.vllm_config)
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
self.layer_sharding_kwargs = []
for layer_name in get_ascend_config().layer_sharding or []:
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config=None,
speculative_config=None,
num_dcp_pcp_tokens=None,
draft_attn_metadatas=None,
):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
attn_mask,
sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)]
elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
x = x.view(self.num_heads, -1, self.kv_lora_rank)
# Multiply (N, B, L) x (N, L, V) -> (B, N, V)
x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2))
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
return x
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = (
self.q_proj(x)[0]
.view(-1, self.num_heads, self.qk_head_dim)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)
if self.enable_mlapo:
# Currently mlapo only supports W8A8 quantization in MLA scenario
# TODO(whx): modify this limitation when mlapo supports floating point
if self.fused_qkv_a_proj is None or not isinstance(
getattr(self.fused_qkv_a_proj.quant_method, "quant_method", None), AscendW8A8LinearMethod
):
self.enable_mlapo = False
logger.warning_once(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers."
)
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
else:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
[Feat] Flashcomm2 use o_shared linear (#4188) ### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart <img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39" src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0" /> ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: zzh02232027 <zzh02232027@antgroup.com> Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
2025-12-11 12:43:04 +08:00
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr]
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr]
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous()
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous()
wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
self.beta1 = torch.zeros_like(self.gamma1) if (_bias := self.q_a_layernorm.bias) is None else _bias.data # type: ignore[union-attr]
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data # type: ignore[union-attr]
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data # type: ignore[union-attr]
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory.
if (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.is_kv_consumer
and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS
):
self.fused_qkv_a_proj.weight = None # type: ignore[union-attr]
self.fused_qkv_a_proj.deq_scale = None # type: ignore[union-attr]
self.fused_qkv_a_proj.quant_bias = None # type: ignore[union-attr]
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()
def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata):
prefill_metadata = attn_metadata.prefill
assert prefill_metadata is not None
assert prefill_metadata.chunked_context is not None
assert prefill_metadata.chunked_context.chunk_seq_lens_npu is not None
iters = len(prefill_metadata.chunked_context.seq_tot)
assert 0 <= index < iters
return prefill_metadata.chunked_context.chunk_seq_lens_npu[index]
def _reorg_kvcache(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
chunked_context: CPChunkedContextMetadata,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
return kv_c_normed, k_pe
def _compute_prefill_context(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: tuple[torch.Tensor],
rope_dim: int,
attn_metadata: AscendMLAMetadata,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
):
Disaggregate prefill for kv cache register style (#950) ### What this PR does / why we need it? This PR adopt `LLMDataDist` for kv cache register and `pull_blocks` style disaggregate prefill implementation. The interface implementation mainly follows the design of NIXL PR https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953 . This PR can be test with the following step: - Generate the rank table for all machine. - execute`toy_proxy.py` to launch the disaggregate prefill proxy server, specify the prefill ip, port and the decode ip, port - Run the prefill server and decode server. - send the request to the disaggregate prefill proxy ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/8d0a01a5f2b53794e4bc6b734d7b63cb8a9b7d7d --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Signed-off-by: liziyu179 <3475441767@qq.com> Signed-off-by: underfitc <hucong24@huawei.com> Signed-off-by: zouyida2052 <zouyida@huawei.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: underfituu <hzhucong@163.com> Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Co-authored-by: liziyu179 <3475441767@qq.com> Co-authored-by: underfitc <hucong24@huawei.com> Co-authored-by: zouyida2052 <zouyida@huawei.com> Co-authored-by: liziyu <liziyu16@huawei.com> Co-authored-by: underfituu <hzhucong@163.com>
2025-07-26 17:15:47 +08:00
assert len(kv_c_and_k_pe_cache) > 1
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse
iters = len(prefill_metadata.chunked_context.seq_tot)
current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
Disaggregate prefill for kv cache register style (#950) ### What this PR does / why we need it? This PR adopt `LLMDataDist` for kv cache register and `pull_blocks` style disaggregate prefill implementation. The interface implementation mainly follows the design of NIXL PR https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953 . This PR can be test with the following step: - Generate the rank table for all machine. - execute`toy_proxy.py` to launch the disaggregate prefill proxy server, specify the prefill ip, port and the decode ip, port - Run the prefill server and decode server. - send the request to the disaggregate prefill proxy ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/8d0a01a5f2b53794e4bc6b734d7b63cb8a9b7d7d --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Signed-off-by: liziyu179 <3475441767@qq.com> Signed-off-by: underfitc <hucong24@huawei.com> Signed-off-by: zouyida2052 <zouyida@huawei.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: underfituu <hzhucong@163.com> Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Co-authored-by: liziyu179 <3475441767@qq.com> Co-authored-by: underfitc <hucong24@huawei.com> Co-authored-by: zouyida2052 <zouyida@huawei.com> Co-authored-by: liziyu <liziyu16@huawei.com> Co-authored-by: underfituu <hzhucong@163.com>
2025-07-26 17:15:47 +08:00
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
# chunk_seq_lens will be padded when pcp&dcp
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([current_seq_len, context_seq_len])
context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata)
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)
kv_c_normed, k_pe = self._reorg_kvcache(
kv_c_normed,
k_pe,
chunked_context=prefill_metadata.chunked_context,
chunk_idx=i,
toks=toks,
)
kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
mask = attn_metadata.attn_mask
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse,
)
return prefix_output, prefix_lse
def _forward_prefill(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
Disaggregate prefill for kv cache register style (#950) ### What this PR does / why we need it? This PR adopt `LLMDataDist` for kv cache register and `pull_blocks` style disaggregate prefill implementation. The interface implementation mainly follows the design of NIXL PR https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953 . This PR can be test with the following step: - Generate the rank table for all machine. - execute`toy_proxy.py` to launch the disaggregate prefill proxy server, specify the prefill ip, port and the decode ip, port - Run the prefill server and decode server. - send the request to the disaggregate prefill proxy ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/8d0a01a5f2b53794e4bc6b734d7b63cb8a9b7d7d --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Signed-off-by: liziyu179 <3475441767@qq.com> Signed-off-by: underfitc <hucong24@huawei.com> Signed-off-by: zouyida2052 <zouyida@huawei.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: underfituu <hzhucong@163.com> Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Co-authored-by: liziyu179 <3475441767@qq.com> Co-authored-by: underfitc <hucong24@huawei.com> Co-authored-by: zouyida2052 <zouyida@huawei.com> Co-authored-by: liziyu <liziyu16@huawei.com> Co-authored-by: underfituu <hzhucong@163.com>
2025-07-26 17:15:47 +08:00
assert len(kv_c_and_k_pe_cache) > 1
num_tokens = q_nope.size(0)
attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=attn_metadata.attn_mask,
seqlen=attn_metadata.prefill.query_lens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse,
)
attn_output, attn_lse = self._compute_prefill_context(
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
)
attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim])
[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198) ### What this PR does / why we need it? 1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to pure DP for shared experts, enabling more efficient execution. 2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by using ReduceScatter, made possible by pure DP sharding. 3.AllGather Postponed: Delayed to after QKV down projection to reduce synchronization impact during prefill. ### How was this patch tested? Adding ut case in `tests/ut/attention/test_mla_v1.py` #### How to run use parameter `--additional_config='{"enable_shared_expert_dp": true}'` ##### a.How to run eager mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true}' ##### b.How to run graph mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9edd1db02bc6dce6da503503a373657f3466a78b --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
2025-08-12 14:12:12 +08:00
return attn_output
def exec_kv_decode(
self,
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: tuple,
slots: torch.Tensor,
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight, # type: ignore[union-attr]
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
cache_mode=cache_mode,
)
return k_pe, k_nope
def exec_kv_prefill(
self,
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: tuple,
slots: torch.Tensor,
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight, # type: ignore[union-attr]
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
cache_mode=cache_mode,
is_output_kv=True,
)
return k_pe, k_nope
def rope_single(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
Disaggregate prefill for kv cache register style (#950) ### What this PR does / why we need it? This PR adopt `LLMDataDist` for kv cache register and `pull_blocks` style disaggregate prefill implementation. The interface implementation mainly follows the design of NIXL PR https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953 . This PR can be test with the following step: - Generate the rank table for all machine. - execute`toy_proxy.py` to launch the disaggregate prefill proxy server, specify the prefill ip, port and the decode ip, port - Run the prefill server and decode server. - send the request to the disaggregate prefill proxy ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/8d0a01a5f2b53794e4bc6b734d7b63cb8a9b7d7d --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Signed-off-by: liziyu179 <3475441767@qq.com> Signed-off-by: underfitc <hucong24@huawei.com> Signed-off-by: zouyida2052 <zouyida@huawei.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: underfituu <hzhucong@163.com> Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com> Co-authored-by: liziyu179 <3475441767@qq.com> Co-authored-by: underfitc <hucong24@huawei.com> Co-authored-by: zouyida2052 <zouyida@huawei.com> Co-authored-by: liziyu <liziyu16@huawei.com> Co-authored-by: underfituu <hzhucong@163.com>
2025-07-26 17:15:47 +08:00
num_tokens = q_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
if self.enable_kv_nz:
nz_fmt_last_dim = 16
k_nope = k_nope.view(
-1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim
)
k_pe = k_pe.view(
-1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim
)
else:
k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim)
attn_output_shape: tuple | None = None
if (
attn_metadata.attn_state
in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
]
and self.speculative_config is not None
):
# The right part layout indicates the layout of the attention
# output. It is set to NTD to avoid the need for a transpose
# operation after attention.
input_layout = "TND_NTD"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
# Input shape: [num_tokens, num_heads, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
# Output shape: [num_heads, num_tokens, dim]
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
sparse_mode = 3
attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# The output layout is set to NBSD to eliminate the need for a
# transpose operation after attention.
if self.enable_kv_nz:
# Input shape: [num_tokens, seq_len, num_heads, dim]
input_layout = "BSND_NBSD"
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
# Input shape: [num_tokens, num_heads, seq_len, dim]
input_layout = "BNSD_NBSD"
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# Output shape: [num_heads, num_tokens, seq_len, dim]
attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank)
sparse_mode = 0
attn_mask = None
common_kwargs = {
"query_rope": q_pe,
"key_rope": k_pe,
"num_heads": self.num_heads,
"num_key_value_heads": self.num_kv_heads,
"input_layout": input_layout,
"atten_mask": attn_mask,
"sparse_mode": sparse_mode,
"scale": self.scale,
"antiquant_mode": 0,
"antiquant_scale": None,
"block_table": decode_meta.block_table,
"block_size": block_size,
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
}
forward_context: ForwardContext = get_forward_context()
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
[Feat] Support MTP to running in full graph mode (#3892) ### What this PR does / why we need it? Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance. The change in both disable_padded_drafter_batch is True and False case include: 1. Add _mtp_graph_params in acl_graph.py to isolate the data of main model and the data of MTP. 2. Padding some metadata in mla_v1.py when in fullgraph mode. 3. Fixed the essential data address that will be used in model.forward. 4. Adapted according to the aclgraph capture framwork: 1). Rebuild MTP model with ACLGraphWrapper. 2). Add common attn metadata when start capture in MTP dummy_run. 3). Add common attn metadata update in MTP. 4). Addapted data update when num_speculative_tokens > 1. 5. Add a patch of MTP to adapt vllm v0.11.0. Existing Issues: 1. When disable_padded_drafter_batch=True and running in FullGraph mode, the data of the first-round requests in MTP is abnormal. We need to identify the cause subsequently. 2. When disable_padded_drafter_batch=False and running in FullGraph mode, the acceptance rate of the second and third tokens will decrease (For example, if we set the num_speculative_tokens=3, the acceptance rate of first token is 90%, the second is only 50% lower than 60%, the third is only 20% lower than 30%). The reason is that the data processed after the model runs does not match. This is a problem from another PR. It works fine in eager and PIECEWISE mode, but has problem in FullGraph mode. Once we have a solution, we will submit a bugfix. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
else:
graph_params = get_graph_params()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs
)
if forward_context.is_draft_model:
update_draft_graph_params_workspaces(num_tokens, workspace)
else:
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device)
softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(q_nope),
weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe),
self.num_heads,
self.num_kv_heads,
input_layout,
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
sparse_mode,
self.scale,
decode_meta.block_table,
block_size,
decode_meta.seq_lens_list,
actual_seq_lengths,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs)
return self._v_up_proj(attn_output)
def reorg_decode_q(self, decode_q_nope, decode_q_pe):
return decode_q_nope, decode_q_pe
def _mla_preprocess_only_decode(self, hidden_states, kv_cache, attn_metadata):
bsz = attn_metadata.num_decode_tokens
hidden_states = hidden_states[:bsz]
cos_shape = attn_metadata.decode.cos.shape
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
decode_q_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
decode_q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.W_UK_T,
decode_k_nope,
decode_k_pe,
attn_metadata.slot_mapping[:bsz],
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,
q_out1=decode_q_pe,
kv_cache_out1=decode_k_pe,
enable_inner_out=False,
inner_out=torch.tensor([], device=hidden_states.device),
)
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe)
decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return decode_preprocess_res, None
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens]
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :]
prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
prefill_k_nope, prefill_value = (
self.kv_b_proj(prefill_k_c_normed)[0]
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
)
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], self.num_kv_heads, -1)
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value)
def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata):
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q_c = q_c[:num_decode_tokens]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1]
decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots)
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv):
# MLA Preprocess:
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
# or
# Perform kv_a_proj_with_mqa to obtain kv_no_split
# 2. If need_gather_q_kv, perform all_gather.
# 3. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 4. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
if self.fused_qkv_a_proj is not None:
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c) # type: ignore[misc]
# allgather need contiguous data
kv_no_split = kv_no_split.contiguous()
else:
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] # type: ignore[misc]
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(kv_no_split.contiguous(), need_gather_q_kv)
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
[Feat] Flashcomm2 use o_shared linear (#4188) ### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart <img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39" src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0" /> ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <2783294813@qq.com> Co-authored-by: zzh02232027 <zzh02232027@antgroup.com> Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
2025-12-11 12:43:04 +08:00
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
# Preprocess for decode tokens
if has_decode:
decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata)
# Preprocess for prefill tokens
if has_prefill:
prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
return decode_preprocess_res, prefill_preprocess_res
def get_num_actual_tokens(self, attn_metadata: M):
return attn_metadata.num_actual_tokens
[main2main] upgrade vllm main 0202 (#6560) ### What this PR does / why we need it? 1. Fix `TypeError: FusedMoEParallelConfig.__init__() missing 1 required positional argument: 'is_sequence_parallel'` due to https://github.com/vllm-project/vllm/pull/32567 2. Fix ` TypeError: '>' not supported between instances of 'MagicMock' and 'int'` due to https://github.com/vllm-project/vllm/pull/33035 3. Fix `TypeError: Can't instantiate abstract class AscendMLAImpl with abstract methods forward_mha, forward_mqa` and AttributeError: 'bool' object has no attribute 'process_weights_after_loading' due to https://github.com/vllm-project/vllm/pull/33284 4. Fix `'AscendSharedFusedMoE' object has no attribute '_routed_input_transform'`due to https://github.com/vllm-project/vllm/pull/32790 5. Fix `NPUModelRunner._dummy_run() got an unexpected keyword argument 'num_active_loras'` due to https://github.com/vllm-project/vllm/pull/32005 6. Fix the problem caused by` 'tuple' object has no attribute 'job_id'` due to https://github.com/vllm-project/vllm/pull/27492 7. Fix the problem that all_moe_layers is not equal to vllm.moe_forward, vllm.moe_forward_shared due to https://github.com/vllm-project/vllm/pull/33184 8. Add patch to fix the problem "got multiple values for keyword argument 'add_special_tokens'" due to https://github.com/vllm-project/vllm/pull/32863 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: hfadzxy <starmoon_zhang@163.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: hfadzxy <starmoon_zhang@163.com>
2026-02-05 19:31:17 +08:00
def forward_mha(
self,
layer_name: str,
hidden_states: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError("forward_mha is not supported for MLA attention. Use forward() instead.")
def forward_mqa(
self,
layer_name: str,
hidden_states: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError("forward_mqa is not supported for MLA attention. Use forward() instead.")
def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: tuple[torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
[1/N][Refactor] Refactor code to adapt with vllm main (#3612) ### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 1. refactor deepseek to the latest code arch as of https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 2. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
return output.fill_(0)
forward_context = get_forward_context()
num_actual_tokens = self.get_num_actual_tokens(attn_metadata)
assert (
attn_metadata.num_decodes is not None
and attn_metadata.num_prefills is not None
and attn_metadata.num_decode_tokens is not None
)
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
# MLA Preprocess
if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv
)
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
hidden_states, kv_cache, attn_metadata
)
else:
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv
)
if decode_preprocess_res is not None:
# MLA Preprocess for decoding
output_decode = self._forward_decode(
decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata,
)
o_proj_input[:num_decode_tokens] = output_decode
[perf]: support dual-batch overlap(dbo) for deepseek (#941) ### What this PR does / why we need it? Based on the design of dual-batch overlap proposed by Deepseek team and also the implementation of fused moe in VLLM project, we implement the multi-stream(also known as dual-batch) overlap for deepseek+mla on Ascend NPU. We split the input batch of model into two microbatches and then overlap the comp/comm ops in attention and moe layers using two streams to improve the performance. Our approach can be easily extended when adding dispatch/combine communications for moe layer. Compared with the previously proposed [draft](https://github.com/vllm-project/vllm-ascend/pull/842), we use one stream for computation ops and the other for communication ops, separately. In out opinions, it is beneficial for arranging the order of executing different ops and thus avoiding the contention of computation/communication resources. ref: [overlap for llama](https://github.com/vllm-project/vllm/pull/15787/files) ref: [dbo in sglang](https://github.com/sgl-project/sglang/pull/4068/files#diff-b4937569fc71f6ad215181b633b2f89c7183a2b4ac39e41fc22635599a9be7de) ### Does this PR introduce _any_ user-facing change? Adding an env variable "VLLM_ENABLE_DBO". Users can enable dbo by setting "VLLM_ASCEND_ENABLE_DBO=1" See /examples/offline_dualbatch_overlap_npu.py for more info. ### How was this patch tested? This patch can be tested with vllm-0.9.0 using its online service with benchmark tests. We have decoupled the func of dbo from vllm and it should be able to run without any modification to the code of vllm(some modifications is better to implement in vllm though). Any advice/discussion is welcome. ### Performance Benchmark We have ran the benchmark_serving script of vllm to test the performance after using dual-batch overlap. `python -m vllm.entrypoints.openai.api_server \ --model=DeepSeek-R1-W8A8 \ --trust-remote-code \ --distributed-executor-backend=mp \ -tp=16 \ --port 8006 \ --max-num-seqs 390 \ --max-model-len 32768 \ --max-num-batched-tokens 65536 \ --block-size 128 \ --compilation_config 0 \ --gpu-memory-utilization 0.90 \ --disable-log-requests \ --additional-config '{"expert_tensor_parallel_size":1,"enable_inter_dp_scheduling":true,"init_torchair_graph_batch_sizes":true,"trace_recompiles":true,"ascend_scheduler_config":{},"enable_graph_mode":false}'` and run benchmark with the parameters of : `--dataset-name random --random-input-len 4096 --random-output-len 1 --num-prompts 200 --max-concurrency 8 --request-rate 5 --metric-percentiles 90` 1. test with the version using allgather+allreduce in Ascend 910B (tp16 ep16 + deepseek r1 w8a8) 2. test with the version using alltoall: prefill qps: 0.90 -> 1.01 Mean TTFT:8226->7432ms The overlap approach when using alltoall communication can be further optimized by overlapping micro-batch1's moe comp with micro-batch2's dispatch a2a comm --------- Signed-off-by: zhuohuan <zxdu1997@gmail.com>
2025-06-07 16:46:58 +08:00
if prefill_preprocess_res is not None:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope,
prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope,
prefill_preprocess_res.k_pe,
prefill_preprocess_res.value,
kv_cache,
attn_metadata,
)
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
linear_layer=self.o_proj,
)
output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0]
[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198) ### What this PR does / why we need it? 1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to pure DP for shared experts, enabling more efficient execution. 2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by using ReduceScatter, made possible by pure DP sharding. 3.AllGather Postponed: Delayed to after QKV down projection to reduce synchronization impact during prefill. ### How was this patch tested? Adding ut case in `tests/ut/attention/test_mla_v1.py` #### How to run use parameter `--additional_config='{"enable_shared_expert_dp": true}'` ##### a.How to run eager mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true}' ##### b.How to run graph mode eg: python -m vllm.entrypoints.openai.api_server --model=/model_path --trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002 --max-model-len 5120 --max-num-batched-tokens 16384 --disable-log-requests --additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp": true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}' - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9edd1db02bc6dce6da503503a373657f3466a78b --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
2025-08-12 14:12:12 +08:00
del o_proj_input
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded