300 lines
10 KiB
Python
300 lines
10 KiB
Python
import importlib
|
|
import sys
|
|
import pkgutil
|
|
import types
|
|
from typing import List
|
|
|
|
PATCH_ROOT = "vllm_vacc."
|
|
|
|
|
|
def get_func_name(func):
|
|
if isinstance(func, str):
|
|
return func
|
|
return ".".join((func.__module__, func.__qualname__))
|
|
|
|
|
|
def dummy_function_wrapper(func_name):
|
|
def dummy_function(*args, **kwargs):
|
|
raise RuntimeError(f"function {func_name} no exist")
|
|
|
|
return dummy_function
|
|
|
|
|
|
def dummy_jit(fn):
|
|
def wrapper(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class Patch:
|
|
def __init__(self, orig_func_name, new_func, create_dummy):
|
|
split_name = orig_func_name.rsplit(".", 1)
|
|
if len(split_name) == 1:
|
|
self.orig_module_name, self.orig_func_name = orig_func_name, None
|
|
else:
|
|
self.orig_module_name, self.orig_func_name = split_name
|
|
self.orig_module = None
|
|
self.orig_func = None
|
|
|
|
self.patch_func = None
|
|
self.wrappers = []
|
|
if new_func is None:
|
|
new_func = dummy_function_wrapper(orig_func_name)
|
|
self.set_patch_func(new_func)
|
|
self.is_applied = False
|
|
self.create_dummy = create_dummy
|
|
|
|
@property
|
|
def orig_func_id(self):
|
|
return id(self.orig_func)
|
|
|
|
@property
|
|
def patch_func_id(self):
|
|
return id(self.patch_func)
|
|
|
|
def set_patch_func(self, new_func, force_patch=False):
|
|
if hasattr(new_func, "__name__") and new_func.__name__.endswith(
|
|
("wrapper", "decorator")
|
|
):
|
|
self.wrappers.append(new_func)
|
|
else:
|
|
if self.patch_func and not force_patch:
|
|
raise RuntimeError(
|
|
f"The patch of '{self.orig_func_name}' ('{self.patch_func}') exist!"
|
|
)
|
|
self.patch_func = new_func
|
|
self.is_applied = False
|
|
|
|
def apply_patch(self):
|
|
if self.is_applied:
|
|
return
|
|
|
|
self.orig_module, self.orig_func = Patch.parse_path(
|
|
self.orig_module_name, self.orig_func_name, self.create_dummy
|
|
)
|
|
if self.patch_func is None:
|
|
self.patch_func = self.orig_func
|
|
|
|
for wrapper in self.wrappers:
|
|
self.patch_func = wrapper(self.patch_func)
|
|
|
|
if self.orig_func_name is not None:
|
|
setattr(self.orig_module, self.orig_func_name, self.patch_func)
|
|
for key, value in sys.modules.copy().items():
|
|
# 遍历 pip 所有库, 然后 setattr, 有些库不匹配 可能会有问题, 这里是否可以优化 只遍历vllm相关
|
|
try:
|
|
if (
|
|
self.orig_func_name is not None
|
|
and hasattr(value, self.orig_func_name)
|
|
and id(getattr(value, self.orig_func_name)) == self.orig_func_id
|
|
):
|
|
setattr(value, self.orig_func_name, self.patch_func)
|
|
except:
|
|
continue
|
|
|
|
self.is_applied = True
|
|
|
|
@staticmethod
|
|
def parse_function(function_path: str, create_dummy):
|
|
split_name = function_path.rsplit(".", 1)
|
|
if len(split_name) == 1:
|
|
orig_module_name, orig_func_name = function_path, None
|
|
else:
|
|
orig_module_name, orig_func_name = split_name
|
|
return Patch.parse_path(orig_module_name, orig_func_name, create_dummy)[1]
|
|
|
|
@staticmethod
|
|
def parse_path(module_path, function_name, create_dummy):
|
|
from importlib.machinery import ModuleSpec
|
|
|
|
modules = module_path.split(".")
|
|
for i in range(1, len(modules) + 1):
|
|
parent = ".".join(modules[: i - 1])
|
|
path = ".".join(modules[:i])
|
|
try:
|
|
importlib.import_module(path)
|
|
except ModuleNotFoundError as e:
|
|
if not parent or not hasattr(
|
|
importlib.import_module(parent), modules[i - 1]
|
|
):
|
|
if not create_dummy:
|
|
raise ModuleNotFoundError(e) from e
|
|
sys.modules[path] = types.ModuleType(path)
|
|
sys.modules[path].__file__ = "patch_tools.dummy_module.py"
|
|
sys.modules[path].__spec__ = ModuleSpec(path, None)
|
|
if parent:
|
|
setattr(
|
|
importlib.import_module(parent),
|
|
modules[i - 1],
|
|
sys.modules[path],
|
|
)
|
|
else:
|
|
module = getattr(importlib.import_module(parent), modules[i - 1])
|
|
if hasattr(module, function_name):
|
|
return module, getattr(module, function_name)
|
|
elif create_dummy:
|
|
return module, dummy_function_wrapper(function_name)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Function '{function_name}' of '{module}' does not exist."
|
|
) from e
|
|
|
|
if function_name is not None and not hasattr(
|
|
sys.modules[module_path], function_name
|
|
):
|
|
setattr(sys.modules[module_path], function_name, None)
|
|
return sys.modules[module_path], (
|
|
getattr(sys.modules[module_path], function_name)
|
|
if function_name is not None
|
|
else None
|
|
)
|
|
|
|
|
|
class PatchManager:
|
|
patches_info: dict = {}
|
|
patched: bool = False
|
|
|
|
@classmethod
|
|
def get_patch_info(cls):
|
|
return cls.patches_info
|
|
|
|
@classmethod
|
|
def register_patch(
|
|
cls,
|
|
orig_func_name,
|
|
new_func=None,
|
|
force_patch=False,
|
|
create_dummy=False,
|
|
allow_create=False,
|
|
):
|
|
"""
|
|
if new_func is written via @wraps, its name must be ended with `wrapper` or `decorator`,
|
|
also if it ends with `wrapper` or `decorator`, it must be written via `@wraps`
|
|
"""
|
|
if not cls._path_valid(orig_func_name):
|
|
if not allow_create:
|
|
raise ValueError(
|
|
f"Module/function path '{orig_func_name}' does not exist, and allow_create=False."
|
|
)
|
|
|
|
# if not create_dummy and not cls._path_valid(orig_func_name):
|
|
# print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
|
|
# return
|
|
|
|
patch_info = cls.get_patch_info()
|
|
if orig_func_name not in patch_info:
|
|
patch_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
|
|
else:
|
|
patch_info[orig_func_name].set_patch_func(new_func, force_patch)
|
|
|
|
@classmethod
|
|
def _is_module(self, module_path):
|
|
try:
|
|
importlib.import_module(module_path)
|
|
return True
|
|
except ModuleNotFoundError:
|
|
return False
|
|
|
|
@classmethod
|
|
def recursive_register_module(cls, module_path, allow_create=False):
|
|
# replace whole submodule in a module
|
|
assert cls._is_module(
|
|
PATCH_ROOT + module_path
|
|
), f'"{PATCH_ROOT}{module_path}" is not a valid module path. Only use this function to register module patches (not functions or classes).'
|
|
|
|
all_submodules = cls.enumerate_submodules(PATCH_ROOT + module_path)
|
|
all_submodules = [submodule[len(PATCH_ROOT) :] for submodule in all_submodules]
|
|
all_submodules = [module_path] + all_submodules
|
|
|
|
for module in all_submodules:
|
|
try:
|
|
importlib.import_module(module)
|
|
except ModuleNotFoundError as e:
|
|
pass
|
|
sys.modules[module] = importlib.import_module(PATCH_ROOT + module)
|
|
|
|
@classmethod
|
|
def batch_recursive_register_module(cls, module_paths, allow_create=False):
|
|
for module_path in module_paths:
|
|
cls.recursive_register_module(module_path, allow_create=allow_create)
|
|
|
|
@classmethod
|
|
def enumerate_submodules(cls, module_name):
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
except ImportError as e:
|
|
print(f"Error importing module {module_name}: {e}")
|
|
return []
|
|
|
|
submodules = []
|
|
for loader, submodule_name, is_pkg in pkgutil.walk_packages(
|
|
module.__path__, module.__name__ + "."
|
|
):
|
|
submodules.append(submodule_name)
|
|
|
|
return submodules
|
|
|
|
@classmethod
|
|
def batch_register_patch(
|
|
cls,
|
|
orig_func_names: List,
|
|
force_patch=False,
|
|
create_dummy=False,
|
|
allow_create=False,
|
|
):
|
|
"""
|
|
This function assumes all new_func are organized in same path like orig_func_names except prefixed with 'vastext.'
|
|
"""
|
|
for orig_func_name in orig_func_names:
|
|
if not cls._path_valid(orig_func_name) and not allow_create:
|
|
print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
|
|
continue
|
|
|
|
wrapper_name = orig_func_name
|
|
if cls._path_valid(PATCH_ROOT + wrapper_name + "_wrapper"):
|
|
wrapper_name = wrapper_name + "_wrapper"
|
|
assert cls._path_valid(
|
|
PATCH_ROOT + wrapper_name
|
|
), f"'{PATCH_ROOT}{wrapper_name}' or '{PATCH_ROOT}{wrapper_name}_wrapper' must be a valid module/function path. Try import {PATCH_ROOT}{wrapper_name} to see if other errors exist."
|
|
new_func = Patch.parse_function(
|
|
PATCH_ROOT + wrapper_name, create_dummy=False
|
|
)
|
|
if new_func is None:
|
|
new_func = importlib.import_module(PATCH_ROOT + wrapper_name)
|
|
# print(f">>> Register patch or function: '{orig_func_name}' -> '{new_func}'")
|
|
cls.register_patch(
|
|
orig_func_name,
|
|
new_func,
|
|
force_patch=force_patch,
|
|
create_dummy=create_dummy,
|
|
allow_create=allow_create,
|
|
)
|
|
|
|
@classmethod
|
|
def _path_valid(cls, path):
|
|
components = path.split(".")
|
|
for i in range(len(components), 0, -1):
|
|
module_name = ".".join(components[:i])
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
break
|
|
except ImportError:
|
|
continue
|
|
else:
|
|
return False
|
|
|
|
for component in components[i:]:
|
|
if hasattr(module, component):
|
|
module = getattr(module, component)
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def apply_patches(cls):
|
|
patch_info = cls.get_patch_info()
|
|
for patch in patch_info.values():
|
|
patch.apply_patch()
|
|
cls.patched = True
|