add qwen3
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from . import arg_utils
|
||||
@@ -0,0 +1,141 @@
|
||||
import argparse
|
||||
import torch
|
||||
from vllm.config import VllmConfig, ParallelConfig
|
||||
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config_org = EngineArgs.create_engine_config
|
||||
vllm__engine__arg_utils__EngineArgs__add_cli_args_org = EngineArgs.add_cli_args
|
||||
vllm__engine__arg_utils__EngineArgs__from_cli_args_org = EngineArgs.from_cli_args
|
||||
vllm__engine__arg_utils__AsyncEngineArgs__from_cli_args_org = AsyncEngineArgs.from_cli_args
|
||||
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs__create_engine_config(self, ) -> VllmConfig:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: chunked parallel pipeline only support batch size = 1 yet.
|
||||
'''
|
||||
if CHUNKED_PIPELINE_PARALLEL_EN:
|
||||
self.max_num_seqs = 1
|
||||
logger.info("Reset max_num_seqs to 1 as the chunked parallel pipeline mode "
|
||||
"only supports batch size to 1.")
|
||||
'''
|
||||
@brief: disable custom_all_reduce, re-set block_size to support paged and unpaged mode.
|
||||
'''
|
||||
# MLU not support custom all reduce
|
||||
self.disable_custom_all_reduce = True
|
||||
BlockSizeInfo.set_block_size(self.block_size)
|
||||
if not USE_PAGED and self.enable_chunked_prefill:
|
||||
raise ValueError("Not support chunked_prefill in unpaged mode.")
|
||||
|
||||
# set parallel_config context_parallel_size, moe_tp_size, moe_ep_size
|
||||
self.context_parallel_size = getattr(self, "context_parallel_size", 1)
|
||||
self.moe_tp_size = getattr(self, "moe_tp_size", -1)
|
||||
self.moe_ep_size = getattr(self, "moe_ep_size", -1)
|
||||
# check context parallel whether supported or not
|
||||
if CONTEXT_PARALLEL_EN:
|
||||
if self.context_parallel_size > 1 and get_device_major_capability() == 3:
|
||||
raise ValueError('Context parallel does not support MLU370.')
|
||||
else:
|
||||
if self.context_parallel_size > 1:
|
||||
raise ValueError('Context parallel does not support when CONTEXT_PARALLEL_EN=False')
|
||||
# check expert parallel whether supported or not
|
||||
if not EXPERT_PARALLEL_EN and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
raise ValueError('Expert parallel does not support when EXPERT_PARALLEL_EN=False')
|
||||
|
||||
ParallelConfig.context_parallel_size = self.context_parallel_size
|
||||
|
||||
# set parallel_config moe_tp_size and moe_ep_size
|
||||
if self.moe_tp_size < 1 and self.moe_ep_size < 1:
|
||||
moe_tp_size = self.tensor_parallel_size
|
||||
moe_ep_size = 1
|
||||
elif self.moe_tp_size >= 1 and self.moe_ep_size < 1:
|
||||
moe_tp_size = self.moe_tp_size
|
||||
moe_ep_size = self.tensor_parallel_size // self.moe_tp_size
|
||||
elif self.moe_tp_size < 1 and self.moe_ep_size >= 1:
|
||||
moe_tp_size = self.tensor_parallel_size // self.moe_ep_size
|
||||
moe_ep_size = self.moe_ep_size
|
||||
else:
|
||||
moe_tp_size = self.moe_tp_size
|
||||
moe_ep_size = self.moe_ep_size
|
||||
assert moe_tp_size * moe_ep_size == self.tensor_parallel_size, (
|
||||
f"tensor_parallel_size ({self.tensor_parallel_size}) is not equal to "
|
||||
f"moe_tp_size ({self.moe_tp_size}) x moe_ep_size ({self.moe_ep_size})"
|
||||
"or moe_tp_size and moe_ep_size should be -1 or one of them should be -1")
|
||||
|
||||
ParallelConfig.moe_tp_size = moe_tp_size
|
||||
ParallelConfig.moe_ep_size = moe_ep_size
|
||||
|
||||
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(self)
|
||||
engine_config.cache_config.block_size = BlockSizeInfo.BLOCK_SIZE
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return engine_config
|
||||
|
||||
|
||||
@staticmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__add_cli_args(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser = vllm__engine__arg_utils__EngineArgs__add_cli_args_org(parser)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add --context-parallel-size, --moe-tp-size and --moe-ep-size
|
||||
'''
|
||||
parser.add_argument('--context-parallel-size',
|
||||
'-cp',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of context parallel replicas')
|
||||
parser.add_argument('--moe-tp-size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Number of moe tensor parallel replicas')
|
||||
parser.add_argument('--moe-ep-size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Number of moe expert parallel replicas')
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return parser
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||
if cls == AsyncEngineArgs:
|
||||
engine_args = vllm__engine__arg_utils__AsyncEngineArgs__from_cli_args_org(args)
|
||||
else:
|
||||
engine_args = vllm__engine__arg_utils__EngineArgs__from_cli_args_org(args)
|
||||
setattr(engine_args, 'context_parallel_size', getattr(args, "context_parallel_size"))
|
||||
setattr(engine_args, 'moe_tp_size', getattr(args, "moe_tp_size"))
|
||||
setattr(engine_args, 'moe_ep_size', getattr(args, "moe_ep_size"))
|
||||
return engine_args
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.create_engine_config,
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.add_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__add_cli_args)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.from_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__from_cli_args)
|
||||
MluHijackObject.apply_hijack(AsyncEngineArgs,
|
||||
AsyncEngineArgs.from_cli_args,
|
||||
vllm__engine__arg_utils__EngineArgs__from_cli_args)
|
||||
Reference in New Issue
Block a user