Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/mlu_hijack.py
2026-02-04 17:22:39 +08:00

95 lines
3.1 KiB
Python

import logging
from logging import Logger
from transformers import AutoConfig
from vllm.model_executor.models import ModelRegistry
from vllm_mlu.model_executor.custom_model.custom import CustomForCausalLM
from vllm_mlu.transformers_utils.configs import CustomConfig
from vllm_mlu._mlu_utils import *
def mlu_init_logger(name: str) -> Logger:
"""Initialize loggers for vllm_mlu module,
and keep the configuration consistent with the vllm module"""
mlu_logger = logging.getLogger(name)
vllm_logger = logging.Logger.manager.loggerDict.get('vllm', None)
if vllm_logger:
mlu_logger.setLevel(vllm_logger.level)
mlu_logger.propagate = vllm_logger.propagate
mlu_logger.handlers = vllm_logger.handlers
return mlu_logger
from vllm import logger
logger.init_logger = mlu_init_logger
from vllm.logger import init_logger
logger = init_logger(__name__)
if USE_PAGED:
logger.info(f"Run vLLM in paged mode, Apply MLU optimization !")
else:
logger.info(f"Run vLLM in unpaged mode, Apply MLU optimization")
import vllm_mlu.config
import vllm_mlu.utils
import vllm_mlu.attention
import vllm_mlu.core
if VLLM_SCHEDULER_PROFILE:
import vllm_mlu.core.scheduler
import vllm_mlu.engine.async_llm_engine
import vllm_mlu.engine.multiprocessing.client
import vllm_mlu.engine.multiprocessing.engine
import vllm_mlu.entrypoints.openai.serving_engine
import vllm_mlu.distributed
import vllm_mlu.engine
import vllm_mlu.entrypoints
import vllm_mlu.executor
import vllm_mlu.lora
import vllm_mlu.model_executor
import vllm_mlu.worker
if VLLM_PRELOAD_SIZE > 0:
logger.info("Apply feature -> Preload Weight !")
import vllm_mlu.mlu_custom.preload
import vllm_mlu.mlu_custom.common
if check_context_comm_cmpt_parallel():
logger.info("Apply feature -> Context Communication Computation Parallel !")
import vllm_mlu.mlu_custom.context_comm_cmpt_parallel
import vllm_mlu.mlu_custom.common
AutoConfig.register("custom", CustomConfig)
ModelRegistry.register_model("CustomForCausalLM", CustomForCausalLM)
def import_cambricon_custom_func(extra_module_path :str):
import importlib
import os
import sys
from pathlib import Path
file_path = Path(os.path.abspath(__file__))
vllm_dir = file_path.parent.parent.parent
sys.path.insert(0, str(vllm_dir))
importlib.import_module(extra_module_path)
# import here to ensure every worker can import custom vllm hijack
if CUSTOM_VLLM_HIJACK_EN:
import_cambricon_custom_func("examples.cambricon_custom_func.vllm.mlu_hijack.mlu_hijack")
# import here to ensure every worker can import chunked pipline parallel hijack.
if CHUNKED_PIPELINE_PARALLEL_EN:
import_cambricon_custom_func("examples.cambricon_custom_func.chunked_pipeline_parallel.mlu_hijack.mlu_hijack")
# import here to ensure every worker can import context parallel hijack
if CONTEXT_PARALLEL_EN:
import_cambricon_custom_func("examples.cambricon_custom_func.context_parallel.mlu_hijack.mlu_hijack")
# import here to ensure every worker can import expert parallel hijack
if EXPERT_PARALLEL_EN:
import_cambricon_custom_func("examples.cambricon_custom_func.expert_parallel.mlu_hijack.mlu_hijack")