Files
2025-12-10 17:51:24 +08:00

354 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, sys
import vllm
from torch.utils._python_dispatch import TorchDispatchMode
import vllm_kunlun.platforms.envs as xenvs
from vllm.utils import weak_ref_tensor
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, TypeVar, Union, cast, overload,
get_origin, get_args, List)
import torch
from torch.library import Library
import inspect
import typing
def redirect_output():
"""
重定向输出到指定目录并将日志文件命名为pp=0_rank=X或pp=1_rank=X。
如果是第一个进程组的第一个进程则使用pp=0否则使用pp=1。
Args:
无参数。
Returns:
无返回值直接修改sys.stdout和sys.stderr的文件描述符。
"""
from vllm.distributed import get_tensor_model_parallel_rank, get_pp_group
rank = get_tensor_model_parallel_rank()
dir_path = xenvs.VLLM_MULTI_LOGPATH
os.makedirs(dir_path, exist_ok=True)
if get_pp_group().is_first_rank:
log_file = os.path.join(dir_path, f"pp=0_rank={rank}.log")
else:
log_file = os.path.join(dir_path, f"pp=1_rank={rank}.log")
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT| os.O_TRUNC, 0o644)
os.dup2(fd, sys.stdout.fileno())
os.dup2(fd, sys.stderr.fileno())
os.close(fd)
def multi_log_monkey_patch(func):
"""
多次打印日志的猴子补丁函数,用于测试日志重定向功能。
该函数会在每次调用被补丁的函数时打印一条日志信息。
Args:
func (function): 需要被补丁的原始函数。
Returns:
function: 返回一个包装后的新函数,每次调用都会打印一条日志信息。
"""
def wrapper(*args, **kwargs):
print("[monkey patch] ensure_model_parallel_initialized")
func(*args, **kwargs)
redirect_output()
return wrapper
# if os.environ.get("VLLM_MULTI_LOG", "0") == "1":
if xenvs.ENABLE_VLLM_MULTI_LOG:
print("ENABLE_VLLM_MULTI_LOG monkey--------")
vllm.distributed.ensure_model_parallel_initialized = multi_log_monkey_patch(
vllm.distributed.ensure_model_parallel_initialized)
class StageHookPre(object):
def __call__(self, *args, **kwargs):
"""
在调用对象时,会自动执行此方法。
如果当前的attention metadata不为None并且已经处理了一个token则打印"Per Token Start";否则打印"First Token Start"
Args:
args (tuple, optional): 可变参数,默认为空元组。
kwargs (dict, optional): 关键字参数,默认为空字典。
Returns:
None: 无返回值。
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
print("First Token Start", flush=True)
else:
print("Per Token Start", flush=True)
class StageHookPost(object):
def __call__(self, *args, **kwargs):
"""
如果当前上下文中的attention metadata不为None并且num_decode_tokens等于0则打印"First Token End"
否则,打印"Per Token End"
Args:
args (Tuple[Any]): 可变长度参数列表,无用参数传入。
kwargs (Dict[str, Any]): 字典类型的关键字参数,无用参数传入。
Returns:
None: 该函数没有返回值。
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
print("First Token End", flush=True)
else:
print("Per Token End", flush=True)
class ModuleLoggingHookPre(object):
def __init__(self):
"""
初始化函数,用于初始化缩进列表和名称列表。
缩进列表用于存储每一行的缩进信息,名称列表用于存储每一个变量或函数的名称。
"""
self.indent_list = list()
self.indent_list.append("")
self.name_list = list()
def __call__(self, *args, **kwargs):
"""
重写了 __call__ 方法,用于在类实例化时调用。
将当前缩进增加一个 Tab并记录当前类名称。
打印开始信息flush=True 表示立即输出到控制台。
Args:
args (tuple): 传入的参数列表,第一个元素是类实例。
kwargs (dict): 传入的关键字参数列表,不使用。
Returns:
None.
"""
self.indent_list.append(self.indent_list[-1] + "\t")
self.name_list.append(args[0].__class__.__module__ + args[0].__class__.__name__)
print(self.indent_list[-1] + self.name_list[-1] + " Start", flush=True)
class ModuleLoggingHookPost(object):
def __init__(self, indent_list, name_list):
"""
初始化函数,设置缩进列表和名称列表。
Args:
indent_list (List[str]): 包含每个节点的缩进字符串的列表索引从0开始。
name_list (List[str]): 包含每个节点的名称字符串的列表索引从0开始。
注意:缩进列表和名称列表应该有相同长度,否则会导致错误。
Returns:
None. 无返回值,直接修改了类实例的属性。
"""
self.indent_list = indent_list
self.name_list = name_list
def __call__(self, *args, **kwargs):
"""
当调用对象时,输出模块结束信息。
参数:*args、**kwargs - 可变长度的位置参数列表和关键字参数字典,未使用。
返回值None无返回值。
"""
print(self.indent_list[-1] + self.name_list[-1] + " Module End", flush=True)
self.indent_list.pop()
self.name_list.pop()
# if os.environ.get("ENABLE_VLLM_MODULE_HOOK", "0") == "1":
if xenvs.ENABLE_VLLM_MODULE_HOOK:
from torch.nn.modules.module import register_module_forward_pre_hook, register_module_forward_hook
module_logging_hook_pre = ModuleLoggingHookPre()
module_logging_hook_post = ModuleLoggingHookPost(
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list)
register_module_forward_pre_hook(module_logging_hook_pre)
register_module_forward_hook(module_logging_hook_post)
else:
module_logging_hook_pre = None
module_logging_hook_post = None
class LoggingDispatchMode(TorchDispatchMode):
def __init__(self):
"""
初始化函数,用于初始化类的属性和方法。
在此处可以进行一些初始化操作,例如设置默认值等。
"""
super().__init__()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""
Override the default dispatch behavior of torch.nn.Module.
This function will be called before and after each method call on this module.
It can be used to log information about the method calls.
Args:
func (function): The function that is being called on this module.
types (Tuple[str]): A tuple of strings representing the type signatures of the arguments.
See torch.types for more details.
args (Tuple[Any], optional): The positional arguments passed to the function. Defaults to ().
kwargs (Dict[str, Any], optional): The keyword arguments passed to the function. Defaults to {}.
Returns:
Any: The result returned by the function.
"""
global module_logging_hook_pre
if module_logging_hook_pre is not None:
indent = module_logging_hook_pre.indent_list[-1]
else:
indent = "\t"
print(indent + "{} calling".format(func), flush=True)
result = func(*args, **(kwargs or {}))
print(indent + "{} called".format(func), flush=True)
return result
class CUDAGraphInnerWatcher(TorchDispatchMode):
def __init__(self, name_list):
"""
初始化函数,将传入的名称列表保存到类属性中。
同时创建一个字典来记录已经追踪过的张量。
Args:
name_list (List[str]): 包含需要追踪的张量名称的列表。
Returns:
None.
"""
self.name_list = name_list
self.traced_tensor = dict()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""
Override the default dispatch behavior of PyTorch tensors to track
the tracing process. If the result of a function call is a tensor on CUDA,
it will be added to the traced_tensor dictionary with the name of the function.
Args:
func (Callable): The function to be called.
types (Tuple[Type]): The type hints of the function.
args (Tuple[Any], optional): Positional arguments for the function. Defaults to ().
kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the function. Defaults to None.
Returns:
Any: The result of the function call.
"""
result = func(*args, **(kwargs or {}))
if isinstance(result, torch.Tensor) and result.is_cuda:
if func._name in self.name_list:
self.traced_tensor[func._name] = weak_ref_tensor(result)
return result
def __exit__(self, exc_type, exc_val, exc_tb):
"""
清空 traced_tensor 和 name_list并调用父类的 __exit__ 方法。
Args:
exc_type (Optional[Type[BaseException]]): 异常类型,默认为 None。
exc_val (Optional[BaseException]): 异常值,默认为 None。
exc_tb (Optional[TracebackType]): Traceback 对象,默认为 None。
Returns:
None.
"""
for name, value in self.traced_tensor.items():
print(name, value)
self.traced_tensor.clear()
self.name_list.clear()
super(CUDAGraphInnerWatcher, self).__exit__(exc_type, exc_val, exc_tb)
# def patch_annotations_for_schema(func):
# sig = inspect.signature(func)
# new_params = []
# for name, param in sig.parameters.items():
# anno = param.annotation
# if anno == list[int]:
# anno = typing.List[int]
# new_params.append(param.replace(annotation=anno))
# new_sig = sig.replace(parameters=new_params)
# func.__signature__ = new_sig
# return func
def patch_annotations_for_schema(func):
"""
运行时替换函数签名里的 list[int]、Optional[list[int]] 为 typing.List[int] / Optional[typing.List[int]]
让 torch.library.infer_schema 能识别
"""
sig = inspect.signature(func)
new_params = []
for name, param in sig.parameters.items():
ann = param.annotation
# 如果是 Optional[T]
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
if get_origin(inner_type) is list: # Optional[list[int]]
inner_args = get_args(inner_type)
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
param = param.replace(annotation=new_ann)
# 如果是直接 list[int]
elif get_origin(ann) is list:
args = get_args(ann)
new_ann = List[args[0] if args else typing.Any]
param = param.replace(annotation=new_ann)
new_params.append(param)
func.__signature__ = sig.replace(parameters=new_params)
return func
def supports_custom_op() -> bool:
"""supports_custom_op"""
return hasattr(torch.library, "custom_op")
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies.")
return
import torch.library
if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func)
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)