34 lines
857 B
Python
34 lines
857 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Attention backend utils"""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
PAD_SLOT_ID = -1
|
|
|
|
|
|
@dataclass
|
|
class MLADims:
|
|
q_lora_rank: int | None
|
|
kv_lora_rank: int
|
|
qk_nope_head_dim: int
|
|
qk_rope_head_dim: int
|
|
v_head_dim: int
|
|
|
|
|
|
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
|
hf_text_config = model_config.hf_text_config
|
|
|
|
return MLADims(
|
|
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
|
kv_lora_rank=hf_text_config.kv_lora_rank,
|
|
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
|
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
|
v_head_dim=hf_text_config.v_head_dim,
|
|
)
|