[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390)

## Summary

Flash Comm V1 (flashcomm1) was previously blocked for all VL models.

**Root cause:** For VL models, `inputs_embeds` at layer 0 originates
from the vision encoder as a full `[N, H]` tensor — it has **not** been
reduce-scattered across TP ranks. The original MLA forward path assumed
inputs were already scattered, producing wrong output shapes under TP >
1.

**Fix:**
- Detect at init time (statically, not via runtime shape checks) whether
a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo
treats the branch as a constant.
- In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1
+ is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate
output as `[N//tp_size, H]`.
- Remove the platform-level assertion that prevented VL models from
enabling Flash Comm V1.

**Other improvements:**
- `is_vl_model()` now uses vllm's canonical detection (`hf_config is not
hf_text_config`) instead of fragile key-name checks, with the old checks
kept as fallback.
- Added `parse_layer_idx(prefix)` utility.
- Added `maybe_chunk_residual` call in `AscendRMSNorm` before the
add-rms-norm op.
- Removed unnecessary CPU/fp32 round-trip in
  `AscendLearnable2DInterpPosEmbDivided_fixed.forward()`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
Cao Yi
2026-03-22 21:05:28 +08:00
committed by GitHub
parent 9d0b7c8e98
commit 9e2965bae2
4 changed files with 49 additions and 17 deletions

View File

@@ -68,6 +68,7 @@ class AscendRMSNorm(RMSNorm):
import torch_npu import torch_npu
if residual is not None: if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
if enable_custom_op(): if enable_custom_op():
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
x, residual, self.weight, self.bias, self.variance_epsilon x, residual, self.weight, self.bias, self.variance_epsilon

View File

@@ -33,6 +33,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.utils import is_vl_model, parse_layer_idx
class IndexerWrapper(nn.Module): class IndexerWrapper(nn.Module):
@@ -134,7 +135,16 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
self.mla_attn.process_weights_after_loading = wrapped_process_weights self.mla_attn.process_weights_after_loading = wrapped_process_weights
compilation_config = get_current_vllm_config().compilation_config # For VL models (e.g. Kimi K2.5), inputs_embeds at layer 0 comes from
# the vision encoder as full [N, H] — it has NOT been reduce-scattered.
# We detect this statically at init time (not at runtime via shape checks,
# which break graph-mode compilation) so the branch is a constant to dynamo.
vllm_config = get_current_vllm_config()
_is_vl = is_vl_model(vllm_config)
_layer_idx = parse_layer_idx(prefix)
self.is_vl_first_layer = bool(_is_vl and _layer_idx == 0)
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
@@ -146,12 +156,18 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
kv_cache: torch.Tensor | None = None, kv_cache: torch.Tensor | None = None,
attn_metadata: AttentionMetadata | None = None, attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_dim = hidden_states.shape[-1]
if _EXTRA_CTX.flash_comm_v1_enabled and self.tp_size > 1 and self.is_vl_first_layer:
need_gather_q_kv = False
n_out = hidden_states.shape[0] // self.tp_size
output = torch.empty((n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
else:
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
output_shape = hidden_states.shape output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix) torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix)
output = output.view(-1, output_shape[-1]) output = output.view(-1, hidden_dim)
return output return output

View File

@@ -42,7 +42,6 @@ from vllm_ascend.utils import (
flashcomm2_enable, flashcomm2_enable,
get_ascend_device_type, get_ascend_device_type,
is_moe_model, is_moe_model,
is_vl_model,
refresh_block_size, refresh_block_size,
update_aclgraph_sizes, update_aclgraph_sizes,
update_cudagraph_capture_sizes, update_cudagraph_capture_sizes,
@@ -420,11 +419,6 @@ class NPUPlatform(Platform):
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
if enable_sp(vllm_config): if enable_sp(vllm_config):
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
assert vllm_config.parallel_config.tensor_parallel_size > 1, ( assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 is only supported when tp_size > 1." "Flash Comm v1 is only supported when tp_size > 1."
) )

View File

@@ -23,6 +23,7 @@ import atexit
import functools import functools
import math import math
import os import os
import re
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
@@ -842,15 +843,29 @@ def _is_contain_expert(config: Any):
def is_vl_model(vllm_config: VllmConfig): def is_vl_model(vllm_config: VllmConfig):
"""Checks if the model is a VL model by config""" """Checks if the model is a VL model by config.
Uses the same criterion as vllm itself (model_config.py): a model is
multimodal when its top-level hf_config differs from its hf_text_config
(i.e. there is a separate vision sub-config). The legacy key-name checks
are kept as fallbacks for configs that override get_text_config() to return
self (rare but possible).
"""
global _IS_VL_MODEL global _IS_VL_MODEL
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config: if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
hf_config = vllm_config.model_config.hf_config.to_dict() model_config = vllm_config.model_config
if "thinker_config" in hf_config: # Primary: vllm's own VL detection — hf_config is the top-level
# Qwen-Omni-thinker models # (multimodal) config; hf_text_config is the language-model sub-config.
# They are the same object for pure-text models.
if model_config.hf_config is not model_config.hf_text_config:
_IS_VL_MODEL = True _IS_VL_MODEL = True
else: else:
_IS_VL_MODEL = "vision_config" in hf_config # Fallback: check well-known config keys
hf_config = model_config.hf_config.to_dict()
if "thinker_config" in hf_config or "vision_config" in hf_config:
_IS_VL_MODEL = True
else:
_IS_VL_MODEL = False
return _IS_VL_MODEL return _IS_VL_MODEL
@@ -1244,3 +1259,9 @@ def trans_nd_to_nz(cache_tensor: torch.Tensor):
cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0])
cache_tensor = cache_tensor.permute(*array_trans) cache_tensor = cache_tensor.permute(*array_trans)
return cache_tensor return cache_tensor
def parse_layer_idx(prefix: str) -> int | None:
"""Extract the layer index from a module prefix string like 'model.layers.0.self_attn'."""
match = re.search(r"layers\.(\d+)", prefix)
return int(match.group(1)) if match else None