Sync from v0.13
This commit is contained in:
42
vllm/v1/worker/cp_utils.py
Normal file
42
vllm/v1/worker/cp_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
else:
|
||||
AttentionLayerBase = object
|
||||
|
||||
|
||||
def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
|
||||
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(vllm_config, layer_type)
|
||||
for layer in layers.values():
|
||||
layer_impl = getattr(layer, "impl", None)
|
||||
if layer_impl is None:
|
||||
continue
|
||||
if vllm_config.speculative_config is not None and interleave_size > 1:
|
||||
assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, (
|
||||
"MTP with cp_kv_cache_interleave_size > 1 is not "
|
||||
f"supported in {layer_impl.__class__.__name__}."
|
||||
)
|
||||
if dcp_size > 1:
|
||||
assert layer_impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer_impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode."
|
||||
)
|
||||
|
||||
if pcp_size > 1:
|
||||
assert layer_impl.supports_pcp, (
|
||||
"PCP requires attention impls' support, "
|
||||
f"but the impl {layer_impl.__class__.__name__} "
|
||||
"does not support PCP."
|
||||
)
|
||||
Reference in New Issue
Block a user