49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import Callable, Any
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.parameter import BasevLLMParameter
|
|
from vllm.distributed import (
|
|
get_parallel_rank_with_group,
|
|
get_parallel_world_size_with_group,
|
|
)
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
|
|
vllm__model_executor__parameter__BasevLLMParameter____init__org = BasevLLMParameter.__init__
|
|
|
|
|
|
def vllm__model_executor__parameter__BasevLLMParameter____init__(
|
|
self,
|
|
data: torch.Tensor,
|
|
weight_loader: Callable,
|
|
tp_group: Any = None
|
|
):
|
|
vllm__model_executor__parameter__BasevLLMParameter____init__org(
|
|
self, data, weight_loader
|
|
)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add self.tp_group, world_size and tp_rank
|
|
'''
|
|
if tp_group is not None:
|
|
self.tp_group = tp_group
|
|
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
|
|
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
|
|
'''
|
|
=================
|
|
End of MLU Hijack
|
|
=================
|
|
'''
|
|
|
|
|
|
MluHijackObject.apply_hijack(BasevLLMParameter,
|
|
BasevLLMParameter.__init__,
|
|
vllm__model_executor__parameter__BasevLLMParameter____init__)
|