Files
enginex-mlu590-vllm/vllm_mlu/model_executor/parameter.py
2026-04-24 09:58:03 +08:00

49 lines
1.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Callable, Any
import torch
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.distributed import (
get_parallel_rank_with_group,
get_parallel_world_size_with_group,
)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
vllm__model_executor__parameter__BasevLLMParameter____init__org = BasevLLMParameter.__init__
def vllm__model_executor__parameter__BasevLLMParameter____init__(
self,
data: torch.Tensor,
weight_loader: Callable,
tp_group: Any = None
):
vllm__model_executor__parameter__BasevLLMParameter____init__org(
self, data, weight_loader
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add self.tp_group, world_size and tp_rank
'''
if tp_group is not None:
self.tp_group = tp_group
self.tp_world_size = get_parallel_world_size_with_group(self.tp_group)
self.tp_rank = get_parallel_rank_with_group(self.tp_group)
'''
=================
End of MLU Hijack
=================
'''
MluHijackObject.apply_hijack(BasevLLMParameter,
BasevLLMParameter.__init__,
vllm__model_executor__parameter__BasevLLMParameter____init__)