forked from EngineX-Cambricon/enginex-mlu370-vllm
121 lines
4.5 KiB
Python
121 lines
4.5 KiB
Python
from vllm.config import ModelConfig, VllmConfig
|
|
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
|
from vllm_mlu._mlu_utils import (BlockSizeInfo, USE_PAGED, get_device_name)
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
vllm__engine__arg_utils__EngineArgs__create_model_config_org = EngineArgs.create_model_config
|
|
vllm__engine__arg_utils__EngineArgs__create_engine_config_org = EngineArgs.create_engine_config
|
|
vllm__engine__arg_utils__EngineArgs__add_cli_args_org = EngineArgs.add_cli_args
|
|
vllm_engine__arg_utils__EngineArgs____post_init__org = EngineArgs.__post_init__
|
|
|
|
|
|
def vllm_engine__arg_utils__EngineArgs____post_init__(self,) -> None:
|
|
'''
|
|
=============================
|
|
Add by vllm_mlu
|
|
=============================
|
|
@brief: 1. In MLU3XX device, when the tensor_parallel_size > 1, the enforce_eager is forced to set False.
|
|
2. For unpaged mode, set default block_size=2048.
|
|
'''
|
|
unsupport_graph_device = "3" in get_device_name()
|
|
if unsupport_graph_device and self.tensor_parallel_size > 1 and self.enforce_eager != True:
|
|
self.enforce_eager = True
|
|
logger.warning("The current device only support eager mode, when the tensor_parallel_size > 1. "
|
|
"The param enforce_eager is forced to set True")
|
|
|
|
if not USE_PAGED and self.block_size == 16:
|
|
self.block_size = 2048
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
vllm_engine__arg_utils__EngineArgs____post_init__org(self)
|
|
|
|
|
|
def vllm__engine__arg_utils__EngineArgs__create_model_config(self) -> ModelConfig:
|
|
model_config = vllm__engine__arg_utils__EngineArgs__create_model_config_org(self)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: set context mlugraph info for model config
|
|
'''
|
|
model_config.set_context_mlugraph_info(
|
|
getattr(self, "enable_context_mlugraph", False),
|
|
getattr(self, "context_batch_size_to_capture", None),
|
|
getattr(self, "context_seq_len_to_capture", None))
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return model_config
|
|
|
|
|
|
def vllm__engine__arg_utils__EngineArgs__create_engine_config(self) -> VllmConfig:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: disable custom_all_reduce, re-set block_size to support paged and unpaged mode.
|
|
'''
|
|
# MLU not support custom all reduce
|
|
self.disable_custom_all_reduce = True
|
|
BlockSizeInfo.set_block_size(self.block_size)
|
|
if not USE_PAGED and self.enable_chunked_prefill:
|
|
raise ValueError("Not support chunked_prefill in unpaged mode.")
|
|
|
|
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(self)
|
|
engine_config.cache_config.block_size = BlockSizeInfo.BLOCK_SIZE
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return engine_config
|
|
|
|
|
|
@staticmethod
|
|
def vllm__engine__arg_utils__EngineArgs__add_cli_args(
|
|
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|
parser = vllm__engine__arg_utils__EngineArgs__add_cli_args_org(parser)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: 1. remove block_size choices, set default value to -1
|
|
2. add kv_cache_dtype choices of 'int8'
|
|
'''
|
|
for action in parser._actions:
|
|
if action.dest == "block_size":
|
|
action.choices = None
|
|
action.default = -1
|
|
elif action.dest == "kv_cache_dtype":
|
|
action.choices += ['int8']
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return parser
|
|
|
|
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs.__post_init__,
|
|
vllm_engine__arg_utils__EngineArgs____post_init__)
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs.create_model_config,
|
|
vllm__engine__arg_utils__EngineArgs__create_model_config)
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs.create_engine_config,
|
|
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs.add_cli_args,
|
|
vllm__engine__arg_utils__EngineArgs__add_cli_args)
|