Files
2026-04-02 04:55:00 +00:00

300 lines
10 KiB
Python

import importlib
import sys
import pkgutil
import types
from typing import List
PATCH_ROOT = "vllm_vacc."
def get_func_name(func):
if isinstance(func, str):
return func
return ".".join((func.__module__, func.__qualname__))
def dummy_function_wrapper(func_name):
def dummy_function(*args, **kwargs):
raise RuntimeError(f"function {func_name} no exist")
return dummy_function
def dummy_jit(fn):
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
class Patch:
def __init__(self, orig_func_name, new_func, create_dummy):
split_name = orig_func_name.rsplit(".", 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None
else:
self.orig_module_name, self.orig_func_name = split_name
self.orig_module = None
self.orig_func = None
self.patch_func = None
self.wrappers = []
if new_func is None:
new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_id(self):
return id(self.orig_func)
@property
def patch_func_id(self):
return id(self.patch_func)
def set_patch_func(self, new_func, force_patch=False):
if hasattr(new_func, "__name__") and new_func.__name__.endswith(
("wrapper", "decorator")
):
self.wrappers.append(new_func)
else:
if self.patch_func and not force_patch:
raise RuntimeError(
f"The patch of '{self.orig_func_name}' ('{self.patch_func}') exist!"
)
self.patch_func = new_func
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
self.orig_module, self.orig_func = Patch.parse_path(
self.orig_module_name, self.orig_func_name, self.create_dummy
)
if self.patch_func is None:
self.patch_func = self.orig_func
for wrapper in self.wrappers:
self.patch_func = wrapper(self.patch_func)
if self.orig_func_name is not None:
setattr(self.orig_module, self.orig_func_name, self.patch_func)
for key, value in sys.modules.copy().items():
# 遍历 pip 所有库, 然后 setattr, 有些库不匹配 可能会有问题, 这里是否可以优化 只遍历vllm相关
try:
if (
self.orig_func_name is not None
and hasattr(value, self.orig_func_name)
and id(getattr(value, self.orig_func_name)) == self.orig_func_id
):
setattr(value, self.orig_func_name, self.patch_func)
except:
continue
self.is_applied = True
@staticmethod
def parse_function(function_path: str, create_dummy):
split_name = function_path.rsplit(".", 1)
if len(split_name) == 1:
orig_module_name, orig_func_name = function_path, None
else:
orig_module_name, orig_func_name = split_name
return Patch.parse_path(orig_module_name, orig_func_name, create_dummy)[1]
@staticmethod
def parse_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
modules = module_path.split(".")
for i in range(1, len(modules) + 1):
parent = ".".join(modules[: i - 1])
path = ".".join(modules[:i])
try:
importlib.import_module(path)
except ModuleNotFoundError as e:
if not parent or not hasattr(
importlib.import_module(parent), modules[i - 1]
):
if not create_dummy:
raise ModuleNotFoundError(e) from e
sys.modules[path] = types.ModuleType(path)
sys.modules[path].__file__ = "patch_tools.dummy_module.py"
sys.modules[path].__spec__ = ModuleSpec(path, None)
if parent:
setattr(
importlib.import_module(parent),
modules[i - 1],
sys.modules[path],
)
else:
module = getattr(importlib.import_module(parent), modules[i - 1])
if hasattr(module, function_name):
return module, getattr(module, function_name)
elif create_dummy:
return module, dummy_function_wrapper(function_name)
else:
raise RuntimeError(
f"Function '{function_name}' of '{module}' does not exist."
) from e
if function_name is not None and not hasattr(
sys.modules[module_path], function_name
):
setattr(sys.modules[module_path], function_name, None)
return sys.modules[module_path], (
getattr(sys.modules[module_path], function_name)
if function_name is not None
else None
)
class PatchManager:
patches_info: dict = {}
patched: bool = False
@classmethod
def get_patch_info(cls):
return cls.patches_info
@classmethod
def register_patch(
cls,
orig_func_name,
new_func=None,
force_patch=False,
create_dummy=False,
allow_create=False,
):
"""
if new_func is written via @wraps, its name must be ended with `wrapper` or `decorator`,
also if it ends with `wrapper` or `decorator`, it must be written via `@wraps`
"""
if not cls._path_valid(orig_func_name):
if not allow_create:
raise ValueError(
f"Module/function path '{orig_func_name}' does not exist, and allow_create=False."
)
# if not create_dummy and not cls._path_valid(orig_func_name):
# print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
# return
patch_info = cls.get_patch_info()
if orig_func_name not in patch_info:
patch_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
else:
patch_info[orig_func_name].set_patch_func(new_func, force_patch)
@classmethod
def _is_module(self, module_path):
try:
importlib.import_module(module_path)
return True
except ModuleNotFoundError:
return False
@classmethod
def recursive_register_module(cls, module_path, allow_create=False):
# replace whole submodule in a module
assert cls._is_module(
PATCH_ROOT + module_path
), f'"{PATCH_ROOT}{module_path}" is not a valid module path. Only use this function to register module patches (not functions or classes).'
all_submodules = cls.enumerate_submodules(PATCH_ROOT + module_path)
all_submodules = [submodule[len(PATCH_ROOT) :] for submodule in all_submodules]
all_submodules = [module_path] + all_submodules
for module in all_submodules:
try:
importlib.import_module(module)
except ModuleNotFoundError as e:
pass
sys.modules[module] = importlib.import_module(PATCH_ROOT + module)
@classmethod
def batch_recursive_register_module(cls, module_paths, allow_create=False):
for module_path in module_paths:
cls.recursive_register_module(module_path, allow_create=allow_create)
@classmethod
def enumerate_submodules(cls, module_name):
try:
module = importlib.import_module(module_name)
except ImportError as e:
print(f"Error importing module {module_name}: {e}")
return []
submodules = []
for loader, submodule_name, is_pkg in pkgutil.walk_packages(
module.__path__, module.__name__ + "."
):
submodules.append(submodule_name)
return submodules
@classmethod
def batch_register_patch(
cls,
orig_func_names: List,
force_patch=False,
create_dummy=False,
allow_create=False,
):
"""
This function assumes all new_func are organized in same path like orig_func_names except prefixed with 'vastext.'
"""
for orig_func_name in orig_func_names:
if not cls._path_valid(orig_func_name) and not allow_create:
print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
continue
wrapper_name = orig_func_name
if cls._path_valid(PATCH_ROOT + wrapper_name + "_wrapper"):
wrapper_name = wrapper_name + "_wrapper"
assert cls._path_valid(
PATCH_ROOT + wrapper_name
), f"'{PATCH_ROOT}{wrapper_name}' or '{PATCH_ROOT}{wrapper_name}_wrapper' must be a valid module/function path. Try import {PATCH_ROOT}{wrapper_name} to see if other errors exist."
new_func = Patch.parse_function(
PATCH_ROOT + wrapper_name, create_dummy=False
)
if new_func is None:
new_func = importlib.import_module(PATCH_ROOT + wrapper_name)
# print(f">>> Register patch or function: '{orig_func_name}' -> '{new_func}'")
cls.register_patch(
orig_func_name,
new_func,
force_patch=force_patch,
create_dummy=create_dummy,
allow_create=allow_create,
)
@classmethod
def _path_valid(cls, path):
components = path.split(".")
for i in range(len(components), 0, -1):
module_name = ".".join(components[:i])
try:
module = importlib.import_module(module_name)
break
except ImportError:
continue
else:
return False
for component in components[i:]:
if hasattr(module, component):
module = getattr(module, component)
else:
return False
return True
@classmethod
def apply_patches(cls):
patch_info = cls.get_patch_info()
for patch in patch_info.values():
patch.apply_patch()
cls.patched = True