2025-04-19 17:38:18 +08:00
|
|
|
from dataclasses import dataclass
|
2026-01-24 22:10:18 +08:00
|
|
|
from typing import TYPE_CHECKING, NamedTuple, TypeVar
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
import numpy as np
|
2025-04-19 17:38:18 +08:00
|
|
|
import torch
|
|
|
|
|
import torch_npu
|
2025-12-29 15:28:34 +08:00
|
|
|
import vllm.envs as envs_vllm
|
2025-08-20 09:01:04 +08:00
|
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
2025-10-10 16:31:20 +08:00
|
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
2025-10-21 20:17:09 +08:00
|
|
|
from vllm.logger import logger
|
2026-02-02 15:57:55 +08:00
|
|
|
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
2025-12-19 14:27:24 +08:00
|
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
2025-11-24 17:08:20 +08:00
|
|
|
from vllm.utils.math_utils import cdiv, round_down
|
2026-01-24 22:10:18 +08:00
|
|
|
from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore
|
2026-01-23 09:45:08 +08:00
|
|
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
2025-12-30 08:32:14 +08:00
|
|
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-06-05 16:28:01 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
2026-01-07 17:09:52 +08:00
|
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
2026-01-24 22:10:18 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
|
2026-01-21 10:45:45 +08:00
|
|
|
from vllm_ascend.attention.utils import (
|
2026-01-24 22:10:18 +08:00
|
|
|
AscendCommonAttentionMetadata,
|
|
|
|
|
ascend_chunked_prefill_workspace_size,
|
|
|
|
|
enable_cp,
|
2026-01-31 22:44:56 +08:00
|
|
|
enabling_mlapo,
|
2026-01-24 22:10:18 +08:00
|
|
|
maybe_save_kv_layer_to_connector,
|
|
|
|
|
split_decodes_and_prefills,
|
|
|
|
|
trans_rope_weight,
|
|
|
|
|
transdata,
|
|
|
|
|
wait_for_kv_layer_from_connector,
|
|
|
|
|
)
|
2025-12-29 09:54:51 +08:00
|
|
|
from vllm_ascend.compilation.acl_graph import (
|
2026-01-24 22:10:18 +08:00
|
|
|
get_draft_graph_params,
|
|
|
|
|
get_graph_params,
|
|
|
|
|
update_draft_graph_params_workspaces,
|
|
|
|
|
update_graph_params_workspaces,
|
|
|
|
|
)
|
2026-01-08 09:05:02 +08:00
|
|
|
from vllm_ascend.ops.layer_shard_linear import (
|
2026-01-24 22:10:18 +08:00
|
|
|
is_hidden_layer,
|
|
|
|
|
post_process_after_loading_for_shard_weight_series,
|
2026-01-08 09:05:02 +08:00
|
|
|
reach_layer_for_shard_weight_series,
|
2026-01-24 22:10:18 +08:00
|
|
|
register_all_layers_to_shard_weight_series,
|
|
|
|
|
)
|
2025-12-17 08:53:44 +08:00
|
|
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
|
|
|
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
2026-02-10 14:14:37 +08:00
|
|
|
from vllm_ascend.utils import (
|
|
|
|
|
ACL_FORMAT_FRACTAL_ND,
|
|
|
|
|
get_weight_prefetch_method,
|
|
|
|
|
maybe_trans_nz,
|
|
|
|
|
weak_ref_tensors,
|
|
|
|
|
)
|
2025-12-15 19:54:23 +08:00
|
|
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
|
|
2026-01-13 19:14:43 +08:00
|
|
|
|
2025-10-30 17:06:38 +08:00
|
|
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
2026-01-05 17:41:12 +08:00
|
|
|
BUILD_METADATA_STEP_PREFILL = 0
|
|
|
|
|
BUILD_METADATA_STEP_DECODE = 1
|
2026-01-16 17:52:48 +08:00
|
|
|
# token count limits within the mlapo operator
|
|
|
|
|
MLAPO_MAX_SUPPORTED_TOKENS = 1024
|
2025-10-30 17:06:38 +08:00
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
class AscendMLABackend(AttentionBackend):
|
|
|
|
|
accept_output_buffer: bool = True
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name() -> str:
|
2025-12-29 15:28:34 +08:00
|
|
|
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
|
|
|
|
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
|
|
|
|
# rectify this when vllm disable the assertion.
|
|
|
|
|
return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builder_cls():
|
2025-12-19 14:57:09 +08:00
|
|
|
if enable_cp():
|
2026-01-24 22:10:18 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPMetadataBuilder
|
|
|
|
|
|
2025-12-15 12:59:18 +08:00
|
|
|
return AscendMlaCPMetadataBuilder
|
2025-04-19 17:38:18 +08:00
|
|
|
return AscendMLAMetadataBuilder
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2026-01-24 22:10:18 +08:00
|
|
|
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]:
|
2025-12-15 12:59:18 +08:00
|
|
|
return num_blocks, block_size, num_kv_heads, head_size
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2026-01-24 22:10:18 +08:00
|
|
|
def get_impl_cls() -> type["MLAAttentionImpl"]:
|
2025-12-19 14:57:09 +08:00
|
|
|
if enable_cp():
|
2026-01-24 22:10:18 +08:00
|
|
|
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
|
|
|
|
|
2025-12-15 12:59:18 +08:00
|
|
|
return AscendMlaCPImpl
|
2025-04-19 17:38:18 +08:00
|
|
|
return AscendMLAImpl
|
|
|
|
|
|
2026-02-02 19:16:26 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def get_supported_kernel_block_sizes() -> list[int]:
|
|
|
|
|
return [128]
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
@dataclass
|
2025-12-24 10:25:19 +08:00
|
|
|
class ChunkedContextMetadata:
|
2026-01-13 08:46:50 +08:00
|
|
|
"""
|
|
|
|
|
Metadata for chunked context handling in MLA attention.
|
|
|
|
|
|
|
|
|
|
Manages sequence boundaries and workspace for chunked prefill processing.
|
|
|
|
|
"""
|
2026-01-24 22:10:18 +08:00
|
|
|
|
2025-12-24 10:25:19 +08:00
|
|
|
cu_seq_lens: torch.Tensor
|
|
|
|
|
starts: torch.Tensor
|
|
|
|
|
seq_tot: list[int]
|
|
|
|
|
max_seq_lens: list[int]
|
|
|
|
|
workspace: torch.Tensor
|
|
|
|
|
chunk_seq_lens: torch.Tensor
|
|
|
|
|
chunk_seq_lens_npu: torch.Tensor
|
2025-06-14 22:31:16 +08:00
|
|
|
|
2025-12-15 12:59:18 +08:00
|
|
|
|
2025-12-24 10:25:19 +08:00
|
|
|
@dataclass
|
|
|
|
|
class AscendMLAPrefillMetadata:
|
2026-01-24 22:10:18 +08:00
|
|
|
"""Prefill Specific Metadata for Ascend"""
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_mask: torch.Tensor
|
2025-10-15 11:30:30 +08:00
|
|
|
query_lens: torch.Tensor
|
2025-05-12 19:14:07 +08:00
|
|
|
seq_lens: list[int]
|
2025-04-19 17:38:18 +08:00
|
|
|
context_lens: torch.Tensor
|
|
|
|
|
input_positions: torch.Tensor
|
2025-05-30 08:59:58 +08:00
|
|
|
query_start_loc: torch.Tensor
|
2025-04-19 17:38:18 +08:00
|
|
|
block_table: torch.Tensor
|
|
|
|
|
max_query_len: int
|
2025-04-29 17:12:03 +08:00
|
|
|
max_seq_lens: int
|
2026-01-24 22:10:18 +08:00
|
|
|
chunked_context: ChunkedContextMetadata | CPChunkedContextMetadata | None = None
|
2025-07-29 18:06:45 +08:00
|
|
|
sin: torch.Tensor = None
|
|
|
|
|
cos: torch.Tensor = None
|
2026-01-24 22:10:18 +08:00
|
|
|
pcp_metadata: AscendPCPMetadata | None = None
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class AscendMLADecodeMetadata:
|
2026-01-24 22:10:18 +08:00
|
|
|
"""Decode-specific metadata for Ascend MLA attention."""
|
|
|
|
|
|
2026-01-13 08:46:50 +08:00
|
|
|
# Input positions for rotary embeddings since for MLA the rotary
|
2025-04-19 17:38:18 +08:00
|
|
|
# position embeddings are applied inside the attention backend
|
|
|
|
|
input_positions: torch.Tensor
|
|
|
|
|
block_table: torch.Tensor
|
|
|
|
|
seq_lens: torch.Tensor
|
2025-04-29 17:12:03 +08:00
|
|
|
max_seq_lens: int
|
2025-05-12 19:14:07 +08:00
|
|
|
seq_lens_list: list[int]
|
2026-01-24 22:10:18 +08:00
|
|
|
actual_seq_lengths_q: list[int] | None = None
|
|
|
|
|
attn_mask: torch.Tensor | None = None
|
2025-07-29 18:06:45 +08:00
|
|
|
sin: torch.Tensor = None
|
|
|
|
|
cos: torch.Tensor = None
|
2025-10-27 09:58:23 +08:00
|
|
|
cp_seq_len: torch.Tensor = None
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class AscendMLAMetadata:
|
|
|
|
|
"""Metadata for MLACommon.
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
"""
|
2026-01-24 22:10:18 +08:00
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
|
|
|
# |---------- N-1 iteration --------|
|
|
|
|
|
# |---------------- N iteration ---------------------|
|
|
|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
|
|
|
# |---------- context_len ----------|
|
|
|
|
|
# |-------------------- seq_len ---------------------|
|
|
|
|
|
# |-- query_len ---|
|
|
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
num_actual_tokens_pcp_padded: int
|
2025-04-19 17:38:18 +08:00
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
|
|
|
slot_mapping: torch.Tensor
|
2025-05-30 08:59:58 +08:00
|
|
|
query_start_loc: torch.Tensor
|
|
|
|
|
seq_lens: torch.Tensor
|
|
|
|
|
block_tables: torch.Tensor
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
|
|
|
# For handling prefill decode split
|
|
|
|
|
num_decodes: int
|
|
|
|
|
num_decode_tokens: int
|
|
|
|
|
num_prefills: int
|
|
|
|
|
|
|
|
|
|
# For logging.
|
|
|
|
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
query_lens: list[int] | None = None
|
2025-04-19 17:38:18 +08:00
|
|
|
# The dimension of the attention heads
|
2026-01-24 22:10:18 +08:00
|
|
|
head_dim: int | None = None
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_mask: torch.Tensor = None
|
|
|
|
|
# chunked prefill by default if no attn_states passed
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
decode: AscendMLADecodeMetadata | None = None
|
|
|
|
|
prefill: AscendMLAPrefillMetadata | None = None
|
2025-12-31 15:09:01 +08:00
|
|
|
reshape_cache_event: torch.npu.Event = None
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
pass
|
|
|
|
|
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
|
|
|
|
# if self.head_dim is not None and self.head_dim \
|
|
|
|
|
# not in supported_head_sizes:
|
|
|
|
|
# raise ValueError(
|
|
|
|
|
# f"Only {supported_head_sizes} are supported for head_dim,",
|
|
|
|
|
# f"received {self.head_dim}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
M = TypeVar("M", bound=AscendMLAMetadata)
|
|
|
|
|
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
2025-04-19 17:38:18 +08:00
|
|
|
"""
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
"""
|
|
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
kv_cache_spec: MLAAttentionSpec,
|
|
|
|
|
layer_names: list[str],
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
metadata_cls: type[AscendMLAMetadata] | None = None,
|
|
|
|
|
supports_dcp_with_varlen: bool = False,
|
|
|
|
|
):
|
2026-01-21 10:45:45 +08:00
|
|
|
super().__init__(
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_cache_spec,
|
|
|
|
|
layer_names,
|
|
|
|
|
vllm_config,
|
|
|
|
|
device,
|
2026-01-21 10:45:45 +08:00
|
|
|
metadata_cls if metadata_cls is not None else AscendMLAMetadata,
|
2026-01-24 22:10:18 +08:00
|
|
|
supports_dcp_with_varlen,
|
|
|
|
|
)
|
2026-01-21 10:45:45 +08:00
|
|
|
|
2025-08-20 09:01:04 +08:00
|
|
|
scheduler_config = vllm_config.scheduler_config
|
|
|
|
|
self.block_size = vllm_config.cache_config.block_size
|
2026-01-24 22:10:18 +08:00
|
|
|
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
|
2025-12-02 22:10:52 +08:00
|
|
|
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-09-18 14:05:33 +08:00
|
|
|
self.speculative_config = vllm_config.speculative_config
|
2025-08-28 10:35:57 +08:00
|
|
|
self.decode_threshold = 1
|
2025-09-18 14:05:33 +08:00
|
|
|
if self.speculative_config:
|
|
|
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
|
|
|
self.decode_threshold += spec_token_num
|
2026-01-24 22:10:18 +08:00
|
|
|
assert self.decode_threshold <= 16, (
|
|
|
|
|
f"decode_threshold exceeded \
|
2025-09-18 14:05:33 +08:00
|
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
|
|
|
got {self.decode_threshold}"
|
2026-01-24 22:10:18 +08:00
|
|
|
)
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-09-26 09:04:16 +08:00
|
|
|
self.reorder_batch_threshold = self.decode_threshold
|
2025-08-20 09:01:04 +08:00
|
|
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
2025-07-29 18:06:45 +08:00
|
|
|
self.cos_cache = None
|
|
|
|
|
self.sin_cache = None
|
2025-11-07 09:48:39 +08:00
|
|
|
|
2025-12-24 10:25:19 +08:00
|
|
|
self.chunk_seq_lens: torch.Tensor = None
|
|
|
|
|
self.cu_seq_lens_cpu: torch.Tensor = None
|
|
|
|
|
self.num_chunks: torch.Tensor = None
|
|
|
|
|
self.max_context_chunk = 0
|
|
|
|
|
self.num_decodes = 0
|
|
|
|
|
self.num_prefills = 0
|
|
|
|
|
self.num_decode_tokens = 0
|
|
|
|
|
self.num_prefill_tokens = 0
|
|
|
|
|
self.context_lens_cpu: torch.Tensor = None
|
2026-01-24 22:10:18 +08:00
|
|
|
self.num_actual_tokens: int | None = None
|
2025-12-24 10:25:19 +08:00
|
|
|
self.block_table: torch.Tensor = None
|
|
|
|
|
self.slot_mapping: torch.Tensor = None
|
|
|
|
|
self.graph_pad_size = 0
|
|
|
|
|
self.query_lens: torch.Tensor = None
|
|
|
|
|
self.seq_lens: torch.Tensor = None
|
2026-01-07 17:09:52 +08:00
|
|
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2026-01-21 10:45:45 +08:00
|
|
|
@staticmethod
|
2026-01-24 22:10:18 +08:00
|
|
|
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
2026-01-21 10:45:45 +08:00
|
|
|
return ascend_chunked_prefill_workspace_size(vllm_config)
|
|
|
|
|
|
2025-12-30 08:32:14 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_cudagraph_support(
|
|
|
|
|
cls: type["AscendMLAMetadataBuilder"],
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
kv_cache_spec: AttentionSpec,
|
|
|
|
|
) -> AttentionCGSupport:
|
|
|
|
|
# Explicit override in case the underlying builder specialized this getter.
|
|
|
|
|
# @override omitted only because of mypy limitation due to type variable.
|
|
|
|
|
return AttentionCGSupport.UNIFORM_BATCH
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool:
|
2025-04-19 17:38:18 +08:00
|
|
|
# We now want to reorder the batch so that the "decode" requests are at
|
|
|
|
|
# the front and the "prefill" requests are at the using the least amount
|
|
|
|
|
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
|
|
|
|
# where attention is likely memory-bound and "prefill" to mean requests
|
|
|
|
|
# where attention is likely compute-bound, TODO(lucas): figure out a
|
|
|
|
|
# better naming here)
|
|
|
|
|
decodes = []
|
|
|
|
|
prefills = []
|
|
|
|
|
|
|
|
|
|
for i, req_id in enumerate(input_batch.req_ids):
|
|
|
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
2025-08-28 10:35:57 +08:00
|
|
|
if num_tokens <= self.decode_threshold:
|
2025-08-21 14:02:30 +08:00
|
|
|
decodes.append(i)
|
2025-04-19 17:38:18 +08:00
|
|
|
else:
|
2025-08-21 14:02:30 +08:00
|
|
|
prefills.append(i)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
# We hope that this is fairly minimal since decodes
|
|
|
|
|
# should be around for a number of iterations so hopefully they are
|
|
|
|
|
# relatively stationary (and new request are generally appended to the
|
|
|
|
|
# persistent batch so already should be at the back)
|
|
|
|
|
# To achieve this we loop over the decodes in descending order and
|
|
|
|
|
# the prefills in ascending order. We swap decodes from the "back"
|
|
|
|
|
# i.e. past where the last decode should be in the reodorered with
|
|
|
|
|
# prefills from the front of the batch.
|
|
|
|
|
# `decodes` and `prefills` are already in ascending order just based on
|
|
|
|
|
# the above loop
|
|
|
|
|
num_decodes = len(decodes)
|
|
|
|
|
num_prefills = len(prefills)
|
|
|
|
|
first_prefill = 0
|
|
|
|
|
modified_batch = False
|
|
|
|
|
|
|
|
|
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
|
|
|
|
# If the decode is at the "back" of the batch, i, we can swap it
|
|
|
|
|
# with the prefill closest to the front of the batch
|
|
|
|
|
if decodes[num_decodes - i] >= num_decodes:
|
2026-01-24 22:10:18 +08:00
|
|
|
input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i])
|
2025-04-19 17:38:18 +08:00
|
|
|
first_prefill += 1
|
|
|
|
|
modified_batch = True
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Save for next `build` call
|
|
|
|
|
# TODO(lucas): this is a bit of a hack, we should probably have a
|
|
|
|
|
# better way of doing this
|
|
|
|
|
return modified_batch
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def pad_actual_seq_len_q_mtp_enable_pad(
|
|
|
|
|
self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata
|
|
|
|
|
):
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
"""
|
2025-12-15 12:59:18 +08:00
|
|
|
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
in order to meet the requirement of npu_fused_infer_attention_score.
|
|
|
|
|
|
|
|
|
|
In Torchair scenario, the lengths of the queries must be padded to the same length.
|
|
|
|
|
And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
batch_size=36, num_reqs_pad_size=2, num_reqs=16
|
2025-12-15 12:59:18 +08:00
|
|
|
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
|
|
|
|
|
|
2025-12-15 12:59:18 +08:00
|
|
|
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
|
2026-01-24 22:10:18 +08:00
|
|
|
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q
|
|
|
|
|
evenly to not exceed 16 tokens per request.
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
|
|
|
|
|
"""
|
|
|
|
|
FIA_SEQ_LEN_LIMIT = 16
|
2026-01-24 22:10:18 +08:00
|
|
|
need_padding = (
|
|
|
|
|
num_reqs_pad_size != 0
|
|
|
|
|
and len(common_attn_metadata.actual_seq_lengths_q) > num_reqs
|
|
|
|
|
and common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
|
|
|
|
|
)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
if need_padding:
|
2026-01-24 22:10:18 +08:00
|
|
|
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size]
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
start_val = actual_seq_lengths_q[-1]
|
|
|
|
|
end_val = padding_seq_len_q[-1]
|
|
|
|
|
|
|
|
|
|
num_step = len(padding_seq_len_q)
|
2026-01-24 22:10:18 +08:00
|
|
|
interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist()
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
assert interpolated[-1] == end_val
|
|
|
|
|
assert len(interpolated) == len(padding_seq_len_q)
|
|
|
|
|
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
|
|
|
|
|
else:
|
2026-01-24 22:10:18 +08:00
|
|
|
actual_seq_lengths_q = (
|
|
|
|
|
actual_seq_lengths_q
|
|
|
|
|
+ common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size]
|
|
|
|
|
)
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
|
|
|
|
|
return actual_seq_lengths_q
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q):
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
"""
|
|
|
|
|
Only use for acl full graph mode.
|
|
|
|
|
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
|
|
|
|
|
the num of dimensions equal to the batch_size of main model.
|
2025-12-15 12:59:18 +08:00
|
|
|
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
For example:
|
|
|
|
|
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
|
|
|
|
|
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
|
|
|
|
|
After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8]
|
|
|
|
|
"""
|
|
|
|
|
need_padding = num_reqs_pad_size > 0
|
|
|
|
|
if need_padding:
|
|
|
|
|
start_val = actual_seq_lengths_q[-1]
|
|
|
|
|
end_val = num_reqs + num_reqs_pad_size
|
|
|
|
|
num_step = num_reqs_pad_size
|
2026-01-24 22:10:18 +08:00
|
|
|
interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist()
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
assert interpolated[-1] == end_val
|
|
|
|
|
assert len(interpolated) == num_reqs_pad_size
|
|
|
|
|
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
|
|
|
|
|
return actual_seq_lengths_q
|
|
|
|
|
|
2025-12-24 10:25:19 +08:00
|
|
|
def set_num_actual_tokens(
|
|
|
|
|
self,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
):
|
|
|
|
|
self.num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
|
|
|
|
2025-06-04 18:31:41 +08:00
|
|
|
def build(
|
|
|
|
|
self,
|
2025-09-16 01:17:42 +08:00
|
|
|
common_prefix_len: int,
|
2025-08-20 09:01:04 +08:00
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
2025-12-28 10:35:07 +08:00
|
|
|
fast_build: bool = False,
|
2025-06-04 18:31:41 +08:00
|
|
|
) -> AscendMLAMetadata:
|
2025-08-20 09:01:04 +08:00
|
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
|
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
2025-10-24 10:32:01 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = (
|
2025-08-28 10:35:57 +08:00
|
|
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
2026-01-24 22:10:18 +08:00
|
|
|
)
|
2025-12-24 10:25:19 +08:00
|
|
|
self.set_num_actual_tokens(common_attn_metadata)
|
|
|
|
|
assert self.num_decodes + self.num_prefills == num_reqs
|
|
|
|
|
assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens
|
2025-10-24 10:32:01 +08:00
|
|
|
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
2026-01-24 22:10:18 +08:00
|
|
|
self.slot_mapping = common_attn_metadata.slot_mapping[: self.num_actual_tokens]
|
2025-08-20 09:01:04 +08:00
|
|
|
|
|
|
|
|
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
2025-12-24 10:25:19 +08:00
|
|
|
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
|
|
|
|
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
|
|
|
|
|
2026-01-05 17:41:12 +08:00
|
|
|
self.graph_pad_size = common_attn_metadata.graph_pad_size
|
2026-01-24 22:10:18 +08:00
|
|
|
block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
|
|
|
|
|
self.block_table = common_attn_metadata.block_table_tensor[:block_table_size]
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
prefill_metadata = None
|
2025-12-24 10:25:19 +08:00
|
|
|
if self.num_prefills > 0:
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_metadata = self.build_prefill_metadata(common_prefix_len, common_attn_metadata)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
decode_metadata = None
|
2025-12-24 10:25:19 +08:00
|
|
|
if self.num_decodes > 0:
|
2026-01-24 22:10:18 +08:00
|
|
|
decode_metadata = self.build_decode_metadata(common_prefix_len, common_attn_metadata)
|
2025-04-19 17:38:18 +08:00
|
|
|
return self.metadata_cls( # type: ignore
|
2025-12-24 10:25:19 +08:00
|
|
|
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
[Core]Append padding logic for Attention (#3256)
### What this PR does / why we need it?
This PR aims to add padding logic to seq_lens、block_tables when running
in full decode scenario. Before this PR, the number of input tokens with
padding might exceeds corresponding seq_lens. For example, when running
in full decode scenario:
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1]
query_start_loc: [0, 1, 2]
```
Here, `input_ids` is padded by 2 tokens while
`seq_lens`/`query_start_loc` are not. The mismatch between `input_ids`
and `seq_lens`/`query_start_loc` might cause some potential bugs. This
PR would change it into :
```
input_ids : [1, 3, 0, 0]
seq_lens: [2, 1, 1, 1]
query_start_loc: [0, 1, 2, 3, 4]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: Angazenn <supperccell@163.com>
2025-10-17 21:56:01 +08:00
|
|
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
2025-12-24 10:25:19 +08:00
|
|
|
num_actual_tokens=self.num_actual_tokens,
|
|
|
|
|
query_lens=self.query_lens.tolist(),
|
|
|
|
|
slot_mapping=self.slot_mapping,
|
2025-08-20 09:01:04 +08:00
|
|
|
head_dim=self.model_config.get_head_size(),
|
2025-12-24 10:25:19 +08:00
|
|
|
num_decodes=self.num_decodes,
|
|
|
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
|
|
|
num_prefills=self.num_prefills,
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
|
2025-08-20 09:01:04 +08:00
|
|
|
attn_state=common_attn_metadata.attn_state,
|
2025-04-19 17:38:18 +08:00
|
|
|
prefill=prefill_metadata,
|
|
|
|
|
decode=decode_metadata,
|
2025-05-30 08:59:58 +08:00
|
|
|
query_start_loc=query_start_loc,
|
2025-12-24 10:25:19 +08:00
|
|
|
block_tables=self.block_table,
|
|
|
|
|
seq_lens=self.seq_lens,
|
2025-04-19 17:38:18 +08:00
|
|
|
)
|
|
|
|
|
|
2025-12-24 10:25:19 +08:00
|
|
|
def build_chunked_metadata(
|
|
|
|
|
self,
|
|
|
|
|
common_prefix_len: int,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
):
|
|
|
|
|
if not self.chunked_prefill_enabled:
|
|
|
|
|
return None
|
|
|
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
num_computed_tokens_cpu = self.seq_lens - self.query_lens
|
2025-12-24 10:25:19 +08:00
|
|
|
reqs_start = self.num_decodes # prefill_start
|
|
|
|
|
|
|
|
|
|
self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
|
|
|
|
max_context_len_cpu = self.context_lens_cpu.max().item()
|
|
|
|
|
if not max_context_len_cpu > 0:
|
|
|
|
|
return None
|
2026-01-24 22:10:18 +08:00
|
|
|
num_prefills_with_context_cpu = (self.context_lens_cpu > 0).sum().item()
|
|
|
|
|
self.max_context_chunk = self.chunked_prefill_workspace_size // num_prefills_with_context_cpu
|
|
|
|
|
self.max_context_chunk = round_down(self.max_context_chunk, self.block_size)
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
assert self.max_context_chunk > 0
|
|
|
|
|
self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk)
|
2026-01-24 22:10:18 +08:00
|
|
|
chunk_starts = (
|
|
|
|
|
torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills)
|
|
|
|
|
* self.max_context_chunk
|
|
|
|
|
)
|
|
|
|
|
chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0), chunk_starts + self.max_context_chunk)
|
2025-12-24 10:25:19 +08:00
|
|
|
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
2026-01-24 22:10:18 +08:00
|
|
|
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True)
|
|
|
|
|
torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32)
|
2025-12-24 10:25:19 +08:00
|
|
|
return ChunkedContextMetadata(
|
2026-01-24 22:10:18 +08:00
|
|
|
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
|
|
|
|
|
starts=chunk_starts.pin_memory().to(self.device, non_blocking=True),
|
2025-12-24 10:25:19 +08:00
|
|
|
seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(),
|
|
|
|
|
max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(),
|
|
|
|
|
chunk_seq_lens=self.chunk_seq_lens,
|
|
|
|
|
chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
|
|
|
|
|
workspace=self.chunked_prefill_workspace,
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
|
2026-01-05 17:41:12 +08:00
|
|
|
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
|
|
|
|
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
|
|
|
|
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
2026-01-24 22:10:18 +08:00
|
|
|
if (
|
|
|
|
|
self.graph_pad_size > common_attn_metadata.num_reqs
|
|
|
|
|
and self.speculative_config.disable_padded_drafter_batch
|
|
|
|
|
):
|
2026-01-05 17:41:12 +08:00
|
|
|
return self.graph_pad_size
|
|
|
|
|
return common_attn_metadata.num_reqs
|
|
|
|
|
return self.num_decodes
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
def build_prefill_metadata(
|
|
|
|
|
self,
|
|
|
|
|
common_prefix_len: int,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
) -> AscendMLAPrefillMetadata:
|
|
|
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
|
|
|
|
|
|
|
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
2026-01-24 22:10:18 +08:00
|
|
|
input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long()
|
2025-12-24 10:25:19 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
chunked_context_metadata = self.build_chunked_metadata(common_prefix_len, common_attn_metadata)
|
2025-12-24 10:25:19 +08:00
|
|
|
reqs_start = self.num_decodes # prefill_start
|
|
|
|
|
tokens_start = self.num_decode_tokens
|
|
|
|
|
max_query_len = self.query_lens[reqs_start:].max().item()
|
|
|
|
|
max_seq_lens = self.seq_lens[reqs_start:].max().item()
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_query_start_loc = query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
prefill_input_positions = input_positions[tokens_start:]
|
2025-12-28 10:35:07 +08:00
|
|
|
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
2025-12-24 10:25:19 +08:00
|
|
|
return AscendMLAPrefillMetadata(
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
|
2025-12-24 10:25:19 +08:00
|
|
|
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
|
|
|
|
seq_lens=self.seq_lens,
|
|
|
|
|
context_lens=self.seq_lens[reqs_start:],
|
|
|
|
|
input_positions=prefill_input_positions,
|
|
|
|
|
block_table=self.block_table[reqs_start:, ...],
|
|
|
|
|
max_query_len=max_query_len,
|
|
|
|
|
max_seq_lens=max_seq_lens,
|
|
|
|
|
query_start_loc=prefill_query_start_loc,
|
|
|
|
|
chunked_context=chunked_context_metadata,
|
|
|
|
|
sin=sin,
|
|
|
|
|
cos=cos,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def build_decode_metadata(
|
|
|
|
|
self,
|
|
|
|
|
common_prefix_len: int,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
) -> AscendMLADecodeMetadata:
|
|
|
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long()
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
2026-01-24 22:10:18 +08:00
|
|
|
actual_seq_lengths_q = query_start_loc_cpu[1 : self.num_decodes + 1].tolist()
|
|
|
|
|
max_seq_lens = self.seq_lens[: self.num_decodes].max().item()
|
|
|
|
|
self.seq_lens = self.seq_lens[: self.num_decodes]
|
|
|
|
|
input_positions = input_positions[: self.num_decode_tokens]
|
|
|
|
|
|
|
|
|
|
block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_DECODE)
|
2026-01-05 17:41:12 +08:00
|
|
|
self.block_table = self.block_table[:block_table_size]
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
|
|
|
|
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
2026-01-24 22:10:18 +08:00
|
|
|
if self.graph_pad_size > self.num_decodes and self.speculative_config.disable_padded_drafter_batch:
|
|
|
|
|
self.block_table = self.block_table[: self.graph_pad_size, ...]
|
2025-12-24 10:25:19 +08:00
|
|
|
seq_lens_list = self.seq_lens.tolist()
|
|
|
|
|
|
2026-01-23 14:13:12 +08:00
|
|
|
cp_seq_len = None
|
2025-12-24 10:25:19 +08:00
|
|
|
|
|
|
|
|
if self.graph_pad_size > num_reqs:
|
|
|
|
|
if self.speculative_config.disable_padded_drafter_batch:
|
|
|
|
|
num_reqs_pad_size = self.graph_pad_size - num_reqs
|
|
|
|
|
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
|
2026-01-24 22:10:18 +08:00
|
|
|
num_reqs_pad_size, num_reqs, actual_seq_lengths_q
|
|
|
|
|
)
|
|
|
|
|
seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size - self.num_decodes)
|
|
|
|
|
num_block_pad_size = self.graph_pad_size - self.block_table.shape[0]
|
2025-12-24 10:25:19 +08:00
|
|
|
if num_block_pad_size > 0:
|
|
|
|
|
block_table_padding = torch.zeros(
|
2026-01-24 22:10:18 +08:00
|
|
|
(num_block_pad_size,) + self.block_table.shape[1:],
|
2025-12-24 10:25:19 +08:00
|
|
|
dtype=self.block_table.dtype,
|
2026-01-24 22:10:18 +08:00
|
|
|
device=self.block_table.device,
|
|
|
|
|
)
|
|
|
|
|
self.block_table = torch.cat([self.block_table, block_table_padding], dim=0)
|
2025-12-24 10:25:19 +08:00
|
|
|
else:
|
|
|
|
|
num_token_pad_size = self.graph_pad_size - self.num_decode_tokens
|
2026-01-24 22:10:18 +08:00
|
|
|
num_reqs_pad_size = self.graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs
|
2025-12-24 10:25:19 +08:00
|
|
|
num_block_table_pad_size = (
|
2026-01-24 22:10:18 +08:00
|
|
|
self.graph_pad_size // common_attn_metadata.decode_token_per_req - self.num_decodes
|
|
|
|
|
)
|
|
|
|
|
seq_lens_list = self.seq_lens.tolist() + [0] * num_reqs_pad_size
|
|
|
|
|
slot_padding = torch.full(
|
|
|
|
|
(num_token_pad_size,), PAD_SLOT_ID, dtype=self.slot_mapping.dtype, device=self.slot_mapping.device
|
|
|
|
|
)
|
|
|
|
|
self.slot_mapping = torch.cat([self.slot_mapping, slot_padding])
|
2025-12-24 10:25:19 +08:00
|
|
|
block_table_padding = torch.zeros(
|
2026-01-24 22:10:18 +08:00
|
|
|
(num_block_table_pad_size,) + self.block_table.shape[1:],
|
2025-12-24 10:25:19 +08:00
|
|
|
dtype=self.block_table.dtype,
|
2026-01-24 22:10:18 +08:00
|
|
|
device=self.block_table.device,
|
|
|
|
|
)
|
|
|
|
|
self.block_table = torch.cat([self.block_table, block_table_padding], dim=0)
|
|
|
|
|
position_padding = torch.zeros(
|
|
|
|
|
num_token_pad_size, dtype=input_positions.dtype, device=input_positions.device
|
|
|
|
|
)
|
|
|
|
|
input_positions = torch.cat([input_positions, position_padding])
|
2025-12-24 10:25:19 +08:00
|
|
|
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
|
2026-01-24 22:10:18 +08:00
|
|
|
num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata
|
|
|
|
|
)
|
2025-12-24 10:25:19 +08:00
|
|
|
|
2025-12-28 10:35:07 +08:00
|
|
|
cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True)
|
|
|
|
|
decode_metadata = AscendMLADecodeMetadata(
|
|
|
|
|
input_positions=input_positions,
|
|
|
|
|
block_table=self.block_table,
|
|
|
|
|
seq_lens=self.seq_lens,
|
|
|
|
|
seq_lens_list=seq_lens_list,
|
|
|
|
|
max_seq_lens=max_seq_lens,
|
2026-01-07 17:09:52 +08:00
|
|
|
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
2025-12-28 10:35:07 +08:00
|
|
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
2026-01-24 22:10:18 +08:00
|
|
|
sin=sin[: self.num_decode_tokens, ...],
|
|
|
|
|
cos=cos[: self.num_decode_tokens, ...],
|
|
|
|
|
cp_seq_len=cp_seq_len,
|
|
|
|
|
)
|
2025-12-24 10:25:19 +08:00
|
|
|
return decode_metadata
|
|
|
|
|
|
2025-10-10 16:31:20 +08:00
|
|
|
def build_for_graph_capture(
|
|
|
|
|
self,
|
|
|
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
|
|
|
|
):
|
2026-01-24 22:10:18 +08:00
|
|
|
if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}:
|
2025-10-10 16:31:20 +08:00
|
|
|
attn_metadata = self.build(
|
|
|
|
|
common_prefix_len=0,
|
|
|
|
|
common_attn_metadata=common_attn_metadata,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(
|
2025-10-17 20:19:56 +08:00
|
|
|
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state"
|
2025-10-10 16:31:20 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_metadata.attn_state = attn_state
|
|
|
|
|
return attn_metadata
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-08-28 10:35:57 +08:00
|
|
|
class DecodeMLAPreprocessResult(NamedTuple):
|
2026-01-24 22:10:18 +08:00
|
|
|
ql_nope: torch.Tensor | None = None
|
|
|
|
|
q_pe: torch.Tensor | None = None
|
|
|
|
|
k_nope: torch.Tensor | None = None
|
|
|
|
|
k_pe: torch.Tensor | None = None
|
|
|
|
|
decode_q_wo_k_up: torch.Tensor | None = None
|
2025-08-28 10:35:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrefillMLAPreprocessResult(NamedTuple):
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope: torch.Tensor | None = None
|
|
|
|
|
q_pe: torch.Tensor | None = None
|
|
|
|
|
k_nope: torch.Tensor | None = None
|
|
|
|
|
k_pe: torch.Tensor | None = None
|
|
|
|
|
value: torch.Tensor | None = None
|
2025-08-28 10:35:57 +08:00
|
|
|
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
class AscendMLAImpl(MLAAttentionImpl):
|
|
|
|
|
"""
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
scale: float,
|
|
|
|
|
num_kv_heads: int,
|
2026-01-24 22:10:18 +08:00
|
|
|
alibi_slopes: list[float] | None,
|
|
|
|
|
sliding_window: int | None,
|
2025-04-19 17:38:18 +08:00
|
|
|
kv_cache_dtype: str,
|
2026-01-24 22:10:18 +08:00
|
|
|
logits_soft_cap: float | None,
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_type: str,
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_sharing_target_layer_name: str | None,
|
2025-04-19 17:38:18 +08:00
|
|
|
**kwargs,
|
2025-12-15 12:59:18 +08:00
|
|
|
):
|
2025-12-31 15:09:01 +08:00
|
|
|
self.vllm_config = get_current_vllm_config()
|
2025-04-19 17:38:18 +08:00
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
|
|
|
|
self.scale = float(scale)
|
|
|
|
|
self.num_kv_heads = num_kv_heads
|
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
|
2025-06-04 20:26:44 +08:00
|
|
|
# MLA Args
|
2026-01-24 22:10:18 +08:00
|
|
|
self.q_lora_rank = kwargs["q_lora_rank"]
|
|
|
|
|
self.kv_lora_rank = kwargs["kv_lora_rank"]
|
|
|
|
|
self.qk_nope_head_dim = kwargs["qk_nope_head_dim"]
|
|
|
|
|
self.qk_rope_head_dim = kwargs["qk_rope_head_dim"]
|
|
|
|
|
self.qk_head_dim = kwargs["qk_head_dim"]
|
|
|
|
|
self.v_head_dim = kwargs["v_head_dim"]
|
|
|
|
|
self.rotary_emb = kwargs["rotary_emb"]
|
|
|
|
|
self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj")
|
|
|
|
|
self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"]
|
|
|
|
|
self.kv_b_proj = kwargs["kv_b_proj"]
|
|
|
|
|
self.o_proj = kwargs["o_proj"]
|
2025-12-11 12:43:04 +08:00
|
|
|
self.vllm_config = get_current_vllm_config()
|
2026-01-24 22:10:18 +08:00
|
|
|
self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa")
|
|
|
|
|
self.kv_a_layernorm = kwargs.get("kv_a_layernorm")
|
|
|
|
|
self.q_a_layernorm = kwargs.get("q_a_layernorm")
|
2025-06-10 22:26:53 +08:00
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
2025-06-04 20:26:44 +08:00
|
|
|
|
2025-06-05 16:28:01 +08:00
|
|
|
ascend_config = get_ascend_config()
|
2025-08-12 14:12:12 +08:00
|
|
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
2025-12-31 14:24:04 +08:00
|
|
|
self.enable_kv_nz = ascend_config.enable_kv_nz
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-09-04 10:22:46 +08:00
|
|
|
self.ring_mla_mask_size = 512
|
2025-06-12 21:42:09 +08:00
|
|
|
|
2025-12-11 12:43:04 +08:00
|
|
|
self.speculative_config = self.vllm_config.speculative_config
|
2026-01-31 22:44:56 +08:00
|
|
|
self.enable_mlapo = enabling_mlapo(self.vllm_config)
|
2025-05-12 19:14:07 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
self.is_kv_producer = (
|
|
|
|
|
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
|
|
|
|
)
|
2026-01-08 09:05:02 +08:00
|
|
|
self.layer_sharding_kwargs = []
|
2026-01-24 22:10:18 +08:00
|
|
|
for layer_name in get_ascend_config().layer_sharding or []:
|
2026-01-08 09:05:02 +08:00
|
|
|
if layer_name in kwargs:
|
|
|
|
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
|
|
|
|
else:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
|
|
|
|
)
|
|
|
|
|
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
2025-12-31 15:09:01 +08:00
|
|
|
|
2026-01-26 09:04:54 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def update_graph_params(
|
|
|
|
|
update_stream,
|
|
|
|
|
forward_context,
|
|
|
|
|
num_tokens,
|
|
|
|
|
vllm_config=None,
|
|
|
|
|
speculative_config=None,
|
|
|
|
|
num_dcp_pcp_tokens=None,
|
2026-01-28 14:41:18 +08:00
|
|
|
draft_attn_metadatas=None,
|
2026-01-26 09:04:54 +08:00
|
|
|
):
|
|
|
|
|
if forward_context.is_draft_model:
|
|
|
|
|
graph_params = get_draft_graph_params()
|
|
|
|
|
else:
|
|
|
|
|
graph_params = get_graph_params()
|
|
|
|
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
|
|
|
|
# for each layer's attention op in the graph.
|
|
|
|
|
with torch.npu.stream(update_stream):
|
|
|
|
|
for key, param, handle, event in zip(
|
|
|
|
|
forward_context.attn_metadata,
|
|
|
|
|
graph_params.attn_params[num_tokens],
|
|
|
|
|
graph_params.handles[num_tokens],
|
|
|
|
|
graph_params.events[num_tokens],
|
|
|
|
|
):
|
|
|
|
|
(
|
|
|
|
|
q_nope,
|
|
|
|
|
k_nope,
|
|
|
|
|
q_pe,
|
|
|
|
|
k_pe,
|
|
|
|
|
num_heads,
|
|
|
|
|
num_kv_heads,
|
|
|
|
|
input_layout,
|
|
|
|
|
attn_mask,
|
|
|
|
|
sparse_mode,
|
|
|
|
|
scale,
|
|
|
|
|
block_table,
|
|
|
|
|
block_size,
|
|
|
|
|
seq_lens_list,
|
|
|
|
|
actual_seq_lengths,
|
|
|
|
|
attn_output,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
) = param
|
|
|
|
|
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
|
|
|
|
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
|
|
|
|
|
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
|
|
|
|
spec_multiple = speculative_config.num_speculative_tokens + 1
|
|
|
|
|
seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list))
|
|
|
|
|
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)]
|
|
|
|
|
elif forward_context.is_draft_model:
|
|
|
|
|
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
|
|
|
|
block_table = forward_context.attn_metadata[key].decode.block_table
|
|
|
|
|
# TODO: This is a hack and should be fixed in the future.
|
|
|
|
|
if speculative_config.disable_padded_drafter_batch:
|
|
|
|
|
block_table = block_table[: len(actual_seq_lengths)]
|
|
|
|
|
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
|
|
|
|
|
else:
|
|
|
|
|
seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list))
|
|
|
|
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
|
|
|
|
|
|
|
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
|
|
|
|
q_nope,
|
|
|
|
|
k_nope,
|
|
|
|
|
k_nope,
|
|
|
|
|
query_rope=q_pe,
|
|
|
|
|
key_rope=k_pe,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
num_key_value_heads=num_kv_heads,
|
|
|
|
|
input_layout=input_layout,
|
|
|
|
|
atten_mask=attn_mask,
|
|
|
|
|
sparse_mode=sparse_mode,
|
|
|
|
|
scale=scale,
|
|
|
|
|
antiquant_mode=0,
|
|
|
|
|
antiquant_scale=None,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
actual_seq_lengths_kv=seq_lens_list,
|
|
|
|
|
actual_seq_lengths=actual_seq_lengths,
|
|
|
|
|
workspace=graph_params.workspaces.get(num_tokens),
|
|
|
|
|
out=[attn_output, softmax_lse],
|
|
|
|
|
)
|
|
|
|
|
torch.npu.graph_task_update_end(update_stream)
|
|
|
|
|
|
|
|
|
|
event.record(update_stream)
|
|
|
|
|
|
2025-08-28 10:35:57 +08:00
|
|
|
def _v_up_proj(self, x):
|
2025-12-26 22:03:46 +08:00
|
|
|
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
|
|
|
|
x = x.view(self.num_heads, -1, self.kv_lora_rank)
|
|
|
|
|
# Multiply (N, B, L) x (N, L, V) -> (B, N, V)
|
|
|
|
|
x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2))
|
|
|
|
|
# Convert from (B, N, V) to (B, N * V)
|
|
|
|
|
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
2025-08-21 14:02:30 +08:00
|
|
|
return x
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
# Return `ql_nope`, `q_pe`
|
|
|
|
|
def _q_proj_and_k_up_proj(self, x):
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope, q_pe = (
|
|
|
|
|
self.q_proj(x)[0]
|
|
|
|
|
.view(-1, self.num_heads, self.qk_head_dim)
|
2025-04-19 17:38:18 +08:00
|
|
|
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
2026-01-24 22:10:18 +08:00
|
|
|
)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
# Convert from (B, N, P) to (N, B, P)
|
|
|
|
|
q_nope = q_nope.transpose(0, 1)
|
|
|
|
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|
|
|
|
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
|
|
|
|
# Convert from (N, B, L) to (B, N, L)
|
|
|
|
|
return ql_nope.transpose(0, 1), q_pe
|
|
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
2025-12-19 14:27:24 +08:00
|
|
|
# NOTE: We currently do not support quant kv_b_proj.
|
|
|
|
|
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
|
|
|
|
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
|
2025-04-19 17:38:18 +08:00
|
|
|
assert kv_b_proj_weight.shape == (
|
|
|
|
|
self.kv_lora_rank,
|
2026-01-24 22:10:18 +08:00
|
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
|
|
|
), (
|
|
|
|
|
f"{kv_b_proj_weight.shape=}, "
|
|
|
|
|
f"{self.kv_lora_rank=}, "
|
|
|
|
|
f"{self.num_heads=}, "
|
|
|
|
|
f"{self.qk_nope_head_dim=}, "
|
|
|
|
|
f"{self.v_head_dim=}"
|
|
|
|
|
)
|
2025-04-19 17:38:18 +08:00
|
|
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
|
|
|
|
self.kv_lora_rank,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.qk_nope_head_dim + self.v_head_dim,
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
# Convert from (L, N, V) to (N, L, V)
|
2025-05-23 10:18:10 +08:00
|
|
|
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
2025-04-19 17:38:18 +08:00
|
|
|
# Convert from (L, N, P) to (N, P, L)
|
2025-05-23 10:18:10 +08:00
|
|
|
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
2025-06-15 19:57:02 +08:00
|
|
|
|
2025-12-19 14:27:24 +08:00
|
|
|
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
|
|
|
|
|
# self.W_UV = maybe_trans_nz(self.W_UV)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-10-30 17:06:38 +08:00
|
|
|
if self.enable_mlapo:
|
|
|
|
|
# Currently mlapo only supports W8A8 quantization in MLA scenario
|
|
|
|
|
# TODO(whx): modify this limitation when mlapo supports floating point
|
|
|
|
|
if self.fused_qkv_a_proj is None or not isinstance(
|
2026-01-24 22:10:18 +08:00
|
|
|
getattr(self.fused_qkv_a_proj.quant_method, "quant_method", None), AscendW8A8LinearMethod
|
|
|
|
|
):
|
2025-10-30 17:06:38 +08:00
|
|
|
self.enable_mlapo = False
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
|
|
|
|
"Some layers in your model are not quantized with W8A8,"
|
2026-01-24 22:10:18 +08:00
|
|
|
"thus mlapo is disabled for these layers."
|
|
|
|
|
)
|
2025-10-21 20:17:09 +08:00
|
|
|
if self.enable_mlapo:
|
2025-10-15 10:34:25 +08:00
|
|
|
self._process_weights_for_fused_mlapo(act_dtype)
|
2025-12-19 14:27:24 +08:00
|
|
|
else:
|
|
|
|
|
# if mlapo, W_UK_T can't trans nz
|
|
|
|
|
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
2025-10-15 10:34:25 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
for layer in self.layer_sharding_kwargs or []:
|
2026-01-08 09:05:02 +08:00
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
post_process_after_loading_for_shard_weight_series(layer)
|
2025-12-11 12:43:04 +08:00
|
|
|
|
2025-10-15 10:34:25 +08:00
|
|
|
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr]
|
|
|
|
|
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr]
|
2025-10-30 00:34:55 +08:00
|
|
|
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
2025-10-15 10:34:25 +08:00
|
|
|
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
2025-10-30 00:34:55 +08:00
|
|
|
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
2025-10-20 15:31:34 +08:00
|
|
|
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
|
2025-10-15 10:34:25 +08:00
|
|
|
wd_qkv = wd_qkv.t().contiguous()
|
2026-01-24 22:10:18 +08:00
|
|
|
wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous()
|
2025-10-15 10:34:25 +08:00
|
|
|
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
|
|
|
|
|
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
|
|
|
|
|
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
|
|
|
|
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim)
|
|
|
|
|
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
|
|
|
|
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous()
|
|
|
|
|
|
|
|
|
|
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
|
|
|
|
|
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
|
|
|
|
|
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
|
|
|
|
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim)
|
|
|
|
|
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
|
|
|
|
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous()
|
2025-10-15 10:34:25 +08:00
|
|
|
|
|
|
|
|
wu_q = self.q_proj.weight.data
|
2026-01-24 22:10:18 +08:00
|
|
|
wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
2025-10-15 10:34:25 +08:00
|
|
|
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
|
2026-01-24 22:10:18 +08:00
|
|
|
wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1)
|
2025-10-15 10:34:25 +08:00
|
|
|
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
|
|
|
|
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
|
|
|
|
|
|
|
|
|
qb_deq_scl = self.q_proj.deq_scale.data
|
2026-01-24 22:10:18 +08:00
|
|
|
qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
2025-10-15 10:34:25 +08:00
|
|
|
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
|
2026-01-24 22:10:18 +08:00
|
|
|
self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
2025-10-15 10:34:25 +08:00
|
|
|
|
|
|
|
|
qb_qt_bias = self.q_proj.quant_bias.data
|
2026-01-24 22:10:18 +08:00
|
|
|
qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
2025-10-15 10:34:25 +08:00
|
|
|
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
|
2026-01-24 22:10:18 +08:00
|
|
|
self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
2025-10-15 10:34:25 +08:00
|
|
|
|
2025-10-23 09:12:50 +08:00
|
|
|
device = self.q_proj.weight.device
|
2026-01-24 22:10:18 +08:00
|
|
|
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
|
|
|
|
|
self.beta1 = torch.zeros_like(self.gamma1) if (_bias := self.q_a_layernorm.bias) is None else _bias.data # type: ignore[union-attr]
|
|
|
|
|
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
|
|
|
|
|
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data # type: ignore[union-attr]
|
|
|
|
|
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data # type: ignore[union-attr]
|
2025-10-15 10:34:25 +08:00
|
|
|
self.quant_scale1 = self.q_proj.input_scale.data
|
|
|
|
|
self.quant_offset1 = self.q_proj.input_offset.data
|
|
|
|
|
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
|
|
|
|
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
|
|
|
|
|
2026-01-05 21:29:45 +08:00
|
|
|
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
|
|
|
|
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
|
|
|
|
# referenced, so drop them to save memory.
|
2026-01-24 22:10:18 +08:00
|
|
|
if (
|
|
|
|
|
self.vllm_config.kv_transfer_config is not None
|
|
|
|
|
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
|
|
|
|
and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS
|
|
|
|
|
):
|
|
|
|
|
self.fused_qkv_a_proj.weight = None # type: ignore[union-attr]
|
|
|
|
|
self.fused_qkv_a_proj.deq_scale = None # type: ignore[union-attr]
|
|
|
|
|
self.fused_qkv_a_proj.quant_bias = None # type: ignore[union-attr]
|
2026-01-05 21:29:45 +08:00
|
|
|
self.q_proj.weight = None
|
|
|
|
|
self.q_proj.deq_scale = None
|
|
|
|
|
self.q_proj.quant_bias = None
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata):
|
2025-12-28 10:40:45 +08:00
|
|
|
prefill_metadata = attn_metadata.prefill
|
|
|
|
|
assert prefill_metadata is not None
|
|
|
|
|
assert prefill_metadata.chunked_context is not None
|
|
|
|
|
assert prefill_metadata.chunked_context.chunk_seq_lens_npu is not None
|
|
|
|
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
|
|
|
|
assert 0 <= index < iters
|
|
|
|
|
return prefill_metadata.chunked_context.chunk_seq_lens_npu[index]
|
|
|
|
|
|
|
|
|
|
def _reorg_kvcache(
|
|
|
|
|
self,
|
|
|
|
|
kv_c_normed: torch.Tensor,
|
|
|
|
|
k_pe: torch.Tensor,
|
|
|
|
|
chunked_context: CPChunkedContextMetadata,
|
|
|
|
|
chunk_idx: int,
|
|
|
|
|
toks: int,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
return kv_c_normed, k_pe
|
|
|
|
|
|
2025-06-14 22:31:16 +08:00
|
|
|
def _compute_prefill_context(
|
|
|
|
|
self,
|
2025-08-28 10:35:57 +08:00
|
|
|
q_nope: torch.Tensor,
|
|
|
|
|
q_pe: torch.Tensor,
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_c_and_k_pe_cache: tuple[torch.Tensor],
|
2025-06-14 22:31:16 +08:00
|
|
|
rope_dim: int,
|
|
|
|
|
attn_metadata: AscendMLAMetadata,
|
|
|
|
|
prefix_output: torch.Tensor,
|
|
|
|
|
prefix_lse: torch.Tensor,
|
|
|
|
|
):
|
2025-07-26 17:15:47 +08:00
|
|
|
assert len(kv_c_and_k_pe_cache) > 1
|
2025-06-14 22:31:16 +08:00
|
|
|
prefill_metadata = attn_metadata.prefill
|
|
|
|
|
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
|
|
|
|
return prefix_output, prefix_lse
|
|
|
|
|
|
|
|
|
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
2025-07-26 17:15:47 +08:00
|
|
|
cache_kv_c = kv_c_and_k_pe_cache[0]
|
|
|
|
|
cache_k_pe = kv_c_and_k_pe_cache[1]
|
|
|
|
|
num_heads = cache_k_pe.size(2)
|
|
|
|
|
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
2025-11-11 09:18:02 +08:00
|
|
|
for i in range(iters):
|
2025-11-14 08:43:37 +08:00
|
|
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
|
|
|
|
# chunk_seq_lens will be padded when pcp&dcp
|
2026-01-24 22:10:18 +08:00
|
|
|
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
2025-11-14 08:43:37 +08:00
|
|
|
seq_len = torch.stack([current_seq_len, context_seq_len])
|
2026-01-24 22:10:18 +08:00
|
|
|
context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata)
|
|
|
|
|
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
|
|
|
|
|
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
|
2025-11-14 08:43:37 +08:00
|
|
|
|
|
|
|
|
torch_npu.atb.npu_paged_cache_load(
|
|
|
|
|
cache_kv_c,
|
|
|
|
|
cache_k_pe,
|
|
|
|
|
prefill_metadata.block_table,
|
|
|
|
|
context_seq_len_npu,
|
|
|
|
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
|
|
|
|
key=kv_c_normed,
|
|
|
|
|
value=k_pe,
|
|
|
|
|
)
|
2025-12-28 10:40:45 +08:00
|
|
|
kv_c_normed, k_pe = self._reorg_kvcache(
|
|
|
|
|
kv_c_normed,
|
|
|
|
|
k_pe,
|
|
|
|
|
chunked_context=prefill_metadata.chunked_context,
|
|
|
|
|
chunk_idx=i,
|
|
|
|
|
toks=toks,
|
|
|
|
|
)
|
2025-11-14 08:43:37 +08:00
|
|
|
kv_c_normed = kv_c_normed.squeeze()
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
|
|
|
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
2025-11-14 08:43:37 +08:00
|
|
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
|
|
|
|
|
2025-12-09 18:51:00 +08:00
|
|
|
mask = attn_metadata.attn_mask
|
2025-11-14 08:43:37 +08:00
|
|
|
torch_npu.atb.npu_ring_mla(
|
|
|
|
|
q_nope=q_nope,
|
|
|
|
|
q_rope=q_pe,
|
|
|
|
|
k_nope=k_nope,
|
|
|
|
|
k_rope=k_pe,
|
|
|
|
|
value=v,
|
|
|
|
|
mask=mask,
|
|
|
|
|
seqlen=seq_len,
|
|
|
|
|
head_num=self.num_heads,
|
|
|
|
|
kv_head_num=self.num_heads,
|
|
|
|
|
pre_out=prefix_output,
|
|
|
|
|
prev_lse=prefix_lse,
|
|
|
|
|
qk_scale=self.scale,
|
|
|
|
|
kernel_type="kernel_type_high_precision",
|
|
|
|
|
mask_type="no_mask",
|
|
|
|
|
input_layout="type_bsnd",
|
|
|
|
|
calc_type="calc_type_default",
|
|
|
|
|
output=prefix_output,
|
2026-01-24 22:10:18 +08:00
|
|
|
softmax_lse=prefix_lse,
|
|
|
|
|
)
|
2025-06-14 22:31:16 +08:00
|
|
|
return prefix_output, prefix_lse
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
def _forward_prefill(
|
|
|
|
|
self,
|
2025-08-28 10:35:57 +08:00
|
|
|
q_nope: torch.Tensor,
|
|
|
|
|
q_pe: torch.Tensor,
|
|
|
|
|
k_nope: torch.Tensor,
|
2025-04-19 17:38:18 +08:00
|
|
|
k_pe: torch.Tensor,
|
2025-08-28 10:35:57 +08:00
|
|
|
value: torch.Tensor,
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_c_and_k_pe_cache: tuple[torch.Tensor],
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_metadata: AscendMLAMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert attn_metadata.prefill is not None
|
2025-07-26 17:15:47 +08:00
|
|
|
assert len(kv_c_and_k_pe_cache) > 1
|
2025-08-28 10:35:57 +08:00
|
|
|
num_tokens = q_nope.size(0)
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
|
|
|
|
|
attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device)
|
|
|
|
|
torch_npu.atb.npu_ring_mla(
|
|
|
|
|
q_nope=q_nope,
|
|
|
|
|
q_rope=q_pe,
|
|
|
|
|
k_nope=k_nope,
|
|
|
|
|
k_rope=k_pe,
|
|
|
|
|
value=value,
|
|
|
|
|
mask=attn_metadata.attn_mask,
|
|
|
|
|
seqlen=attn_metadata.prefill.query_lens,
|
|
|
|
|
head_num=self.num_heads,
|
|
|
|
|
kv_head_num=self.num_heads,
|
|
|
|
|
pre_out=None,
|
|
|
|
|
prev_lse=None,
|
|
|
|
|
qk_scale=self.scale,
|
|
|
|
|
kernel_type="kernel_type_high_precision",
|
|
|
|
|
mask_type="mask_type_triu",
|
|
|
|
|
input_layout="type_bsnd",
|
|
|
|
|
calc_type="calc_type_first_ring",
|
|
|
|
|
output=attn_output,
|
|
|
|
|
softmax_lse=attn_lse,
|
|
|
|
|
)
|
2025-12-15 12:59:18 +08:00
|
|
|
attn_output, attn_lse = self._compute_prefill_context(
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
|
|
|
|
|
)
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim])
|
2025-08-12 14:12:12 +08:00
|
|
|
return attn_output
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-08-28 10:35:57 +08:00
|
|
|
def exec_kv_decode(
|
|
|
|
|
self,
|
|
|
|
|
kv_no_split: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_cache: tuple,
|
2025-08-28 10:35:57 +08:00
|
|
|
slots: torch.Tensor,
|
|
|
|
|
):
|
|
|
|
|
B = kv_no_split.shape[0]
|
|
|
|
|
N = self.num_kv_heads
|
|
|
|
|
S = 1
|
|
|
|
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
2025-12-31 14:24:04 +08:00
|
|
|
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
2025-08-28 10:35:57 +08:00
|
|
|
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
|
|
|
|
kv_no_split,
|
2026-01-24 22:10:18 +08:00
|
|
|
self.kv_a_layernorm.weight, # type: ignore[union-attr]
|
2025-08-28 10:35:57 +08:00
|
|
|
cos,
|
|
|
|
|
sin,
|
|
|
|
|
slots.to(torch.int64),
|
|
|
|
|
kv_cache[1],
|
|
|
|
|
kv_cache[0],
|
2026-01-24 22:10:18 +08:00
|
|
|
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
|
2025-08-28 10:35:57 +08:00
|
|
|
cache_mode=cache_mode,
|
|
|
|
|
)
|
|
|
|
|
return k_pe, k_nope
|
|
|
|
|
|
|
|
|
|
def exec_kv_prefill(
|
|
|
|
|
self,
|
|
|
|
|
kv_no_split: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_cache: tuple,
|
2025-08-28 10:35:57 +08:00
|
|
|
slots: torch.Tensor,
|
|
|
|
|
):
|
|
|
|
|
B = kv_no_split.shape[0]
|
|
|
|
|
N = self.num_kv_heads
|
|
|
|
|
S = 1
|
|
|
|
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
2025-12-10 09:20:40 +08:00
|
|
|
cache_mode = "PA"
|
2025-08-28 10:35:57 +08:00
|
|
|
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
|
|
|
|
kv_no_split,
|
2026-01-24 22:10:18 +08:00
|
|
|
self.kv_a_layernorm.weight, # type: ignore[union-attr]
|
2025-08-28 10:35:57 +08:00
|
|
|
cos,
|
|
|
|
|
sin,
|
|
|
|
|
slots.to(torch.int64),
|
|
|
|
|
kv_cache[1],
|
|
|
|
|
kv_cache[0],
|
2026-01-24 22:10:18 +08:00
|
|
|
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
|
2025-08-28 10:35:57 +08:00
|
|
|
cache_mode=cache_mode,
|
|
|
|
|
is_output_kv=True,
|
|
|
|
|
)
|
|
|
|
|
return k_pe, k_nope
|
|
|
|
|
|
|
|
|
|
def rope_single(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
cos: torch.Tensor,
|
|
|
|
|
sin: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
B, N, D = x.shape
|
|
|
|
|
S = 1
|
|
|
|
|
x = x.view(B, N, S, D)
|
|
|
|
|
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
|
|
|
|
return x.view(B, N, D)
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
def _forward_decode(
|
|
|
|
|
self,
|
|
|
|
|
q_nope: torch.Tensor,
|
|
|
|
|
q_pe: torch.Tensor,
|
2025-05-12 19:14:07 +08:00
|
|
|
k_nope: torch.Tensor,
|
|
|
|
|
k_pe: torch.Tensor,
|
2025-08-28 10:35:57 +08:00
|
|
|
block_size: int,
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_metadata: AscendMLAMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
decode_meta = attn_metadata.decode
|
|
|
|
|
assert decode_meta is not None
|
2025-07-26 17:15:47 +08:00
|
|
|
num_tokens = q_nope.size(0)
|
2025-08-28 10:35:57 +08:00
|
|
|
# shape of knope/k_pe for npu graph mode should be:
|
|
|
|
|
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
|
|
|
|
actual_seq_lengths = None
|
2025-12-31 14:24:04 +08:00
|
|
|
if self.enable_kv_nz:
|
|
|
|
|
nz_fmt_last_dim = 16
|
2026-01-24 22:10:18 +08:00
|
|
|
k_nope = k_nope.view(
|
|
|
|
|
-1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim
|
|
|
|
|
)
|
|
|
|
|
k_pe = k_pe.view(
|
|
|
|
|
-1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim
|
|
|
|
|
)
|
2025-12-31 14:24:04 +08:00
|
|
|
else:
|
2026-01-24 22:10:18 +08:00
|
|
|
k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank)
|
|
|
|
|
k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim)
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-12-31 14:24:04 +08:00
|
|
|
attn_output_shape: tuple | None = None
|
2026-01-24 22:10:18 +08:00
|
|
|
if (
|
|
|
|
|
attn_metadata.attn_state
|
|
|
|
|
in [
|
2025-09-19 14:05:08 +08:00
|
|
|
AscendAttentionState.SpecDecoding,
|
2025-10-17 20:19:56 +08:00
|
|
|
AscendAttentionState.ChunkedPrefill,
|
|
|
|
|
AscendAttentionState.DecodeOnly,
|
2026-01-24 22:10:18 +08:00
|
|
|
]
|
|
|
|
|
and self.speculative_config is not None
|
|
|
|
|
):
|
2025-12-26 22:03:46 +08:00
|
|
|
# The right part layout indicates the layout of the attention
|
|
|
|
|
# output. It is set to NTD to avoid the need for a transpose
|
|
|
|
|
# operation after attention.
|
|
|
|
|
input_layout = "TND_NTD"
|
2025-10-20 20:04:04 +08:00
|
|
|
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
|
2025-12-31 14:24:04 +08:00
|
|
|
# Input shape: [num_tokens, num_heads, dim]
|
2025-10-20 20:04:04 +08:00
|
|
|
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
|
2025-08-28 10:35:57 +08:00
|
|
|
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
2025-12-31 14:24:04 +08:00
|
|
|
# Output shape: [num_heads, num_tokens, dim]
|
|
|
|
|
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
|
2025-08-28 10:35:57 +08:00
|
|
|
sparse_mode = 3
|
2026-01-07 17:09:52 +08:00
|
|
|
attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
2025-08-28 10:35:57 +08:00
|
|
|
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
|
|
|
|
else:
|
2025-12-26 22:03:46 +08:00
|
|
|
# The output layout is set to NBSD to eliminate the need for a
|
|
|
|
|
# transpose operation after attention.
|
2025-12-31 14:24:04 +08:00
|
|
|
if self.enable_kv_nz:
|
|
|
|
|
# Input shape: [num_tokens, seq_len, num_heads, dim]
|
|
|
|
|
input_layout = "BSND_NBSD"
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous()
|
2025-12-31 14:24:04 +08:00
|
|
|
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
|
|
|
|
else:
|
|
|
|
|
# Input shape: [num_tokens, num_heads, seq_len, dim]
|
|
|
|
|
input_layout = "BNSD_NBSD"
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1).contiguous()
|
2025-12-31 14:24:04 +08:00
|
|
|
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
|
|
|
|
# Output shape: [num_heads, num_tokens, seq_len, dim]
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank)
|
2025-08-28 10:35:57 +08:00
|
|
|
sparse_mode = 0
|
2026-01-07 17:09:52 +08:00
|
|
|
attn_mask = None
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-10-10 16:31:20 +08:00
|
|
|
common_kwargs = {
|
2026-01-24 22:10:18 +08:00
|
|
|
"query_rope": q_pe,
|
|
|
|
|
"key_rope": k_pe,
|
|
|
|
|
"num_heads": self.num_heads,
|
|
|
|
|
"num_key_value_heads": self.num_kv_heads,
|
|
|
|
|
"input_layout": input_layout,
|
|
|
|
|
"atten_mask": attn_mask,
|
|
|
|
|
"sparse_mode": sparse_mode,
|
|
|
|
|
"scale": self.scale,
|
|
|
|
|
"antiquant_mode": 0,
|
|
|
|
|
"antiquant_scale": None,
|
|
|
|
|
"block_table": decode_meta.block_table,
|
|
|
|
|
"block_size": block_size,
|
2025-10-10 16:31:20 +08:00
|
|
|
"actual_seq_lengths": actual_seq_lengths,
|
|
|
|
|
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
|
|
|
|
}
|
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
2025-12-29 09:54:51 +08:00
|
|
|
if forward_context.is_draft_model:
|
|
|
|
|
graph_params = get_draft_graph_params()
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
else:
|
|
|
|
|
graph_params = get_graph_params()
|
2025-10-10 16:31:20 +08:00
|
|
|
if forward_context.capturing:
|
|
|
|
|
stream = torch_npu.npu.current_stream()
|
|
|
|
|
|
|
|
|
|
event = torch.npu.ExternalEvent()
|
|
|
|
|
event.wait(stream)
|
|
|
|
|
event.reset(stream)
|
|
|
|
|
graph_params.events[num_tokens].append(event)
|
|
|
|
|
|
|
|
|
|
workspace = graph_params.workspaces.get(num_tokens)
|
|
|
|
|
if workspace is None:
|
|
|
|
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope, k_nope, k_nope, **common_kwargs
|
|
|
|
|
)
|
2025-12-29 09:54:51 +08:00
|
|
|
if forward_context.is_draft_model:
|
|
|
|
|
update_draft_graph_params_workspaces(num_tokens, workspace)
|
|
|
|
|
else:
|
|
|
|
|
update_graph_params_workspaces(num_tokens, workspace)
|
2025-10-10 16:31:20 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device)
|
|
|
|
|
softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device)
|
2025-10-10 16:31:20 +08:00
|
|
|
|
|
|
|
|
graph_params.attn_params[num_tokens].append(
|
2026-01-24 22:10:18 +08:00
|
|
|
(
|
|
|
|
|
weak_ref_tensors(q_nope),
|
|
|
|
|
weak_ref_tensors(k_nope),
|
|
|
|
|
weak_ref_tensors(q_pe),
|
|
|
|
|
weak_ref_tensors(k_pe),
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.num_kv_heads,
|
|
|
|
|
input_layout,
|
|
|
|
|
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
|
|
|
|
|
sparse_mode,
|
|
|
|
|
self.scale,
|
|
|
|
|
decode_meta.block_table,
|
|
|
|
|
block_size,
|
|
|
|
|
decode_meta.seq_lens_list,
|
|
|
|
|
actual_seq_lengths,
|
|
|
|
|
weak_ref_tensors(attn_output),
|
|
|
|
|
weak_ref_tensors(softmax_lse),
|
|
|
|
|
)
|
|
|
|
|
)
|
2025-10-10 16:31:20 +08:00
|
|
|
|
|
|
|
|
torch.npu.graph_task_group_begin(stream)
|
|
|
|
|
torch_npu.npu_fused_infer_attention_score.out(
|
2026-01-24 22:10:18 +08:00
|
|
|
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
|
|
|
|
|
)
|
2025-10-10 16:31:20 +08:00
|
|
|
handle = torch.npu.graph_task_group_end(stream)
|
|
|
|
|
graph_params.handles[num_tokens].append(handle)
|
|
|
|
|
else:
|
2026-01-24 22:10:18 +08:00
|
|
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs)
|
2025-10-25 15:53:01 +08:00
|
|
|
|
|
|
|
|
return self._v_up_proj(attn_output)
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-12-28 10:40:45 +08:00
|
|
|
def reorg_decode_q(self, decode_q_nope, decode_q_pe):
|
|
|
|
|
return decode_q_nope, decode_q_pe
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def _mla_preprocess_only_decode(self, hidden_states, kv_cache, attn_metadata):
|
2025-10-15 10:34:25 +08:00
|
|
|
bsz = attn_metadata.num_decode_tokens
|
|
|
|
|
hidden_states = hidden_states[:bsz]
|
|
|
|
|
|
|
|
|
|
cos_shape = attn_metadata.decode.cos.shape
|
|
|
|
|
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
|
|
|
|
|
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
|
|
|
|
|
|
|
|
|
|
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
|
|
|
|
|
decode_q_nope = torch.empty(
|
2026-01-24 22:10:18 +08:00
|
|
|
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
|
2025-10-15 10:34:25 +08:00
|
|
|
dtype=hidden_states.dtype,
|
|
|
|
|
device=hidden_states.device,
|
|
|
|
|
)
|
|
|
|
|
decode_q_pe = torch.empty(
|
2026-01-24 22:10:18 +08:00
|
|
|
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
|
2025-10-15 10:34:25 +08:00
|
|
|
dtype=hidden_states.dtype,
|
|
|
|
|
device=hidden_states.device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
torch.ops._C_ascend.mla_preprocess(
|
|
|
|
|
hidden_states,
|
|
|
|
|
self.wd_qkv,
|
|
|
|
|
self.deq_scale_qkv,
|
|
|
|
|
self.gamma1,
|
|
|
|
|
self.beta1,
|
|
|
|
|
self.wu_q,
|
|
|
|
|
self.qb_deq_scl,
|
|
|
|
|
self.gamma2,
|
|
|
|
|
cos,
|
|
|
|
|
sin,
|
|
|
|
|
self.W_UK_T,
|
|
|
|
|
decode_k_nope,
|
|
|
|
|
decode_k_pe,
|
2026-01-08 23:49:23 +08:00
|
|
|
attn_metadata.slot_mapping[:bsz],
|
2025-10-15 10:34:25 +08:00
|
|
|
quant_scale0=self.quant_scale0,
|
|
|
|
|
quant_offset0=self.quant_offset0,
|
|
|
|
|
bias0=self.quant_bias_qkv,
|
|
|
|
|
quant_scale1=self.quant_scale1,
|
|
|
|
|
quant_offset1=self.quant_offset1,
|
|
|
|
|
bias1=self.qb_qt_bias,
|
|
|
|
|
ctkv_scale=self.ctkv_scale,
|
|
|
|
|
q_nope_scale=self.q_nope_scale,
|
2025-12-31 14:24:04 +08:00
|
|
|
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
|
2025-10-15 10:34:25 +08:00
|
|
|
quant_mode="per_tensor_quant_asymm",
|
|
|
|
|
q_out0=decode_q_nope,
|
|
|
|
|
kv_cache_out0=decode_k_nope,
|
|
|
|
|
q_out1=decode_q_pe,
|
|
|
|
|
kv_cache_out1=decode_k_pe,
|
2025-12-10 20:45:07 +08:00
|
|
|
enable_inner_out=False,
|
2026-01-24 22:10:18 +08:00
|
|
|
inner_out=torch.tensor([], device=hidden_states.device),
|
|
|
|
|
)
|
|
|
|
|
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
|
2025-10-15 10:34:25 +08:00
|
|
|
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe)
|
2025-12-28 10:40:45 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
2025-10-15 10:34:25 +08:00
|
|
|
return decode_preprocess_res, None
|
|
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
|
2025-12-28 10:40:45 +08:00
|
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
|
|
|
prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens]
|
|
|
|
|
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
|
|
|
|
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :]
|
|
|
|
|
prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim]
|
2025-12-28 10:40:45 +08:00
|
|
|
cos = attn_metadata.prefill.cos
|
|
|
|
|
sin = attn_metadata.prefill.sin
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
|
2025-12-28 10:40:45 +08:00
|
|
|
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
2025-12-31 15:09:01 +08:00
|
|
|
if self.is_kv_producer:
|
|
|
|
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
2025-12-31 15:09:01 +08:00
|
|
|
if self.is_kv_producer:
|
|
|
|
|
attn_metadata.reshape_cache_event.record()
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_k_nope, prefill_value = (
|
|
|
|
|
self.kv_b_proj(prefill_k_c_normed)[0]
|
|
|
|
|
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
|
|
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
|
)
|
|
|
|
|
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], self.num_kv_heads, -1)
|
2025-12-28 10:40:45 +08:00
|
|
|
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
|
2026-01-24 22:10:18 +08:00
|
|
|
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value)
|
2025-12-28 10:40:45 +08:00
|
|
|
|
|
|
|
|
def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata):
|
|
|
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
|
decode_q_c = q_c[:num_decode_tokens]
|
|
|
|
|
cos = attn_metadata.decode.cos
|
|
|
|
|
sin = attn_metadata.decode.sin
|
2026-01-24 22:10:18 +08:00
|
|
|
decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c)
|
2025-12-28 10:40:45 +08:00
|
|
|
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
|
|
|
|
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1]
|
|
|
|
|
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
2026-01-24 22:10:18 +08:00
|
|
|
decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
|
|
|
|
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
2025-12-28 10:40:45 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv):
|
2025-08-28 10:35:57 +08:00
|
|
|
# MLA Preprocess:
|
2025-10-30 17:06:38 +08:00
|
|
|
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
|
|
|
|
|
# or
|
|
|
|
|
# Perform kv_a_proj_with_mqa to obtain kv_no_split
|
|
|
|
|
# 2. If need_gather_q_kv, perform all_gather.
|
|
|
|
|
# 3. Preprocess decode tokens, write kv cache and get:
|
2025-08-28 10:35:57 +08:00
|
|
|
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
2025-10-30 17:06:38 +08:00
|
|
|
# 4. Preprocess prefill tokens, write kv cache and get:
|
2025-08-28 10:35:57 +08:00
|
|
|
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
|
|
|
|
has_decode = attn_metadata.num_decodes > 0
|
|
|
|
|
has_prefill = attn_metadata.num_prefills > 0
|
2025-10-20 15:31:34 +08:00
|
|
|
if self.fused_qkv_a_proj is not None:
|
2026-02-10 14:14:37 +08:00
|
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
|
|
|
|
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
|
|
|
|
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
|
2026-01-24 22:10:18 +08:00
|
|
|
)
|
2025-10-20 15:31:34 +08:00
|
|
|
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
|
|
|
|
q_c, kv_no_split = qkv_lora.split(
|
|
|
|
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
|
|
|
|
dim=-1,
|
|
|
|
|
)
|
2026-01-24 22:10:18 +08:00
|
|
|
q_c = self.q_a_layernorm(q_c) # type: ignore[misc]
|
2025-11-13 11:28:09 +08:00
|
|
|
# allgather need contiguous data
|
|
|
|
|
kv_no_split = kv_no_split.contiguous()
|
2025-08-28 10:35:57 +08:00
|
|
|
else:
|
|
|
|
|
q_c = hidden_states
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] # type: ignore[misc]
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-10-15 19:36:32 +08:00
|
|
|
# Process for Flash Comm V1
|
2026-01-24 22:10:18 +08:00
|
|
|
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv)
|
|
|
|
|
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(kv_no_split.contiguous(), need_gather_q_kv)
|
2025-10-20 15:31:34 +08:00
|
|
|
|
2026-01-24 22:10:18 +08:00
|
|
|
for layer in self.layer_sharding_kwargs or []:
|
2026-01-08 09:05:02 +08:00
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
reach_layer_for_shard_weight_series(layer)
|
2025-12-11 12:43:04 +08:00
|
|
|
|
2025-08-28 10:35:57 +08:00
|
|
|
decode_preprocess_res = None
|
|
|
|
|
prefill_preprocess_res = None
|
2025-09-23 14:25:05 +08:00
|
|
|
if has_prefill:
|
|
|
|
|
wait_for_kv_layer_from_connector(layer_name)
|
2025-08-28 10:35:57 +08:00
|
|
|
# Preprocess for decode tokens
|
|
|
|
|
if has_decode:
|
2026-01-24 22:10:18 +08:00
|
|
|
decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata)
|
2025-08-28 10:35:57 +08:00
|
|
|
# Preprocess for prefill tokens
|
|
|
|
|
if has_prefill:
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata)
|
2025-08-28 10:35:57 +08:00
|
|
|
return decode_preprocess_res, prefill_preprocess_res
|
2025-04-19 17:38:18 +08:00
|
|
|
|
2025-12-28 10:40:45 +08:00
|
|
|
def get_num_actual_tokens(self, attn_metadata: M):
|
|
|
|
|
return attn_metadata.num_actual_tokens
|
|
|
|
|
|
2026-02-05 19:31:17 +08:00
|
|
|
def forward_mha(
|
|
|
|
|
self,
|
|
|
|
|
layer_name: str,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
kv_cache: tuple[torch.Tensor],
|
|
|
|
|
attn_metadata: M,
|
|
|
|
|
need_gather_q_kv: bool = False,
|
|
|
|
|
output: torch.Tensor | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
raise NotImplementedError("forward_mha is not supported for MLA attention. Use forward() instead.")
|
|
|
|
|
|
|
|
|
|
def forward_mqa(
|
|
|
|
|
self,
|
|
|
|
|
layer_name: str,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
kv_cache: tuple[torch.Tensor],
|
|
|
|
|
attn_metadata: M,
|
|
|
|
|
need_gather_q_kv: bool = False,
|
|
|
|
|
output: torch.Tensor | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
raise NotImplementedError("forward_mqa is not supported for MLA attention. Use forward() instead.")
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
def forward(
|
|
|
|
|
self,
|
2025-09-23 14:25:05 +08:00
|
|
|
layer_name,
|
2025-08-28 10:35:57 +08:00
|
|
|
hidden_states: torch.Tensor, # query in unified attn
|
2026-01-24 22:10:18 +08:00
|
|
|
kv_cache: tuple[torch.Tensor],
|
2025-04-19 17:38:18 +08:00
|
|
|
attn_metadata: M,
|
2025-08-28 10:35:57 +08:00
|
|
|
need_gather_q_kv: bool = False,
|
2026-01-24 22:10:18 +08:00
|
|
|
output: torch.Tensor | None = None,
|
2025-04-19 17:38:18 +08:00
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
|
if attn_metadata is None:
|
|
|
|
|
# Profiling run.
|
2026-01-24 22:10:18 +08:00
|
|
|
for layer in self.layer_sharding_kwargs or []:
|
2026-01-08 09:05:02 +08:00
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
reach_layer_for_shard_weight_series(layer)
|
[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
1. refactor deepseek to the latest code arch as of
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
|
|
|
return output.fill_(0)
|
2025-12-23 14:30:50 +08:00
|
|
|
|
|
|
|
|
forward_context = get_forward_context()
|
2025-12-28 10:40:45 +08:00
|
|
|
num_actual_tokens = self.get_num_actual_tokens(attn_metadata)
|
2026-01-24 22:10:18 +08:00
|
|
|
assert (
|
|
|
|
|
attn_metadata.num_decodes is not None
|
|
|
|
|
and attn_metadata.num_prefills is not None
|
|
|
|
|
and attn_metadata.num_decode_tokens is not None
|
|
|
|
|
)
|
2025-12-23 14:30:50 +08:00
|
|
|
|
|
|
|
|
has_prefill = attn_metadata.num_prefills > 0
|
2025-04-19 17:38:18 +08:00
|
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
2025-08-21 14:02:30 +08:00
|
|
|
# Inputs and outputs may be padded for CUDA graphs
|
|
|
|
|
output_padded = output
|
2026-01-24 22:10:18 +08:00
|
|
|
o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim)
|
|
|
|
|
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
2025-08-28 10:35:57 +08:00
|
|
|
|
|
|
|
|
# MLA Preprocess
|
2026-01-24 22:10:18 +08:00
|
|
|
if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
2025-12-01 20:44:11 +08:00
|
|
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
2026-01-24 22:10:18 +08:00
|
|
|
hidden_states.contiguous(), need_gather_q_kv
|
|
|
|
|
)
|
2025-12-28 10:40:45 +08:00
|
|
|
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
|
2026-01-24 22:10:18 +08:00
|
|
|
hidden_states, kv_cache, attn_metadata
|
|
|
|
|
)
|
2025-10-15 10:34:25 +08:00
|
|
|
else:
|
|
|
|
|
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
2026-01-24 22:10:18 +08:00
|
|
|
layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv
|
|
|
|
|
)
|
2025-08-28 10:35:57 +08:00
|
|
|
if decode_preprocess_res is not None:
|
|
|
|
|
# MLA Preprocess for decoding
|
2026-01-24 22:10:18 +08:00
|
|
|
output_decode = self._forward_decode(
|
|
|
|
|
decode_preprocess_res.ql_nope,
|
|
|
|
|
decode_preprocess_res.q_pe,
|
|
|
|
|
decode_preprocess_res.k_nope,
|
|
|
|
|
decode_preprocess_res.k_pe,
|
|
|
|
|
kv_cache[0].shape[1],
|
|
|
|
|
attn_metadata,
|
|
|
|
|
)
|
2025-10-25 15:53:01 +08:00
|
|
|
|
|
|
|
|
o_proj_input[:num_decode_tokens] = output_decode
|
2025-06-07 16:46:58 +08:00
|
|
|
|
2025-08-28 10:35:57 +08:00
|
|
|
if prefill_preprocess_res is not None:
|
|
|
|
|
# FIX: aicore move should be also placed on the comm stream in dbo,
|
|
|
|
|
# otherwise it may affect the accuracy
|
|
|
|
|
# TODO: use an elegant way to overlap
|
2025-12-15 12:59:18 +08:00
|
|
|
output_prefill = self._forward_prefill(
|
2026-01-24 22:10:18 +08:00
|
|
|
prefill_preprocess_res.q_nope,
|
|
|
|
|
prefill_preprocess_res.q_pe,
|
|
|
|
|
prefill_preprocess_res.k_nope,
|
|
|
|
|
prefill_preprocess_res.k_pe,
|
|
|
|
|
prefill_preprocess_res.value,
|
|
|
|
|
kv_cache,
|
|
|
|
|
attn_metadata,
|
|
|
|
|
)
|
2025-10-25 15:53:01 +08:00
|
|
|
|
|
|
|
|
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
|
2025-08-28 10:35:57 +08:00
|
|
|
# O proj
|
2026-02-10 14:14:37 +08:00
|
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
|
|
|
|
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
2026-01-24 22:10:18 +08:00
|
|
|
inputs=self.o_proj.weight,
|
|
|
|
|
dependency=o_proj_input,
|
|
|
|
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
2026-02-10 14:14:37 +08:00
|
|
|
linear_layer=self.o_proj,
|
2026-01-24 22:10:18 +08:00
|
|
|
)
|
|
|
|
|
output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0]
|
2025-08-28 10:35:57 +08:00
|
|
|
|
2025-08-12 14:12:12 +08:00
|
|
|
del o_proj_input
|
2025-09-23 14:25:05 +08:00
|
|
|
|
|
|
|
|
if has_prefill:
|
|
|
|
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
2025-05-16 12:14:55 +08:00
|
|
|
return output_padded
|