66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from vllm.config.parallel import ParallelConfig
|
|
from vllm.config.speculative import SpeculativeConfig
|
|
from vllm.logger import init_logger
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
@staticmethod
|
|
def vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config(
|
|
target_parallel_config: ParallelConfig,
|
|
speculative_draft_tensor_parallel_size: int,
|
|
) -> ParallelConfig:
|
|
"""Create a parallel config for use by the draft worker.
|
|
|
|
This is mostly a copy of the target parallel config, except the tp_size.
|
|
"""
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
@brief: add draft data parallel parameters
|
|
=============================
|
|
'''
|
|
draft_parallel_config = ParallelConfig(
|
|
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
|
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
|
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
|
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
|
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
|
|
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
|
|
placement_group=target_parallel_config.placement_group,
|
|
# add draft data parallel parameters
|
|
data_parallel_size=target_parallel_config.data_parallel_size,
|
|
data_parallel_size_local=target_parallel_config.data_parallel_size_local,
|
|
data_parallel_master_ip=target_parallel_config.data_parallel_master_ip,
|
|
data_parallel_rpc_port=target_parallel_config.data_parallel_rpc_port,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
return draft_parallel_config
|
|
|
|
|
|
vllm__config__speculative__SpeculativeConfig____post_init___org = SpeculativeConfig.__post_init__
|
|
def vllm__config__speculative__SpeculativeConfig____post_init__(self):
|
|
if self.model is None and self.num_speculative_tokens is not None and self.method is None:
|
|
self.method = "mtp"
|
|
vllm__config__speculative__SpeculativeConfig____post_init___org(self)
|
|
|
|
|
|
MluHijackObject.apply_hijack(
|
|
SpeculativeConfig,
|
|
SpeculativeConfig.create_draft_parallel_config,
|
|
vllm__config__speculative__SpeculativeConfig__create_draft_parallel_config,
|
|
)
|
|
MluHijackObject.apply_hijack(
|
|
SpeculativeConfig,
|
|
SpeculativeConfig.__post_init__,
|
|
vllm__config__speculative__SpeculativeConfig____post_init__,
|
|
) |