[Model] Support DeepSeek-V4
This commit is contained in:
66
vllm_mlu/config/speculative.py
Normal file
66
vllm_mlu/config/speculative.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@staticmethod
|
||||
def vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config, except the tp_size.
|
||||
"""
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
@brief: add draft data parallel parameters
|
||||
=============================
|
||||
'''
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
# add draft data parallel parameters
|
||||
data_parallel_size=target_parallel_config.data_parallel_size,
|
||||
data_parallel_size_local=target_parallel_config.data_parallel_size_local,
|
||||
data_parallel_master_ip=target_parallel_config.data_parallel_master_ip,
|
||||
data_parallel_rpc_port=target_parallel_config.data_parallel_rpc_port,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
|
||||
vllm__config__speculative__SpeculativeConfig____post_init___org = SpeculativeConfig.__post_init__
|
||||
def vllm__config__speculative__SpeculativeConfig____post_init__(self):
|
||||
if self.model is None and self.num_speculative_tokens is not None and self.method is None:
|
||||
self.method = "mtp"
|
||||
vllm__config__speculative__SpeculativeConfig____post_init___org(self)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
SpeculativeConfig,
|
||||
SpeculativeConfig.create_draft_parallel_config,
|
||||
vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
SpeculativeConfig,
|
||||
SpeculativeConfig.__post_init__,
|
||||
vllm__config__speculative__SpeculativeConfig____post_init__,
|
||||
)
|
||||
Reference in New Issue
Block a user