[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390)
## Summary
Flash Comm V1 (flashcomm1) was previously blocked for all VL models.
**Root cause:** For VL models, `inputs_embeds` at layer 0 originates
from the vision encoder as a full `[N, H]` tensor — it has **not** been
reduce-scattered across TP ranks. The original MLA forward path assumed
inputs were already scattered, producing wrong output shapes under TP >
1.
**Fix:**
- Detect at init time (statically, not via runtime shape checks) whether
a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo
treats the branch as a constant.
- In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1
+ is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate
output as `[N//tp_size, H]`.
- Remove the platform-level assertion that prevented VL models from
enabling Flash Comm V1.
**Other improvements:**
- `is_vl_model()` now uses vllm's canonical detection (`hf_config is not
hf_text_config`) instead of fragile key-name checks, with the old checks
kept as fallback.
- Added `parse_layer_idx(prefix)` utility.
- Added `maybe_chunk_residual` call in `AscendRMSNorm` before the
add-rms-norm op.
- Removed unnecessary CPU/fp32 round-trip in
`AscendLearnable2DInterpPosEmbDivided_fixed.forward()`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ import atexit
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
@@ -842,15 +843,29 @@ def _is_contain_expert(config: Any):
|
||||
|
||||
|
||||
def is_vl_model(vllm_config: VllmConfig):
|
||||
"""Checks if the model is a VL model by config"""
|
||||
"""Checks if the model is a VL model by config.
|
||||
|
||||
Uses the same criterion as vllm itself (model_config.py): a model is
|
||||
multimodal when its top-level hf_config differs from its hf_text_config
|
||||
(i.e. there is a separate vision sub-config). The legacy key-name checks
|
||||
are kept as fallbacks for configs that override get_text_config() to return
|
||||
self (rare but possible).
|
||||
"""
|
||||
global _IS_VL_MODEL
|
||||
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
|
||||
hf_config = vllm_config.model_config.hf_config.to_dict()
|
||||
if "thinker_config" in hf_config:
|
||||
# Qwen-Omni-thinker models
|
||||
model_config = vllm_config.model_config
|
||||
# Primary: vllm's own VL detection — hf_config is the top-level
|
||||
# (multimodal) config; hf_text_config is the language-model sub-config.
|
||||
# They are the same object for pure-text models.
|
||||
if model_config.hf_config is not model_config.hf_text_config:
|
||||
_IS_VL_MODEL = True
|
||||
else:
|
||||
_IS_VL_MODEL = "vision_config" in hf_config
|
||||
# Fallback: check well-known config keys
|
||||
hf_config = model_config.hf_config.to_dict()
|
||||
if "thinker_config" in hf_config or "vision_config" in hf_config:
|
||||
_IS_VL_MODEL = True
|
||||
else:
|
||||
_IS_VL_MODEL = False
|
||||
return _IS_VL_MODEL
|
||||
|
||||
|
||||
@@ -1244,3 +1259,9 @@ def trans_nd_to_nz(cache_tensor: torch.Tensor):
|
||||
cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0])
|
||||
cache_tensor = cache_tensor.permute(*array_trans)
|
||||
return cache_tensor
|
||||
|
||||
|
||||
def parse_layer_idx(prefix: str) -> int | None:
|
||||
"""Extract the layer index from a module prefix string like 'model.layers.0.self_attn'."""
|
||||
match = re.search(r"layers\.(\d+)", prefix)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
Reference in New Issue
Block a user