[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390)

## Summary

Flash Comm V1 (flashcomm1) was previously blocked for all VL models.

**Root cause:** For VL models, `inputs_embeds` at layer 0 originates
from the vision encoder as a full `[N, H]` tensor — it has **not** been
reduce-scattered across TP ranks. The original MLA forward path assumed
inputs were already scattered, producing wrong output shapes under TP >
1.

**Fix:**
- Detect at init time (statically, not via runtime shape checks) whether
a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo
treats the branch as a constant.
- In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1
+ is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate
output as `[N//tp_size, H]`.
- Remove the platform-level assertion that prevented VL models from
enabling Flash Comm V1.

**Other improvements:**
- `is_vl_model()` now uses vllm's canonical detection (`hf_config is not
hf_text_config`) instead of fragile key-name checks, with the old checks
kept as fallback.
- Added `parse_layer_idx(prefix)` utility.
- Added `maybe_chunk_residual` call in `AscendRMSNorm` before the
add-rms-norm op.
- Removed unnecessary CPU/fp32 round-trip in
  `AscendLearnable2DInterpPosEmbDivided_fixed.forward()`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
Cao Yi
2026-03-22 21:05:28 +08:00
committed by GitHub
parent 9d0b7c8e98
commit 9e2965bae2
4 changed files with 49 additions and 17 deletions

View File

@@ -42,7 +42,6 @@ from vllm_ascend.utils import (
flashcomm2_enable,
get_ascend_device_type,
is_moe_model,
is_vl_model,
refresh_block_size,
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
@@ -420,11 +419,6 @@ class NPUPlatform(Platform):
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
if enable_sp(vllm_config):
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 is only supported when tp_size > 1."
)