Files
xc-llm-kunlun/vllm_kunlun/utils.py

354 lines
14 KiB
Python
Raw Normal View History

2025-12-10 12:05:39 +08:00
import os, sys
import vllm
from torch.utils._python_dispatch import TorchDispatchMode
2025-12-10 17:51:24 +08:00
import vllm_kunlun.platforms.envs as xenvs
2025-12-10 12:05:39 +08:00
from vllm.utils import weak_ref_tensor
2025-12-10 17:51:24 +08:00
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, TypeVar, Union, cast, overload,
get_origin, get_args, List)
2025-12-10 12:05:39 +08:00
import torch
from torch.library import Library
import inspect
import typing
def redirect_output():
"""
2025-12-10 17:51:24 +08:00
重定向输出到指定目录并将日志文件命名为pp=0_rank=X或pp=1_rank=X
如果是第一个进程组的第一个进程则使用pp=0否则使用pp=1
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
无参数
2025-12-10 12:05:39 +08:00
Returns:
2025-12-10 17:51:24 +08:00
无返回值直接修改sys.stdout和sys.stderr的文件描述符
2025-12-10 12:05:39 +08:00
"""
from vllm.distributed import get_tensor_model_parallel_rank, get_pp_group
rank = get_tensor_model_parallel_rank()
dir_path = xenvs.VLLM_MULTI_LOGPATH
os.makedirs(dir_path, exist_ok=True)
if get_pp_group().is_first_rank:
log_file = os.path.join(dir_path, f"pp=0_rank={rank}.log")
else:
log_file = os.path.join(dir_path, f"pp=1_rank={rank}.log")
2025-12-10 17:51:24 +08:00
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT| os.O_TRUNC, 0o644)
2025-12-10 12:05:39 +08:00
os.dup2(fd, sys.stdout.fileno())
os.dup2(fd, sys.stderr.fileno())
os.close(fd)
def multi_log_monkey_patch(func):
"""
2025-12-10 17:51:24 +08:00
多次打印日志的猴子补丁函数用于测试日志重定向功能
该函数会在每次调用被补丁的函数时打印一条日志信息
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
func (function): 需要被补丁的原始函数
2025-12-10 12:05:39 +08:00
Returns:
2025-12-10 17:51:24 +08:00
function: 返回一个包装后的新函数每次调用都会打印一条日志信息
2025-12-10 12:05:39 +08:00
"""
def wrapper(*args, **kwargs):
print("[monkey patch] ensure_model_parallel_initialized")
func(*args, **kwargs)
redirect_output()
return wrapper
2025-12-10 17:51:24 +08:00
# if os.environ.get("VLLM_MULTI_LOG", "0") == "1":
2025-12-10 12:05:39 +08:00
if xenvs.ENABLE_VLLM_MULTI_LOG:
print("ENABLE_VLLM_MULTI_LOG monkey--------")
vllm.distributed.ensure_model_parallel_initialized = multi_log_monkey_patch(
2025-12-10 17:51:24 +08:00
vllm.distributed.ensure_model_parallel_initialized)
2025-12-10 12:05:39 +08:00
class StageHookPre(object):
def __call__(self, *args, **kwargs):
"""
2025-12-10 17:51:24 +08:00
在调用对象时会自动执行此方法
如果当前的attention metadata不为None并且已经处理了一个token则打印"Per Token Start"否则打印"First Token Start"
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
args (tuple, optional): 可变参数默认为空元组
kwargs (dict, optional): 关键字参数默认为空字典
2025-12-10 12:05:39 +08:00
Returns:
2025-12-10 17:51:24 +08:00
None: 无返回值
2025-12-10 12:05:39 +08:00
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
print("First Token Start", flush=True)
else:
print("Per Token Start", flush=True)
class StageHookPost(object):
def __call__(self, *args, **kwargs):
"""
2025-12-10 17:51:24 +08:00
如果当前上下文中的attention metadata不为None并且num_decode_tokens等于0则打印"First Token End"
否则打印"Per Token End"
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
args (Tuple[Any]): 可变长度参数列表无用参数传入
kwargs (Dict[str, Any]): 字典类型的关键字参数无用参数传入
2025-12-10 12:05:39 +08:00
Returns:
2025-12-10 17:51:24 +08:00
None: 该函数没有返回值
2025-12-10 12:05:39 +08:00
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
print("First Token End", flush=True)
else:
print("Per Token End", flush=True)
class ModuleLoggingHookPre(object):
def __init__(self):
"""
2025-12-10 17:51:24 +08:00
初始化函数用于初始化缩进列表和名称列表
缩进列表用于存储每一行的缩进信息名称列表用于存储每一个变量或函数的名称
2025-12-10 12:05:39 +08:00
"""
self.indent_list = list()
self.indent_list.append("")
self.name_list = list()
def __call__(self, *args, **kwargs):
"""
2025-12-10 17:51:24 +08:00
重写了 __call__ 方法用于在类实例化时调用
将当前缩进增加一个 Tab并记录当前类名称
打印开始信息flush=True 表示立即输出到控制台
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
args (tuple): 传入的参数列表第一个元素是类实例
kwargs (dict): 传入的关键字参数列表不使用
2025-12-10 12:05:39 +08:00
Returns:
None.
"""
self.indent_list.append(self.indent_list[-1] + "\t")
self.name_list.append(args[0].__class__.__module__ + args[0].__class__.__name__)
print(self.indent_list[-1] + self.name_list[-1] + " Start", flush=True)
class ModuleLoggingHookPost(object):
def __init__(self, indent_list, name_list):
"""
2025-12-10 17:51:24 +08:00
初始化函数设置缩进列表和名称列表
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
indent_list (List[str]): 包含每个节点的缩进字符串的列表索引从0开始
name_list (List[str]): 包含每个节点的名称字符串的列表索引从0开始
注意缩进列表和名称列表应该有相同长度否则会导致错误
2025-12-10 12:05:39 +08:00
Returns:
2025-12-10 17:51:24 +08:00
None. 无返回值直接修改了类实例的属性
2025-12-10 12:05:39 +08:00
"""
self.indent_list = indent_list
self.name_list = name_list
def __call__(self, *args, **kwargs):
"""
2025-12-10 17:51:24 +08:00
当调用对象时输出模块结束信息
参数*args**kwargs - 可变长度的位置参数列表和关键字参数字典未使用
返回值None无返回值
2025-12-10 12:05:39 +08:00
"""
print(self.indent_list[-1] + self.name_list[-1] + " Module End", flush=True)
self.indent_list.pop()
self.name_list.pop()
2025-12-10 17:51:24 +08:00
# if os.environ.get("ENABLE_VLLM_MODULE_HOOK", "0") == "1":
2025-12-10 12:05:39 +08:00
if xenvs.ENABLE_VLLM_MODULE_HOOK:
2025-12-10 17:51:24 +08:00
from torch.nn.modules.module import register_module_forward_pre_hook, register_module_forward_hook
2025-12-10 12:05:39 +08:00
module_logging_hook_pre = ModuleLoggingHookPre()
module_logging_hook_post = ModuleLoggingHookPost(
2025-12-10 17:51:24 +08:00
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list)
2025-12-10 12:05:39 +08:00
register_module_forward_pre_hook(module_logging_hook_pre)
register_module_forward_hook(module_logging_hook_post)
else:
module_logging_hook_pre = None
module_logging_hook_post = None
class LoggingDispatchMode(TorchDispatchMode):
def __init__(self):
"""
2025-12-10 17:51:24 +08:00
初始化函数用于初始化类的属性和方法
在此处可以进行一些初始化操作例如设置默认值等
2025-12-10 12:05:39 +08:00
"""
super().__init__()
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""
Override the default dispatch behavior of torch.nn.Module.
This function will be called before and after each method call on this module.
It can be used to log information about the method calls.
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
Args:
func (function): The function that is being called on this module.
types (Tuple[str]): A tuple of strings representing the type signatures of the arguments.
See torch.types for more details.
args (Tuple[Any], optional): The positional arguments passed to the function. Defaults to ().
kwargs (Dict[str, Any], optional): The keyword arguments passed to the function. Defaults to {}.
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
Returns:
Any: The result returned by the function.
"""
global module_logging_hook_pre
if module_logging_hook_pre is not None:
indent = module_logging_hook_pre.indent_list[-1]
else:
indent = "\t"
print(indent + "{} calling".format(func), flush=True)
result = func(*args, **(kwargs or {}))
print(indent + "{} called".format(func), flush=True)
2025-12-10 17:51:24 +08:00
return result
2025-12-10 12:05:39 +08:00
class CUDAGraphInnerWatcher(TorchDispatchMode):
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
def __init__(self, name_list):
"""
2025-12-10 17:51:24 +08:00
初始化函数将传入的名称列表保存到类属性中
同时创建一个字典来记录已经追踪过的张量
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
name_list (List[str]): 包含需要追踪的张量名称的列表
2025-12-10 12:05:39 +08:00
Returns:
None.
"""
self.name_list = name_list
self.traced_tensor = dict()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""
Override the default dispatch behavior of PyTorch tensors to track
the tracing process. If the result of a function call is a tensor on CUDA,
it will be added to the traced_tensor dictionary with the name of the function.
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
Args:
func (Callable): The function to be called.
types (Tuple[Type]): The type hints of the function.
args (Tuple[Any], optional): Positional arguments for the function. Defaults to ().
kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the function. Defaults to None.
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
Returns:
Any: The result of the function call.
"""
result = func(*args, **(kwargs or {}))
if isinstance(result, torch.Tensor) and result.is_cuda:
if func._name in self.name_list:
self.traced_tensor[func._name] = weak_ref_tensor(result)
return result
def __exit__(self, exc_type, exc_val, exc_tb):
"""
2025-12-10 17:51:24 +08:00
清空 traced_tensor name_list并调用父类的 __exit__ 方法
2025-12-10 12:05:39 +08:00
Args:
2025-12-10 17:51:24 +08:00
exc_type (Optional[Type[BaseException]]): 异常类型默认为 None
exc_val (Optional[BaseException]): 异常值默认为 None
exc_tb (Optional[TracebackType]): Traceback 对象默认为 None
2025-12-10 12:05:39 +08:00
Returns:
None.
"""
for name, value in self.traced_tensor.items():
print(name, value)
self.traced_tensor.clear()
self.name_list.clear()
super(CUDAGraphInnerWatcher, self).__exit__(exc_type, exc_val, exc_tb)
2025-12-10 17:51:24 +08:00
# def patch_annotations_for_schema(func):
# sig = inspect.signature(func)
# new_params = []
# for name, param in sig.parameters.items():
# anno = param.annotation
# if anno == list[int]:
# anno = typing.List[int]
# new_params.append(param.replace(annotation=anno))
# new_sig = sig.replace(parameters=new_params)
# func.__signature__ = new_sig
# return func
2025-12-10 12:05:39 +08:00
def patch_annotations_for_schema(func):
"""
2025-12-10 17:51:24 +08:00
运行时替换函数签名里的 list[int]Optional[list[int]] typing.List[int] / Optional[typing.List[int]]
torch.library.infer_schema 能识别
2025-12-10 12:05:39 +08:00
"""
sig = inspect.signature(func)
new_params = []
for name, param in sig.parameters.items():
ann = param.annotation
2025-12-10 17:51:24 +08:00
# 如果是 Optional[T]
2025-12-10 12:05:39 +08:00
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
if get_origin(inner_type) is list: # Optional[list[int]]
inner_args = get_args(inner_type)
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
param = param.replace(annotation=new_ann)
2025-12-10 17:51:24 +08:00
# 如果是直接 list[int]
2025-12-10 12:05:39 +08:00
elif get_origin(ann) is list:
args = get_args(ann)
new_ann = List[args[0] if args else typing.Any]
param = param.replace(annotation=new_ann)
new_params.append(param)
func.__signature__ = sig.replace(parameters=new_params)
return func
def supports_custom_op() -> bool:
"""supports_custom_op"""
return hasattr(torch.library, "custom_op")
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
2025-12-10 17:51:24 +08:00
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
2025-12-10 12:05:39 +08:00
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
2025-12-10 17:51:24 +08:00
"the required dependencies.")
2025-12-10 12:05:39 +08:00
return
import torch.library
if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func)
2025-12-10 17:51:24 +08:00
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
2025-12-10 12:05:39 +08:00
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
2025-12-10 17:51:24 +08:00
my_lib._register_fake(op_name, fake_impl)