diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index d998ddab..71ba7e73 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -68,6 +68,7 @@ class AscendRMSNorm(RMSNorm): import torch_npu if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( x, residual, self.weight, self.bias, self.variance_epsilon diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index de2e2fa4..4da2507e 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -33,6 +33,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX +from vllm_ascend.utils import is_vl_model, parse_layer_idx class IndexerWrapper(nn.Module): @@ -134,7 +135,16 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): self.mla_attn.process_weights_after_loading = wrapped_process_weights - compilation_config = get_current_vllm_config().compilation_config + # For VL models (e.g. Kimi K2.5), inputs_embeds at layer 0 comes from + # the vision encoder as full [N, H] — it has NOT been reduce-scattered. + # We detect this statically at init time (not at runtime via shape checks, + # which break graph-mode compilation) so the branch is a constant to dynamo. + vllm_config = get_current_vllm_config() + _is_vl = is_vl_model(vllm_config) + _layer_idx = parse_layer_idx(prefix) + self.is_vl_first_layer = bool(_is_vl and _layer_idx == 0) + + compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @@ -146,12 +156,18 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): kv_cache: torch.Tensor | None = None, attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: - need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled - output_shape = hidden_states.shape - # FIXME: This does not seem right, should make sure the buffer is fixed - output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + hidden_dim = hidden_states.shape[-1] + + if _EXTRA_CTX.flash_comm_v1_enabled and self.tp_size > 1 and self.is_vl_first_layer: + need_gather_q_kv = False + n_out = hidden_states.shape[0] // self.tp_size + output = torch.empty((n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device) + else: + need_gather_q_kv = _EXTRA_CTX.flash_comm_v1_enabled + output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) + torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output, self.prefix) - output = output.view(-1, output_shape[-1]) + output = output.view(-1, hidden_dim) return output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 0dda6be2..bf460a83 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -42,7 +42,6 @@ from vllm_ascend.utils import ( flashcomm2_enable, get_ascend_device_type, is_moe_model, - is_vl_model, refresh_block_size, update_aclgraph_sizes, update_cudagraph_capture_sizes, @@ -420,11 +419,6 @@ class NPUPlatform(Platform): vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size if enable_sp(vllm_config): - assert not is_vl_model(vllm_config), """Flash Comm V1 is not supported for VL models. \ - Please disable it by setting VLLM_ASCEND_ENABLE_FLASHCOMM1=0. \ - For optimal performance with VL models, we recommend enabling Sequence Parallelism \ - via --compilation-config '{"pass_config": {"enable_sp": true}}'.""" - assert vllm_config.parallel_config.tensor_parallel_size > 1, ( "Flash Comm v1 is only supported when tp_size > 1." ) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1d341899..0f310e36 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -23,6 +23,7 @@ import atexit import functools import math import os +import re from contextlib import nullcontext from enum import Enum from functools import lru_cache @@ -842,15 +843,29 @@ def _is_contain_expert(config: Any): def is_vl_model(vllm_config: VllmConfig): - """Checks if the model is a VL model by config""" + """Checks if the model is a VL model by config. + + Uses the same criterion as vllm itself (model_config.py): a model is + multimodal when its top-level hf_config differs from its hf_text_config + (i.e. there is a separate vision sub-config). The legacy key-name checks + are kept as fallbacks for configs that override get_text_config() to return + self (rare but possible). + """ global _IS_VL_MODEL if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config: - hf_config = vllm_config.model_config.hf_config.to_dict() - if "thinker_config" in hf_config: - # Qwen-Omni-thinker models + model_config = vllm_config.model_config + # Primary: vllm's own VL detection — hf_config is the top-level + # (multimodal) config; hf_text_config is the language-model sub-config. + # They are the same object for pure-text models. + if model_config.hf_config is not model_config.hf_text_config: _IS_VL_MODEL = True else: - _IS_VL_MODEL = "vision_config" in hf_config + # Fallback: check well-known config keys + hf_config = model_config.hf_config.to_dict() + if "thinker_config" in hf_config or "vision_config" in hf_config: + _IS_VL_MODEL = True + else: + _IS_VL_MODEL = False return _IS_VL_MODEL @@ -1244,3 +1259,9 @@ def trans_nd_to_nz(cache_tensor: torch.Tensor): cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) cache_tensor = cache_tensor.permute(*array_trans) return cache_tensor + + +def parse_layer_idx(prefix: str) -> int | None: + """Extract the layer index from a module prefix string like 'model.layers.0.self_attn'.""" + match = re.search(r"layers\.(\d+)", prefix) + return int(match.group(1)) if match else None