Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/engine/arg_utils.py
2026-02-04 17:22:39 +08:00

121 lines
4.5 KiB
Python

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
from vllm_mlu._mlu_utils import (BlockSizeInfo, USE_PAGED, get_device_name)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
vllm__engine__arg_utils__EngineArgs__create_model_config_org = EngineArgs.create_model_config
vllm__engine__arg_utils__EngineArgs__create_engine_config_org = EngineArgs.create_engine_config
vllm__engine__arg_utils__EngineArgs__add_cli_args_org = EngineArgs.add_cli_args
vllm_engine__arg_utils__EngineArgs____post_init__org = EngineArgs.__post_init__
def vllm_engine__arg_utils__EngineArgs____post_init__(self,) -> None:
'''
=============================
Add by vllm_mlu
=============================
@brief: 1. In MLU3XX device, when the tensor_parallel_size > 1, the enforce_eager is forced to set False.
2. For unpaged mode, set default block_size=2048.
'''
unsupport_graph_device = "3" in get_device_name()
if unsupport_graph_device and self.tensor_parallel_size > 1 and self.enforce_eager != True:
self.enforce_eager = True
logger.warning("The current device only support eager mode, when the tensor_parallel_size > 1. "
"The param enforce_eager is forced to set True")
if not USE_PAGED and self.block_size == 16:
self.block_size = 2048
'''
==================
End of MLU Hijack
==================
'''
vllm_engine__arg_utils__EngineArgs____post_init__org(self)
def vllm__engine__arg_utils__EngineArgs__create_model_config(self) -> ModelConfig:
model_config = vllm__engine__arg_utils__EngineArgs__create_model_config_org(self)
'''
=============================
Modify by vllm_mlu
=============================
@brief: set context mlugraph info for model config
'''
model_config.set_context_mlugraph_info(
getattr(self, "enable_context_mlugraph", False),
getattr(self, "context_batch_size_to_capture", None),
getattr(self, "context_seq_len_to_capture", None))
'''
==================
End of MLU Hijack
==================
'''
return model_config
def vllm__engine__arg_utils__EngineArgs__create_engine_config(self) -> VllmConfig:
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable custom_all_reduce, re-set block_size to support paged and unpaged mode.
'''
# MLU not support custom all reduce
self.disable_custom_all_reduce = True
BlockSizeInfo.set_block_size(self.block_size)
if not USE_PAGED and self.enable_chunked_prefill:
raise ValueError("Not support chunked_prefill in unpaged mode.")
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(self)
engine_config.cache_config.block_size = BlockSizeInfo.BLOCK_SIZE
'''
==================
End of MLU Hijack
==================
'''
return engine_config
@staticmethod
def vllm__engine__arg_utils__EngineArgs__add_cli_args(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser = vllm__engine__arg_utils__EngineArgs__add_cli_args_org(parser)
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1. remove block_size choices, set default value to -1
2. add kv_cache_dtype choices of 'int8'
'''
for action in parser._actions:
if action.dest == "block_size":
action.choices = None
action.default = -1
elif action.dest == "kv_cache_dtype":
action.choices += ['int8']
'''
==================
End of MLU Hijack
==================
'''
return parser
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.__post_init__,
vllm_engine__arg_utils__EngineArgs____post_init__)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_model_config,
vllm__engine__arg_utils__EngineArgs__create_model_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_engine_config,
vllm__engine__arg_utils__EngineArgs__create_engine_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.add_cli_args,
vllm__engine__arg_utils__EngineArgs__add_cli_args)