From 9e2965bae2403d1b9a18fdcd03d94704e1f9f01c Mon Sep 17 00:00:00 2001 From: Cao Yi Date: Sun, 22 Mar 2026 21:05:28 +0800 Subject: [PATCH] [Feature] Support Flash Comm V1 for VL models (with MLA) (#7390) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Flash Comm V1 (flashcomm1) was previously blocked for all VL models. **Root cause:** For VL models, `inputs_embeds` at layer 0 originates from the vision encoder as a full `[N, H]` tensor — it has **not** been reduce-scattered across TP ranks. The original MLA forward path assumed inputs were already scattered, producing wrong output shapes under TP > 1. **Fix:** - Detect at init time (statically, not via runtime shape checks) whether a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo treats the branch as a constant. - In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1 + is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate output as `[N//tp_size, H]`. - Remove the platform-level assertion that prevented VL models from enabling Flash Comm V1. **Other improvements:** - `is_vl_model()` now uses vllm's canonical detection (`hf_config is not hf_text_config`) instead of fragile key-name checks, with the old checks kept as fallback. - Added `parse_layer_idx(prefix)` utility. - Added `maybe_chunk_residual` call in `AscendRMSNorm` before the add-rms-norm op. - Removed unnecessary CPU/fp32 round-trip in `AscendLearnable2DInterpPosEmbDivided_fixed.forward()`. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: SlightwindSec Co-authored-by: LoganJane --- vllm_ascend/ops/layernorm.py | 1 + vllm_ascend/ops/mla.py | 28 ++++++++++++++++++++++------ vllm_ascend/platform.py | 6 ------ vllm_ascend/utils.py | 31 ++++++++++++++++++++++++++----- 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index d998ddab..71ba7e73 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -68,6 +68,7 @@ class AscendRMSNorm(RMSNorm): import torch_npu if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( x, residual, self.weight, self.bias, self.variance_epsilon diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index de2e2fa4..4da2507e 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -33,6 +33,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX +from vllm_ascend.utils import is_vl_model, parse_layer_idx class IndexerWrapper(nn.Module): @@ -134,7 +135,16 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): self.mla_attn.process_weights_after_loading = wrapped_process_weights - compilation_config = get_current_vllm_config().compilation_config + # For VL models (e.g. Kimi K2.5), inputs_embeds at layer 0 comes from + # the vision encoder as full [N, H] — it has NOT been reduce-scattered. + # We detect this statically at init time (not at runtime via shape checks, + # which break graph-mode compilation) so the branch is a constant to dynamo. + vllm_config = get_current_vllm_config() + _is_vl = is_vl_model(vllm_config) + _layer_idx = parse_layer_idx(prefix) + self.is_vl_first_layer = bool(_is_vl and _layer_idx == 0) + + compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @@ -146,12 +156,18 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): kv_cache: torch.Tensor | None = None, attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: - need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled - output_shape = hidden_states.shape - # FIXME: This does not seem right, should make sure the buffer is fixed - output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + hidden_dim = hidden_states.shape[-1] + + if _EXTRA_CTX.flash_comm_v1_enabled and self.tp_size > 1 and self.is_vl_first_layer: + need_gather_q_kv = False + n_out = hidden_states.shape[0] // self.tp_size + output = torch.empty((n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) + else: + need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled + output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix) - output = output.view(-1, output_shape[-1]) + output = output.view(-1, hidden_dim) return output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 0dda6be2..bf460a83 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -42,7 +42,6 @@ from vllm_ascend.utils import ( flashcomm2_enable, get_ascend_device_type, is_moe_model, - is_vl_model, refresh_block_size, update_aclgraph_sizes, update_cudagraph_capture_sizes, @@ -420,11 +419,6 @@ class NPUPlatform(Platform): vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size if enable_sp(vllm_config): - assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \ - Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \ - For optimal performance with VL models, we recommend enabling Sequence Parallelism \ - via --compilation-config '{"pass_config": {"enable_sp": true}}'.""" - assert vllm_config.parallel_config.tensor_parallel_size > 1, ( "Flash Comm v1 is only supported when tp_size > 1." ) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1d341899..0f310e36 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -23,6 +23,7 @@ import atexit import functools import math import os +import re from contextlib import nullcontext from enum import Enum from functools import lru_cache @@ -842,15 +843,29 @@ def _is_contain_expert(config: Any): def is_vl_model(vllm_config: VllmConfig): - """Checks if the model is a VL model by config""" + """Checks if the model is a VL model by config. + + Uses the same criterion as vllm itself (model_config.py): a model is + multimodal when its top-level hf_config differs from its hf_text_config + (i.e. there is a separate vision sub-config). The legacy key-name checks + are kept as fallbacks for configs that override get_text_config() to return + self (rare but possible). + """ global _IS_VL_MODEL if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config: - hf_config = vllm_config.model_config.hf_config.to_dict() - if "thinker_config" in hf_config: - # Qwen-Omni-thinker models + model_config = vllm_config.model_config + # Primary: vllm's own VL detection — hf_config is the top-level + # (multimodal) config; hf_text_config is the language-model sub-config. + # They are the same object for pure-text models. + if model_config.hf_config is not model_config.hf_text_config: _IS_VL_MODEL = True else: - _IS_VL_MODEL = "vision_config" in hf_config + # Fallback: check well-known config keys + hf_config = model_config.hf_config.to_dict() + if "thinker_config" in hf_config or "vision_config" in hf_config: + _IS_VL_MODEL = True + else: + _IS_VL_MODEL = False return _IS_VL_MODEL @@ -1244,3 +1259,9 @@ def trans_nd_to_nz(cache_tensor: torch.Tensor): cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) cache_tensor = cache_tensor.permute(*array_trans) return cache_tensor + + +def parse_layer_idx(prefix: str) -> int | None: + """Extract the layer index from a module prefix string like 'model.layers.0.self_attn'.""" + match = re.search(r"layers\.(\d+)", prefix) + return int(match.group(1)) if match else None