# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Any, cast from vllm.config import VllmConfig, get_layers_from_vllm_config if TYPE_CHECKING: from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase else: AttentionLayerBase = object def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None: pcp_size = vllm_config.parallel_config.prefill_context_parallel_size dcp_size = vllm_config.parallel_config.decode_context_parallel_size interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size if pcp_size * dcp_size > 1: layer_type = cast(type[Any], AttentionLayerBase) layers = get_layers_from_vllm_config(vllm_config, layer_type) for layer in layers.values(): layer_impl = getattr(layer, "impl", None) if layer_impl is None: continue if vllm_config.speculative_config is not None and interleave_size > 1: assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, ( "MTP with cp_kv_cache_interleave_size > 1 is not " f"supported in {layer_impl.__class__.__name__}." ) if dcp_size > 1: assert layer_impl.need_to_return_lse_for_decode, ( "DCP requires attention impls to return" " the softmax lse for decode, but the impl " f"{layer_impl.__class__.__name__} " "does not return the softmax lse for decode." ) if pcp_size > 1: assert layer_impl.supports_pcp, ( "PCP requires attention impls' support, " f"but the impl {layer_impl.__class__.__name__} " "does not support PCP." )