[Feature] Support Flash Comm V1 for VL models (with MLA) (#7390)
## Summary
Flash Comm V1 (flashcomm1) was previously blocked for all VL models.
**Root cause:** For VL models, `inputs_embeds` at layer 0 originates
from the vision encoder as a full `[N, H]` tensor — it has **not** been
reduce-scattered across TP ranks. The original MLA forward path assumed
inputs were already scattered, producing wrong output shapes under TP >
1.
**Fix:**
- Detect at init time (statically, not via runtime shape checks) whether
a layer is the first layer of a VL model (`is_vl_first_layer`) so dynamo
treats the branch as a constant.
- In `AscendMultiHeadLatentAttention.forward`, when `flashcomm1 + TP > 1
+ is_vl_first_layer`, set `need_gather_q_kv=False` and pre-allocate
output as `[N//tp_size, H]`.
- Remove the platform-level assertion that prevented VL models from
enabling Flash Comm V1.
**Other improvements:**
- `is_vl_model()` now uses vllm's canonical detection (`hf_config is not
hf_text_config`) instead of fragile key-name checks, with the old checks
kept as fallback.
- Added `parse_layer_idx(prefix)` utility.
- Added `maybe_chunk_residual` call in `AscendRMSNorm` before the
add-rms-norm op.
- Removed unnecessary CPU/fp32 round-trip in
`AscendLearnable2DInterpPosEmbDivided_fixed.forward()`.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
@@ -68,6 +68,7 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
|
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||||
if enable_custom_op():
|
if enable_custom_op():
|
||||||
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
x, residual, self.weight, self.bias, self.variance_epsilon
|
x, residual, self.weight, self.bias, self.variance_epsilon
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||||
|
from vllm_ascend.utils import is_vl_model, parse_layer_idx
|
||||||
|
|
||||||
|
|
||||||
class IndexerWrapper(nn.Module):
|
class IndexerWrapper(nn.Module):
|
||||||
@@ -134,7 +135,16 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
|||||||
|
|
||||||
self.mla_attn.process_weights_after_loading = wrapped_process_weights
|
self.mla_attn.process_weights_after_loading = wrapped_process_weights
|
||||||
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
# For VL models (e.g. Kimi K2.5), inputs_embeds at layer 0 comes from
|
||||||
|
# the vision encoder as full [N, H] — it has NOT been reduce-scattered.
|
||||||
|
# We detect this statically at init time (not at runtime via shape checks,
|
||||||
|
# which break graph-mode compilation) so the branch is a constant to dynamo.
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
_is_vl = is_vl_model(vllm_config)
|
||||||
|
_layer_idx = parse_layer_idx(prefix)
|
||||||
|
self.is_vl_first_layer = bool(_is_vl and _layer_idx == 0)
|
||||||
|
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
@@ -146,12 +156,18 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
|||||||
kv_cache: torch.Tensor | None = None,
|
kv_cache: torch.Tensor | None = None,
|
||||||
attn_metadata: AttentionMetadata | None = None,
|
attn_metadata: AttentionMetadata | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
hidden_dim = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
if _EXTRA_CTX.flash_comm_v1_enabled and self.tp_size > 1 and self.is_vl_first_layer:
|
||||||
|
need_gather_q_kv = False
|
||||||
|
n_out = hidden_states.shape[0] // self.tp_size
|
||||||
|
output = torch.empty((n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
else:
|
||||||
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
|
need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled
|
||||||
output_shape = hidden_states.shape
|
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
|
||||||
output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
|
||||||
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix)
|
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix)
|
||||||
output = output.view(-1, output_shape[-1])
|
output = output.view(-1, hidden_dim)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ from vllm_ascend.utils import (
|
|||||||
flashcomm2_enable,
|
flashcomm2_enable,
|
||||||
get_ascend_device_type,
|
get_ascend_device_type,
|
||||||
is_moe_model,
|
is_moe_model,
|
||||||
is_vl_model,
|
|
||||||
refresh_block_size,
|
refresh_block_size,
|
||||||
update_aclgraph_sizes,
|
update_aclgraph_sizes,
|
||||||
update_cudagraph_capture_sizes,
|
update_cudagraph_capture_sizes,
|
||||||
@@ -420,11 +419,6 @@ class NPUPlatform(Platform):
|
|||||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||||
|
|
||||||
if enable_sp(vllm_config):
|
if enable_sp(vllm_config):
|
||||||
assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \
|
|
||||||
Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \
|
|
||||||
For optimal performance with VL models, we recommend enabling Sequence Parallelism \
|
|
||||||
via --compilation-config '{"pass_config": {"enable_sp": true}}'."""
|
|
||||||
|
|
||||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||||
"Flash Comm v1 is only supported when tp_size > 1."
|
"Flash Comm v1 is only supported when tp_size > 1."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import atexit
|
|||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@@ -842,15 +843,29 @@ def _is_contain_expert(config: Any):
|
|||||||
|
|
||||||
|
|
||||||
def is_vl_model(vllm_config: VllmConfig):
|
def is_vl_model(vllm_config: VllmConfig):
|
||||||
"""Checks if the model is a VL model by config"""
|
"""Checks if the model is a VL model by config.
|
||||||
|
|
||||||
|
Uses the same criterion as vllm itself (model_config.py): a model is
|
||||||
|
multimodal when its top-level hf_config differs from its hf_text_config
|
||||||
|
(i.e. there is a separate vision sub-config). The legacy key-name checks
|
||||||
|
are kept as fallbacks for configs that override get_text_config() to return
|
||||||
|
self (rare but possible).
|
||||||
|
"""
|
||||||
global _IS_VL_MODEL
|
global _IS_VL_MODEL
|
||||||
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
|
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
|
||||||
hf_config = vllm_config.model_config.hf_config.to_dict()
|
model_config = vllm_config.model_config
|
||||||
if "thinker_config" in hf_config:
|
# Primary: vllm's own VL detection — hf_config is the top-level
|
||||||
# Qwen-Omni-thinker models
|
# (multimodal) config; hf_text_config is the language-model sub-config.
|
||||||
|
# They are the same object for pure-text models.
|
||||||
|
if model_config.hf_config is not model_config.hf_text_config:
|
||||||
_IS_VL_MODEL = True
|
_IS_VL_MODEL = True
|
||||||
else:
|
else:
|
||||||
_IS_VL_MODEL = "vision_config" in hf_config
|
# Fallback: check well-known config keys
|
||||||
|
hf_config = model_config.hf_config.to_dict()
|
||||||
|
if "thinker_config" in hf_config or "vision_config" in hf_config:
|
||||||
|
_IS_VL_MODEL = True
|
||||||
|
else:
|
||||||
|
_IS_VL_MODEL = False
|
||||||
return _IS_VL_MODEL
|
return _IS_VL_MODEL
|
||||||
|
|
||||||
|
|
||||||
@@ -1244,3 +1259,9 @@ def trans_nd_to_nz(cache_tensor: torch.Tensor):
|
|||||||
cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0])
|
cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0])
|
||||||
cache_tensor = cache_tensor.permute(*array_trans)
|
cache_tensor = cache_tensor.permute(*array_trans)
|
||||||
return cache_tensor
|
return cache_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def parse_layer_idx(prefix: str) -> int | None:
|
||||||
|
"""Extract the layer index from a module prefix string like 'model.layers.0.self_attn'."""
|
||||||
|
match = re.search(r"layers\.(\d+)", prefix)
|
||||||
|
return int(match.group(1)) if match else None
|
||||||
|
|||||||
Reference in New Issue
Block a user