# mypy: ignore-errors import math import vllm.model_executor.models.config from vllm.logger import logger from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.config import MambaModelConfig from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size @classmethod def verify_and_update_config(cls, vllm_config) -> None: """ Ensure that page size of attention layers is greater than or equal to the mamba layers. If not, automatically set the attention block size to ensure that it is. If the attention page size is strictly greater than the mamba page size, we pad the mamba page size to make them equal. Args: vllm_config: vLLM Config """ # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config if cache_config.cache_dtype == "auto": kv_cache_dtype = model_config.dtype else: kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] kernel_block_size = 128 # get attention block size attn_num_kv_heads = model_config.get_num_kv_heads(parallel_config) attn_head_size = model_config.get_head_size() attn_single_token_k_page_size = attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype) model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, model_config=model_config, ) # get mamba block size mamba_shapes = model_cls.get_mamba_state_shape_from_config(vllm_config) mamba_dtypes = model_cls.get_mamba_state_dtype_from_config(vllm_config) mamba_sizes = [] for shape, dtype in zip(mamba_shapes, mamba_dtypes): mamba_sizes.append(math.prod(shape) * get_dtype_size(dtype)) ssm_block_page_size, conv_block_page_size = max(mamba_sizes), min(mamba_sizes) # NOTE(zxr): because of the limit of Ascend Hardware, we need to keep # all cache tensors contiguous, so we align the page size of ssm_block # and single attn_block attn_block_size = kernel_block_size * cdiv(ssm_block_page_size, kernel_block_size * attn_single_token_k_page_size) assert attn_single_token_k_page_size * attn_block_size == ssm_block_page_size, ( "Cannot align ssm_page_size and attn_page_size." ) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. if cache_config.block_size is None or cache_config.block_size < attn_block_size: cache_config.block_size = attn_block_size logger.info( "Setting attention block size to %d tokens to ensure that attention page size is >= mamba page size.", attn_block_size, ) # compute new attention page size attn_page_size = cache_config.block_size * 2 * attn_head_size * attn_num_kv_heads * get_dtype_size(kv_cache_dtype) # pad mamba page size for conv_blocks if ( cache_config.mamba_page_size_padded is None or cache_config.mamba_page_size_padded != attn_page_size + conv_block_page_size ): cache_config.mamba_page_size_padded = attn_page_size + conv_block_page_size mamba_padding_pct = 100 * conv_block_page_size / cache_config.mamba_page_size_padded logger.info( "Padding mamba page size by %.2f%% to ensure " "that mamba page size and attention page size are " "exactly equal.", mamba_padding_pct, ) if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align": cache_config.mamba_block_size = cache_config.block_size else: cache_config.mamba_block_size = model_config.max_model_len vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config