init
This commit is contained in:
299
vllm_vacc/patch_util.py
Normal file
299
vllm_vacc/patch_util.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import importlib
|
||||
import sys
|
||||
import pkgutil
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
PATCH_ROOT = "vllm_vacc."
|
||||
|
||||
|
||||
def get_func_name(func):
|
||||
if isinstance(func, str):
|
||||
return func
|
||||
return ".".join((func.__module__, func.__qualname__))
|
||||
|
||||
|
||||
def dummy_function_wrapper(func_name):
|
||||
def dummy_function(*args, **kwargs):
|
||||
raise RuntimeError(f"function {func_name} no exist")
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
def dummy_jit(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Patch:
|
||||
def __init__(self, orig_func_name, new_func, create_dummy):
|
||||
split_name = orig_func_name.rsplit(".", 1)
|
||||
if len(split_name) == 1:
|
||||
self.orig_module_name, self.orig_func_name = orig_func_name, None
|
||||
else:
|
||||
self.orig_module_name, self.orig_func_name = split_name
|
||||
self.orig_module = None
|
||||
self.orig_func = None
|
||||
|
||||
self.patch_func = None
|
||||
self.wrappers = []
|
||||
if new_func is None:
|
||||
new_func = dummy_function_wrapper(orig_func_name)
|
||||
self.set_patch_func(new_func)
|
||||
self.is_applied = False
|
||||
self.create_dummy = create_dummy
|
||||
|
||||
@property
|
||||
def orig_func_id(self):
|
||||
return id(self.orig_func)
|
||||
|
||||
@property
|
||||
def patch_func_id(self):
|
||||
return id(self.patch_func)
|
||||
|
||||
def set_patch_func(self, new_func, force_patch=False):
|
||||
if hasattr(new_func, "__name__") and new_func.__name__.endswith(
|
||||
("wrapper", "decorator")
|
||||
):
|
||||
self.wrappers.append(new_func)
|
||||
else:
|
||||
if self.patch_func and not force_patch:
|
||||
raise RuntimeError(
|
||||
f"The patch of '{self.orig_func_name}' ('{self.patch_func}') exist!"
|
||||
)
|
||||
self.patch_func = new_func
|
||||
self.is_applied = False
|
||||
|
||||
def apply_patch(self):
|
||||
if self.is_applied:
|
||||
return
|
||||
|
||||
self.orig_module, self.orig_func = Patch.parse_path(
|
||||
self.orig_module_name, self.orig_func_name, self.create_dummy
|
||||
)
|
||||
if self.patch_func is None:
|
||||
self.patch_func = self.orig_func
|
||||
|
||||
for wrapper in self.wrappers:
|
||||
self.patch_func = wrapper(self.patch_func)
|
||||
|
||||
if self.orig_func_name is not None:
|
||||
setattr(self.orig_module, self.orig_func_name, self.patch_func)
|
||||
for key, value in sys.modules.copy().items():
|
||||
# 遍历 pip 所有库, 然后 setattr, 有些库不匹配 可能会有问题, 这里是否可以优化 只遍历vllm相关
|
||||
try:
|
||||
if (
|
||||
self.orig_func_name is not None
|
||||
and hasattr(value, self.orig_func_name)
|
||||
and id(getattr(value, self.orig_func_name)) == self.orig_func_id
|
||||
):
|
||||
setattr(value, self.orig_func_name, self.patch_func)
|
||||
except:
|
||||
continue
|
||||
|
||||
self.is_applied = True
|
||||
|
||||
@staticmethod
|
||||
def parse_function(function_path: str, create_dummy):
|
||||
split_name = function_path.rsplit(".", 1)
|
||||
if len(split_name) == 1:
|
||||
orig_module_name, orig_func_name = function_path, None
|
||||
else:
|
||||
orig_module_name, orig_func_name = split_name
|
||||
return Patch.parse_path(orig_module_name, orig_func_name, create_dummy)[1]
|
||||
|
||||
@staticmethod
|
||||
def parse_path(module_path, function_name, create_dummy):
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
modules = module_path.split(".")
|
||||
for i in range(1, len(modules) + 1):
|
||||
parent = ".".join(modules[: i - 1])
|
||||
path = ".".join(modules[:i])
|
||||
try:
|
||||
importlib.import_module(path)
|
||||
except ModuleNotFoundError as e:
|
||||
if not parent or not hasattr(
|
||||
importlib.import_module(parent), modules[i - 1]
|
||||
):
|
||||
if not create_dummy:
|
||||
raise ModuleNotFoundError(e) from e
|
||||
sys.modules[path] = types.ModuleType(path)
|
||||
sys.modules[path].__file__ = "patch_tools.dummy_module.py"
|
||||
sys.modules[path].__spec__ = ModuleSpec(path, None)
|
||||
if parent:
|
||||
setattr(
|
||||
importlib.import_module(parent),
|
||||
modules[i - 1],
|
||||
sys.modules[path],
|
||||
)
|
||||
else:
|
||||
module = getattr(importlib.import_module(parent), modules[i - 1])
|
||||
if hasattr(module, function_name):
|
||||
return module, getattr(module, function_name)
|
||||
elif create_dummy:
|
||||
return module, dummy_function_wrapper(function_name)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Function '{function_name}' of '{module}' does not exist."
|
||||
) from e
|
||||
|
||||
if function_name is not None and not hasattr(
|
||||
sys.modules[module_path], function_name
|
||||
):
|
||||
setattr(sys.modules[module_path], function_name, None)
|
||||
return sys.modules[module_path], (
|
||||
getattr(sys.modules[module_path], function_name)
|
||||
if function_name is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
class PatchManager:
|
||||
patches_info: dict = {}
|
||||
patched: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_patch_info(cls):
|
||||
return cls.patches_info
|
||||
|
||||
@classmethod
|
||||
def register_patch(
|
||||
cls,
|
||||
orig_func_name,
|
||||
new_func=None,
|
||||
force_patch=False,
|
||||
create_dummy=False,
|
||||
allow_create=False,
|
||||
):
|
||||
"""
|
||||
if new_func is written via @wraps, its name must be ended with `wrapper` or `decorator`,
|
||||
also if it ends with `wrapper` or `decorator`, it must be written via `@wraps`
|
||||
"""
|
||||
if not cls._path_valid(orig_func_name):
|
||||
if not allow_create:
|
||||
raise ValueError(
|
||||
f"Module/function path '{orig_func_name}' does not exist, and allow_create=False."
|
||||
)
|
||||
|
||||
# if not create_dummy and not cls._path_valid(orig_func_name):
|
||||
# print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
|
||||
# return
|
||||
|
||||
patch_info = cls.get_patch_info()
|
||||
if orig_func_name not in patch_info:
|
||||
patch_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
|
||||
else:
|
||||
patch_info[orig_func_name].set_patch_func(new_func, force_patch)
|
||||
|
||||
@classmethod
|
||||
def _is_module(self, module_path):
|
||||
try:
|
||||
importlib.import_module(module_path)
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def recursive_register_module(cls, module_path, allow_create=False):
|
||||
# replace whole submodule in a module
|
||||
assert cls._is_module(
|
||||
PATCH_ROOT + module_path
|
||||
), f'"{PATCH_ROOT}{module_path}" is not a valid module path. Only use this function to register module patches (not functions or classes).'
|
||||
|
||||
all_submodules = cls.enumerate_submodules(PATCH_ROOT + module_path)
|
||||
all_submodules = [submodule[len(PATCH_ROOT) :] for submodule in all_submodules]
|
||||
all_submodules = [module_path] + all_submodules
|
||||
|
||||
for module in all_submodules:
|
||||
try:
|
||||
importlib.import_module(module)
|
||||
except ModuleNotFoundError as e:
|
||||
pass
|
||||
sys.modules[module] = importlib.import_module(PATCH_ROOT + module)
|
||||
|
||||
@classmethod
|
||||
def batch_recursive_register_module(cls, module_paths, allow_create=False):
|
||||
for module_path in module_paths:
|
||||
cls.recursive_register_module(module_path, allow_create=allow_create)
|
||||
|
||||
@classmethod
|
||||
def enumerate_submodules(cls, module_name):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
print(f"Error importing module {module_name}: {e}")
|
||||
return []
|
||||
|
||||
submodules = []
|
||||
for loader, submodule_name, is_pkg in pkgutil.walk_packages(
|
||||
module.__path__, module.__name__ + "."
|
||||
):
|
||||
submodules.append(submodule_name)
|
||||
|
||||
return submodules
|
||||
|
||||
@classmethod
|
||||
def batch_register_patch(
|
||||
cls,
|
||||
orig_func_names: List,
|
||||
force_patch=False,
|
||||
create_dummy=False,
|
||||
allow_create=False,
|
||||
):
|
||||
"""
|
||||
This function assumes all new_func are organized in same path like orig_func_names except prefixed with 'vastext.'
|
||||
"""
|
||||
for orig_func_name in orig_func_names:
|
||||
if not cls._path_valid(orig_func_name) and not allow_create:
|
||||
print(f"WARNING: path '{orig_func_name}' is not valid, skipped.")
|
||||
continue
|
||||
|
||||
wrapper_name = orig_func_name
|
||||
if cls._path_valid(PATCH_ROOT + wrapper_name + "_wrapper"):
|
||||
wrapper_name = wrapper_name + "_wrapper"
|
||||
assert cls._path_valid(
|
||||
PATCH_ROOT + wrapper_name
|
||||
), f"'{PATCH_ROOT}{wrapper_name}' or '{PATCH_ROOT}{wrapper_name}_wrapper' must be a valid module/function path. Try import {PATCH_ROOT}{wrapper_name} to see if other errors exist."
|
||||
new_func = Patch.parse_function(
|
||||
PATCH_ROOT + wrapper_name, create_dummy=False
|
||||
)
|
||||
if new_func is None:
|
||||
new_func = importlib.import_module(PATCH_ROOT + wrapper_name)
|
||||
# print(f">>> Register patch or function: '{orig_func_name}' -> '{new_func}'")
|
||||
cls.register_patch(
|
||||
orig_func_name,
|
||||
new_func,
|
||||
force_patch=force_patch,
|
||||
create_dummy=create_dummy,
|
||||
allow_create=allow_create,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _path_valid(cls, path):
|
||||
components = path.split(".")
|
||||
for i in range(len(components), 0, -1):
|
||||
module_name = ".".join(components[:i])
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
break
|
||||
except ImportError:
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
|
||||
for component in components[i:]:
|
||||
if hasattr(module, component):
|
||||
module = getattr(module, component)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def apply_patches(cls):
|
||||
patch_info = cls.get_patch_info()
|
||||
for patch in patch_info.values():
|
||||
patch.apply_patch()
|
||||
cls.patched = True
|
||||
Reference in New Issue
Block a user