forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
138
vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py
Normal file
138
vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Tuple
|
||||
from vllm.logger import init_logger
|
||||
from vllm.config import ModelConfig, CacheConfig, LoRAConfig
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm__config__LoRAConfig__verify_with_model_config_org = LoRAConfig.verify_with_model_config
|
||||
|
||||
|
||||
def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add kv_cache_dtype int8 support
|
||||
'''
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor")
|
||||
elif self.cache_dtype == 'int8':
|
||||
logger.info(
|
||||
"Using int8 data type to store kv cache. It reduces the MLU "
|
||||
"memory footprint and boosts the performance. ")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
|
||||
# feature flag MLA
|
||||
return 1
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1,
|
||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||
|
||||
def vllm__config__ModelConfig__get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if hasattr(self.hf_text_config, "model_type"
|
||||
) and self.hf_text_config.model_type == 'deepseek_v2':
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace 256 to 192.
|
||||
'''
|
||||
return 576
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
if hasattr(self.hf_text_config, "head_dim"):
|
||||
return self.hf_text_config.head_dim
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return (self.hf_text_config.hidden_size //
|
||||
self.hf_text_config.num_attention_heads)
|
||||
|
||||
def vllm__config__ModelConfig__set_context_mlugraph_info(
|
||||
self, enable_context_mlugraph: bool, batch_size: int, seq_len: int) -> None:
|
||||
self.enable_context_mlugraph = enable_context_mlugraph
|
||||
self.context_batch_size_to_capture = batch_size
|
||||
self.context_seq_len_to_capture = seq_len
|
||||
|
||||
|
||||
def vllm__config__ModelConfig__use_context_mlugraph(self) -> bool:
|
||||
return hasattr(self, "enable_context_mlugraph") and self.enable_context_mlugraph
|
||||
|
||||
|
||||
def vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq(self) -> Tuple[int, int]:
|
||||
return self.context_batch_size_to_capture, self.context_seq_len_to_capture
|
||||
|
||||
|
||||
def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: ModelConfig):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: do not support quantization with lora for now
|
||||
'''
|
||||
if model_config.quantization:
|
||||
raise ValueError("vllm mlu does not support quantization with lora for now")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
vllm__config__LoRAConfig__verify_with_model_config_org(self, model_config)
|
||||
|
||||
@property
|
||||
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
|
||||
result = hasattr(
|
||||
self.hf_text_config,
|
||||
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
|
||||
return result
|
||||
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"is_deepseek_v2",
|
||||
vllm__config__ModelConfig__is_deepseek_v2)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"set_context_mlugraph_info",
|
||||
vllm__config__ModelConfig__set_context_mlugraph_info)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"use_context_mlugraph",
|
||||
vllm__config__ModelConfig__use_context_mlugraph)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
"get_context_mlugraph_bs_and_seq",
|
||||
vllm__config__ModelConfig__get_context_mlugraph_bs_and_seq)
|
||||
MluHijackObject.apply_hijack(CacheConfig,
|
||||
CacheConfig._verify_cache_dtype,
|
||||
vllm__config__CacheConfig___verify_cache_dtype)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
ModelConfig.get_head_size,
|
||||
vllm__config__ModelConfig__get_head_size)
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
ModelConfig.get_num_kv_heads,
|
||||
vllm__config__ModelConfig__get_num_kv_heads)
|
||||
MluHijackObject.apply_hijack(LoRAConfig,
|
||||
LoRAConfig.verify_with_model_config,
|
||||
vllm__config__LoRAConfig__verify_with_model_config)
|
||||
Reference in New Issue
Block a user