[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390)
## Summary
Flash Comm V1 (flashcomm1) was previously blocked for all VL models.
**Root cause:** For VL models, `inputs_embeds` at layer 0 originates
from the vision encoder as a full `[N, H]` tensor — it has **not** been
reduce-scattered across TP ranks. The original MLA forward path assumed
inputs were already scattered, producing wrong output shapes under TP >
1.
**Fix:**
- Detect at init time (statically, not via runtime shape checks) whether
a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo
treats the branch as a constant.
- In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1
+ is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate
output as `[N//tp_size, H]`.
- Remove the platform-level assertion that prevented VL models from
enabling Flash Comm V1.
**Other improvements:**
- `is_vl_model()` now uses vllm's canonical detection (`hf_config is not
hf_text_config`) instead of fragile key-name checks, with the old checks
kept as fallback.
- Added `parse_layer_idx(prefix)` utility.
- Added `maybe_chunk_residual` call in `AscendRMSNorm` before the
add-rms-norm op.
- Removed unnecessary CPU/fp32 round-trip in
`AscendLearnable2DInterpPosEmbDivided_fixed.forward()`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
@@ -42,7 +42,6 @@ from vllm_ascend.utils import (
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
is_moe_model,
|
||||
is_vl_model,
|
||||
refresh_block_size,
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
@@ -420,11 +419,6 @@ class NPUPlatform(Platform):
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||
|
||||
if enable_sp(vllm_config):
|
||||
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
||||
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
||||
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
||||
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user