Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py
2026-02-11 17:47:15 +08:00

139 lines
5.5 KiB
Python

from typing import Tuple
from vllm.logger import init_logger
from vllm.config import ModelConfig, CacheConfig, LoRAConfig
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
vllm__config__LoRAConfig__verify_with_model_config_org = LoRAConfig.verify_with_model_config
def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: add kv_cache_dtype int8 support
'''
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
elif self.cache_dtype == 'int8':
logger.info(
"Using int8 data type to store kv cache. It reduces the MLU "
"memory footprint and boosts the performance. ")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
'''
==================
End of MLU Hijack
==================
'''
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
# feature flag MLA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
def vllm__config__ModelConfig__get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
'''
=============================
Modify by vllm_mlu
=============================
@brief: replace 256 to 192.
'''
return 576
'''
==================
End of MLU Hijack
==================
'''
if self.is_attention_free:
return 0
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_attention_heads)
def vllm__config__ModelConfig__set_context_mlugraph_info(
self, enable_context_mlugraph: bool, batch_size: int, seq_len: int) -> None:
self.enable_context_mlugraph = enable_context_mlugraph
self.context_batch_size_to_capture = batch_size
self.context_seq_len_to_capture = seq_len
def vllm__config__ModelConfig__use_context_mlugraph(self) -> bool:
return hasattr(self, "enable_context_mlugraph") and self.enable_context_mlugraph
def vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq(self) -> Tuple[int, int]:
return self.context_batch_size_to_capture, self.context_seq_len_to_capture
def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: ModelConfig):
'''
=============================
Modify by vllm_mlu
=============================
@brief: do not support quantization with lora for now
'''
if model_config.quantization:
raise ValueError("vllm mlu does not support quantization with lora for now")
'''
==================
End of MLU Hijack
==================
'''
vllm__config__LoRAConfig__verify_with_model_config_org(self, model_config)
@property
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
result = hasattr(
self.hf_text_config,
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp')
return result
MluHijackObject.apply_hijack(ModelConfig,
"is_deepseek_v2",
vllm__config__ModelConfig__is_deepseek_v2)
MluHijackObject.apply_hijack(ModelConfig,
"set_context_mlugraph_info",
vllm__config__ModelConfig__set_context_mlugraph_info)
MluHijackObject.apply_hijack(ModelConfig,
"use_context_mlugraph",
vllm__config__ModelConfig__use_context_mlugraph)
MluHijackObject.apply_hijack(ModelConfig,
"get_context_mlugraph_bs_and_seq",
vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq)
MluHijackObject.apply_hijack(CacheConfig,
CacheConfig._verify_cache_dtype,
vllm__config__CacheConfig___verify_cache_dtype)
MluHijackObject.apply_hijack(ModelConfig,
ModelConfig.get_head_size,
vllm__config__ModelConfig__get_head_size)
MluHijackObject.apply_hijack(ModelConfig,
ModelConfig.get_num_kv_heads,
vllm__config__ModelConfig__get_num_kv_heads)
MluHijackObject.apply_hijack(LoRAConfig,
LoRAConfig.verify_with_model_config,
vllm__config__LoRAConfig__verify_with_model_config)