Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_weight_utils.py
Angazenn ec34bf0062 [Misc]fix logger which does not take effects in patches (#7402)
### What this PR does / why we need it?
This PR fixes the logger initialization in patches so that the log info
can be displayed as expected.

### Does this PR introduce _any_ user-facing change?
No.

- vLLM version: v0.17.0
- vLLM main:
4497431df6

---------

Signed-off-by: Angazenn <supperccell@163.com>
2026-03-18 17:13:12 +08:00

85 lines
2.7 KiB
Python

import sys
from typing import Any
from vllm.logger import logger
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
class ImportPatchDecorator:
"""Import patch decorator"""
_patches: dict[str, Any] = {}
@classmethod
def register(cls, module_name):
"""Decorator for registering module patches"""
def decorator(func):
cls._patches[module_name] = func
return func
return decorator
@classmethod
def apply_patches(cls):
"""Apply all patches"""
for module_name, patch_func in cls._patches.items():
if module_name in sys.modules:
module = sys.modules[module_name]
try:
patch_func(module)
except Exception as e:
logger.error(f"Patch application failed {module_name}: {e}")
@ImportPatchDecorator.register("vllm.model_executor.models.deepseek_v2")
def patch_deepseek(module):
ori_maybe_remap_kv_scale_name = maybe_remap_kv_scale_name
def new_remap(name: str, params_dict: dict):
name = ori_maybe_remap_kv_scale_name(name, params_dict)
replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"]
for scale_name in replace_scale_names:
if name.endswith(scale_name):
remap_name = name.replace(scale_name, f"mla_attn.mla_attn.{scale_name}")
if remap_name in params_dict:
return remap_name
else:
return remap_name.replace(".mla_attn", "")
return name
if hasattr(module, "maybe_remap_kv_scale_name"):
module._original_maybe_remap_kv_scale_name = module.maybe_remap_kv_scale_name
module.maybe_remap_kv_scale_name = new_remap
@ImportPatchDecorator.register("vllm.model_executor.model_loader.weight_utils")
def patch_weight_utils(module):
if "vllm.model_executor.models.deepseek_v2" in sys.modules:
deepseek = sys.modules["vllm.model_executor.models.deepseek_v2"]
if hasattr(deepseek, "maybe_remap_kv_scale_name"):
module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name
original_import = __builtins__["__import__"] # type: ignore
def patched_import(name, globals=None, locals=None, fromlist=(), level=0):
module = original_import(name, globals, locals, fromlist, level)
if name in ImportPatchDecorator._patches:
try:
ImportPatchDecorator._patches[name](module)
except Exception as e:
logger.error(f"Patch application failed during import {name}: {e}")
return module
__builtins__["__import__"] = patched_import
ImportPatchDecorator.apply_patches()