import importlib import sys import pkgutil import types from typing import List PATCH_ROOT = "vllm_vacc." def get_func_name(func): if isinstance(func, str): return func return ".".join((func.__module__, func.__qualname__)) def dummy_function_wrapper(func_name): def dummy_function(*args, **kwargs): raise RuntimeError(f"function {func_name} no exist") return dummy_function def dummy_jit(fn): def wrapper(*args, **kwargs): return fn(*args, **kwargs) return wrapper class Patch: def __init__(self, orig_func_name, new_func, create_dummy): split_name = orig_func_name.rsplit(".", 1) if len(split_name) == 1: self.orig_module_name, self.orig_func_name = orig_func_name, None else: self.orig_module_name, self.orig_func_name = split_name self.orig_module = None self.orig_func = None self.patch_func = None self.wrappers = [] if new_func is None: new_func = dummy_function_wrapper(orig_func_name) self.set_patch_func(new_func) self.is_applied = False self.create_dummy = create_dummy @property def orig_func_id(self): return id(self.orig_func) @property def patch_func_id(self): return id(self.patch_func) def set_patch_func(self, new_func, force_patch=False): if hasattr(new_func, "__name__") and new_func.__name__.endswith( ("wrapper", "decorator") ): self.wrappers.append(new_func) else: if self.patch_func and not force_patch: raise RuntimeError( f"The patch of '{self.orig_func_name}' ('{self.patch_func}') exist!" ) self.patch_func = new_func self.is_applied = False def apply_patch(self): if self.is_applied: return self.orig_module, self.orig_func = Patch.parse_path( self.orig_module_name, self.orig_func_name, self.create_dummy ) if self.patch_func is None: self.patch_func = self.orig_func for wrapper in self.wrappers: self.patch_func = wrapper(self.patch_func) if self.orig_func_name is not None: setattr(self.orig_module, self.orig_func_name, self.patch_func) for key, value in sys.modules.copy().items(): # 遍历 pip 所有库, 然后 setattr, 有些库不匹配 可能会有问题, 这里是否可以优化 只遍历vllm相关 try: if ( self.orig_func_name is not None and hasattr(value, self.orig_func_name) and id(getattr(value, self.orig_func_name)) == self.orig_func_id ): setattr(value, self.orig_func_name, self.patch_func) except: continue self.is_applied = True @staticmethod def parse_function(function_path: str, create_dummy): split_name = function_path.rsplit(".", 1) if len(split_name) == 1: orig_module_name, orig_func_name = function_path, None else: orig_module_name, orig_func_name = split_name return Patch.parse_path(orig_module_name, orig_func_name, create_dummy)[1] @staticmethod def parse_path(module_path, function_name, create_dummy): from importlib.machinery import ModuleSpec modules = module_path.split(".") for i in range(1, len(modules) + 1): parent = ".".join(modules[: i - 1]) path = ".".join(modules[:i]) try: importlib.import_module(path) except ModuleNotFoundError as e: if not parent or not hasattr( importlib.import_module(parent), modules[i - 1] ): if not create_dummy: raise ModuleNotFoundError(e) from e sys.modules[path] = types.ModuleType(path) sys.modules[path].__file__ = "patch_tools.dummy_module.py" sys.modules[path].__spec__ = ModuleSpec(path, None) if parent: setattr( importlib.import_module(parent), modules[i - 1], sys.modules[path], ) else: module = getattr(importlib.import_module(parent), modules[i - 1]) if hasattr(module, function_name): return module, getattr(module, function_name) elif create_dummy: return module, dummy_function_wrapper(function_name) else: raise RuntimeError( f"Function '{function_name}' of '{module}' does not exist." ) from e if function_name is not None and not hasattr( sys.modules[module_path], function_name ): setattr(sys.modules[module_path], function_name, None) return sys.modules[module_path], ( getattr(sys.modules[module_path], function_name) if function_name is not None else None ) class PatchManager: patches_info: dict = {} patched: bool = False @classmethod def get_patch_info(cls): return cls.patches_info @classmethod def register_patch( cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, allow_create=False, ): """ if new_func is written via @wraps, its name must be ended with `wrapper` or `decorator`, also if it ends with `wrapper` or `decorator`, it must be written via `@wraps` """ if not cls._path_valid(orig_func_name): if not allow_create: raise ValueError( f"Module/function path '{orig_func_name}' does not exist, and allow_create=False." ) # if not create_dummy and not cls._path_valid(orig_func_name): # print(f"WARNING: path '{orig_func_name}' is not valid, skipped.") # return patch_info = cls.get_patch_info() if orig_func_name not in patch_info: patch_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy) else: patch_info[orig_func_name].set_patch_func(new_func, force_patch) @classmethod def _is_module(self, module_path): try: importlib.import_module(module_path) return True except ModuleNotFoundError: return False @classmethod def recursive_register_module(cls, module_path, allow_create=False): # replace whole submodule in a module assert cls._is_module( PATCH_ROOT + module_path ), f'"{PATCH_ROOT}{module_path}" is not a valid module path. Only use this function to register module patches (not functions or classes).' all_submodules = cls.enumerate_submodules(PATCH_ROOT + module_path) all_submodules = [submodule[len(PATCH_ROOT) :] for submodule in all_submodules] all_submodules = [module_path] + all_submodules for module in all_submodules: try: importlib.import_module(module) except ModuleNotFoundError as e: pass sys.modules[module] = importlib.import_module(PATCH_ROOT + module) @classmethod def batch_recursive_register_module(cls, module_paths, allow_create=False): for module_path in module_paths: cls.recursive_register_module(module_path, allow_create=allow_create) @classmethod def enumerate_submodules(cls, module_name): try: module = importlib.import_module(module_name) except ImportError as e: print(f"Error importing module {module_name}: {e}") return [] submodules = [] for loader, submodule_name, is_pkg in pkgutil.walk_packages( module.__path__, module.__name__ + "." ): submodules.append(submodule_name) return submodules @classmethod def batch_register_patch( cls, orig_func_names: List, force_patch=False, create_dummy=False, allow_create=False, ): """ This function assumes all new_func are organized in same path like orig_func_names except prefixed with 'vastext.' """ for orig_func_name in orig_func_names: if not cls._path_valid(orig_func_name) and not allow_create: print(f"WARNING: path '{orig_func_name}' is not valid, skipped.") continue wrapper_name = orig_func_name if cls._path_valid(PATCH_ROOT + wrapper_name + "_wrapper"): wrapper_name = wrapper_name + "_wrapper" assert cls._path_valid( PATCH_ROOT + wrapper_name ), f"'{PATCH_ROOT}{wrapper_name}' or '{PATCH_ROOT}{wrapper_name}_wrapper' must be a valid module/function path. Try import {PATCH_ROOT}{wrapper_name} to see if other errors exist." new_func = Patch.parse_function( PATCH_ROOT + wrapper_name, create_dummy=False ) if new_func is None: new_func = importlib.import_module(PATCH_ROOT + wrapper_name) # print(f">>> Register patch or function: '{orig_func_name}' -> '{new_func}'") cls.register_patch( orig_func_name, new_func, force_patch=force_patch, create_dummy=create_dummy, allow_create=allow_create, ) @classmethod def _path_valid(cls, path): components = path.split(".") for i in range(len(components), 0, -1): module_name = ".".join(components[:i]) try: module = importlib.import_module(module_name) break except ImportError: continue else: return False for component in components[i:]: if hasattr(module, component): module = getattr(module, component) else: return False return True @classmethod def apply_patches(cls): patch_info = cls.get_patch_info() for patch in patch_info.values(): patch.apply_patch() cls.patched = True