139 lines
5.5 KiB
Python
139 lines
5.5 KiB
Python
from typing import Tuple
|
|
from vllm.logger import init_logger
|
|
from vllm.config import ModelConfig, CacheConfig, LoRAConfig
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
vllm__config__LoRAConfig__verify_with_model_config_org = LoRAConfig.verify_with_model_config
|
|
|
|
|
|
def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add kv_cache_dtype int8 support
|
|
'''
|
|
if self.cache_dtype == "auto":
|
|
pass
|
|
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
|
logger.info(
|
|
"Using fp8 data type to store kv cache. It reduces the GPU "
|
|
"memory footprint and boosts the performance. "
|
|
"Meanwhile, it may cause accuracy drop without a proper "
|
|
"scaling factor")
|
|
elif self.cache_dtype == 'int8':
|
|
logger.info(
|
|
"Using int8 data type to store kv cache. It reduces the MLU "
|
|
"memory footprint and boosts the performance. ")
|
|
else:
|
|
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
|
"""Returns the number of KV heads per GPU."""
|
|
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
|
# feature flag MLA
|
|
return 1
|
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
|
# If tensor parallelism is used, we divide the number of KV heads by
|
|
# the tensor parallel size. We will replicate the KV heads in the
|
|
# case where the number of KV heads is smaller than the tensor
|
|
# parallel size so each GPU has at least one KV head.
|
|
return max(1,
|
|
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
|
|
|
def vllm__config__ModelConfig__get_head_size(self) -> int:
|
|
# TODO remove hard code
|
|
if hasattr(self.hf_text_config, "model_type"
|
|
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: replace 256 to 192.
|
|
'''
|
|
return 576
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
if self.is_attention_free:
|
|
return 0
|
|
|
|
if hasattr(self.hf_text_config, "head_dim"):
|
|
return self.hf_text_config.head_dim
|
|
# FIXME(woosuk): This may not be true for all models.
|
|
return (self.hf_text_config.hidden_size //
|
|
self.hf_text_config.num_attention_heads)
|
|
|
|
def vllm__config__ModelConfig__set_context_mlugraph_info(
|
|
self, enable_context_mlugraph: bool, batch_size: int, seq_len: int) -> None:
|
|
self.enable_context_mlugraph = enable_context_mlugraph
|
|
self.context_batch_size_to_capture = batch_size
|
|
self.context_seq_len_to_capture = seq_len
|
|
|
|
|
|
def vllm__config__ModelConfig__use_context_mlugraph(self) -> bool:
|
|
return hasattr(self, "enable_context_mlugraph") and self.enable_context_mlugraph
|
|
|
|
|
|
def vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq(self) -> Tuple[int, int]:
|
|
return self.context_batch_size_to_capture, self.context_seq_len_to_capture
|
|
|
|
|
|
def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: ModelConfig):
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: do not support quantization with lora for now
|
|
'''
|
|
if model_config.quantization:
|
|
raise ValueError("vllm mlu does not support quantization with lora for now")
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
vllm__config__LoRAConfig__verify_with_model_config_org(self, model_config)
|
|
|
|
@property
|
|
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
|
|
result = hasattr(
|
|
self.hf_text_config,
|
|
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp')
|
|
return result
|
|
|
|
MluHijackObject.apply_hijack(ModelConfig,
|
|
"is_deepseek_v2",
|
|
vllm__config__ModelConfig__is_deepseek_v2)
|
|
MluHijackObject.apply_hijack(ModelConfig,
|
|
"set_context_mlugraph_info",
|
|
vllm__config__ModelConfig__set_context_mlugraph_info)
|
|
MluHijackObject.apply_hijack(ModelConfig,
|
|
"use_context_mlugraph",
|
|
vllm__config__ModelConfig__use_context_mlugraph)
|
|
MluHijackObject.apply_hijack(ModelConfig,
|
|
"get_context_mlugraph_bs_and_seq",
|
|
vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq)
|
|
MluHijackObject.apply_hijack(CacheConfig,
|
|
CacheConfig._verify_cache_dtype,
|
|
vllm__config__CacheConfig___verify_cache_dtype)
|
|
MluHijackObject.apply_hijack(ModelConfig,
|
|
ModelConfig.get_head_size,
|
|
vllm__config__ModelConfig__get_head_size)
|
|
MluHijackObject.apply_hijack(ModelConfig,
|
|
ModelConfig.get_num_kv_heads,
|
|
vllm__config__ModelConfig__get_num_kv_heads)
|
|
MluHijackObject.apply_hijack(LoRAConfig,
|
|
LoRAConfig.verify_with_model_config,
|
|
vllm__config__LoRAConfig__verify_with_model_config)
|