[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390)

## Summary

Flash Comm V1 (flashcomm1) was previously blocked for all VL models.

**Root cause:** For VL models, `inputs_embeds` at layer 0 originates
from the vision encoder as a full `[N, H]` tensor — it has **not** been
reduce-scattered across TP ranks. The original MLA forward path assumed
inputs were already scattered, producing wrong output shapes under TP >
1.

**Fix:**
- Detect at init time (statically, not via runtime shape checks) whether
a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo
treats the branch as a constant.
- In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1
+ is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate
output as `[N//tp_size, H]`.
- Remove the platform-level assertion that prevented VL models from
enabling Flash Comm V1.

**Other improvements:**
- `is_vl_model()` now uses vllm's canonical detection (`hf_config is not
hf_text_config`) instead of fragile key-name checks, with the old checks
kept as fallback.
- Added `parse_layer_idx(prefix)` utility.
- Added `maybe_chunk_residual` call in `AscendRMSNorm` before the
add-rms-norm op.
- Removed unnecessary CPU/fp32 round-trip in
  `AscendLearnable2DInterpPosEmbDivided_fixed.forward()`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
Cao Yi
2026-03-22 21:05:28 +08:00
committed by GitHub
parent 9d0b7c8e98
commit 9e2965bae2
4 changed files with 49 additions and 17 deletions

View File

@@ -33,6 +33,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.utils import is_vl_model, parse_layer_idx
class IndexerWrapper(nn.Module):
@@ -134,7 +135,16 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
self.mla_attn.process_weights_after_loading = wrapped_process_weights
compilation_config = get_current_vllm_config().compilation_config
# For VL models (e.g. Kimi K2.5), inputs_embeds at layer 0 comes from
# the vision encoder as full [N, H] — it has NOT been reduce-scattered.
# We detect this statically at init time (not at runtime via shape checks,
# which break graph-mode compilation) so the branch is a constant to dynamo.
vllm_config = get_current_vllm_config()
_is_vl = is_vl_model(vllm_config)
_layer_idx = parse_layer_idx(prefix)
self.is_vl_first_layer = bool(_is_vl and _layer_idx == 0)
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@@ -146,12 +156,18 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
kv_cache: torch.Tensor | None = None,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
output_shape = hidden_states.shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
hidden_dim = hidden_states.shape[-1]
if _EXTRA_CTX.flash_comm_v1_enabled and self.tp_size > 1 and self.is_vl_first_layer:
need_gather_q_kv = False
n_out = hidden_states.shape[0] // self.tp_size
output = torch.empty((n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
else:
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix)
output = output.view(-1, output_shape[-1])
output = output.view(-1, hidden_dim)
return output