提交vllm0.11.0开发分支

This commit is contained in:
chenyili
2025-12-10 17:51:24 +08:00
parent deab7dd0b6
commit 7c22d621fb
175 changed files with 31856 additions and 8683 deletions

View File

@@ -1,61 +1,28 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import vllm
from torch.utils._python_dispatch import TorchDispatchMode
import vllm_kunlun.platforms.envs as xenvs
import vllm_kunlun.platforms.envs as xenvs
from vllm.utils import weak_ref_tensor
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
get_origin,
get_args,
List,
)
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, TypeVar, Union, cast, overload,
get_origin, get_args, List)
import torch
from torch.library import Library
import inspect
import typing
def redirect_output():
"""
Redirect output to a specified directory and name the log files as pp=0_rank=X or pp=1_rank=X.
If it is the first process of the first process group, use pp=0; otherwise, use pp=1.
重定向输出到指定目录,并将日志文件命名为pp=0_rank=Xpp=1_rank=X
如果是第一个进程组的第一个进程则使用pp=0否则使用pp=1
Args:
No parameters.
无参数。
Returns:
No return value, directly modify the file descriptors of sys.stdout and sys.stderr.
无返回值,直接修改sys.stdoutsys.stderr的文件描述符。
"""
from vllm.distributed import get_tensor_model_parallel_rank, get_pp_group
rank = get_tensor_model_parallel_rank()
dir_path = xenvs.VLLM_MULTI_LOGPATH
os.makedirs(dir_path, exist_ok=True)
@@ -63,54 +30,48 @@ def redirect_output():
log_file = os.path.join(dir_path, f"pp=0_rank={rank}.log")
else:
log_file = os.path.join(dir_path, f"pp=1_rank={rank}.log")
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT| os.O_TRUNC, 0o644)
os.dup2(fd, sys.stdout.fileno())
os.dup2(fd, sys.stderr.fileno())
os.close(fd)
def multi_log_monkey_patch(func):
"""
Monkey patch function for logging multiple times, used to test log redirection functionality.
This function will print a log message each time the patched function is called.
多次打印日志的猴子补丁函数,用于测试日志重定向功能。
该函数会在每次调用被补丁的函数时打印一条日志信息。
Args:
func (function): The original function to be patched.
func (function): 需要被补丁的原始函数。
Returns:
function: A wrapped new function that prints a log message each time it is called.
function: 返回一个包装后的新函数,每次调用都会打印一条日志信息。
"""
def wrapper(*args, **kwargs):
print("[monkey patch] ensure_model_parallel_initialized")
func(*args, **kwargs)
redirect_output()
return wrapper
# if os.environ.get("VLLM_MULTI_LOG", "0") == "1":
if xenvs.ENABLE_VLLM_MULTI_LOG:
print("ENABLE_VLLM_MULTI_LOG monkey--------")
vllm.distributed.ensure_model_parallel_initialized = multi_log_monkey_patch(
vllm.distributed.ensure_model_parallel_initialized
)
vllm.distributed.ensure_model_parallel_initialized)
class StageHookPre(object):
def __call__(self, *args, **kwargs):
"""
This method will be automatically executed when the object is called.
If the current attention metadata is not None and a token has been processed, print "Per Token Start"; otherwise, print "First Token Start".
在调用对象时,会自动执行此方法。
如果当前的attention metadata不为None并且已经处理了一个token则打印"Per Token Start";否则打印"First Token Start"
Args:
args (tuple, optional): Variable length argument list, default is an empty tuple.
kwargs (dict, optional): Keyword arguments, default is an empty dictionary.
args (tuple, optional): 可变参数,默认为空元组。
kwargs (dict, optional): 关键字参数,默认为空字典。
Returns:
None: No return value.
None: 无返回值。
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
@@ -118,22 +79,20 @@ class StageHookPre(object):
else:
print("Per Token Start", flush=True)
class StageHookPost(object):
def __call__(self, *args, **kwargs):
"""
If the current context's attention metadata is not None and num_decode_tokens equals 0, print "First Token End".
Otherwise, print "Per Token End".
如果当前上下文中的attention metadata不为None并且num_decode_tokens等于0则打印"First Token End"
否则,打印"Per Token End"
Args:
args (Tuple[Any]): Variable length argument list, unused parameters are passed in.
kwargs (Dict[str, Any]): Keyword arguments, unused parameters are passed in.
args (Tuple[Any]): 可变长度参数列表,无用参数传入。
kwargs (Dict[str, Any]): 字典类型的关键字参数,无用参数传入。
Returns:
None: No return value.
None: 该函数没有返回值。
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
@@ -145,24 +104,22 @@ class StageHookPost(object):
class ModuleLoggingHookPre(object):
def __init__(self):
"""
Initialization function to initialize the indentation list and name list.
The indentation list is used to store the indentation information of each line,
and the name list is used to store the name of each variable or function.
初始化函数,用于初始化缩进列表和名称列表。
缩进列表用于存储每一行的缩进信息,名称列表用于存储每一个变量或函数的名称。
"""
self.indent_list = list()
self.indent_list.append("")
self.name_list = list()
def __call__(self, *args, **kwargs):
"""
This method overrides the __call__ method and is used when the class is instantiated.
It increases the current indentation by one Tab and records the current class name.
It prints the start information, flush=True means it will be output to the console immediately.
重写了 __call__ 方法,用于在类实例化时调用。
将当前缩进增加一个 Tab并记录当前类名称。
打印开始信息flush=True 表示立即输出到控制台。
Args:
args (tuple): Variable length argument list, default is an empty tuple.
kwargs (dict): Keyword arguments, default is an empty dictionary.
args (tuple): 传入的参数列表,第一个元素是类实例。
kwargs (dict): 传入的关键字参数列表,不使用。
Returns:
None.
"""
@@ -174,70 +131,62 @@ class ModuleLoggingHookPre(object):
class ModuleLoggingHookPost(object):
def __init__(self, indent_list, name_list):
"""
Initialization function to set the indentation list and name list.
初始化函数,设置缩进列表和名称列表。
Args:
indent_list (List[str]): A list of indentation strings for each node, indexed from 0.
name_list (List[str]): A list of name strings for each node, indexed from 0.
Note: The indentation list and name list should have the same length, otherwise it will cause an error.
indent_list (List[str]): 包含每个节点的缩进字符串的列表索引从0开始。
name_list (List[str]): 包含每个节点的名称字符串的列表索引从0开始。
注意:缩进列表和名称列表应该有相同长度,否则会导致错误。
Returns:
None: No return value, directly modifies the instance's attributes.
None. 无返回值,直接修改了类实例的属性。
"""
self.indent_list = indent_list
self.name_list = name_list
def __call__(self, *args, **kwargs):
"""
This method is called when the object is invoked.
Args:
*args, **kwargs: Variable length argument list and keyword argument dictionary, unused.
Returns:
None: No return value.
当调用对象时,输出模块结束信息。
参数:*args、**kwargs - 可变长度的位置参数列表和关键字参数字典,未使用。
返回值None无返回值。
"""
print(self.indent_list[-1] + self.name_list[-1] + " Module End", flush=True)
self.indent_list.pop()
self.name_list.pop()
# if os.environ.get("ENABLE_VLLM_MODULE_HOOK", "0") == "1":
if xenvs.ENABLE_VLLM_MODULE_HOOK:
from torch.nn.modules.module import (
register_module_forward_pre_hook,
register_module_forward_hook,
)
from torch.nn.modules.module import register_module_forward_pre_hook, register_module_forward_hook
module_logging_hook_pre = ModuleLoggingHookPre()
module_logging_hook_post = ModuleLoggingHookPost(
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list
)
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list)
register_module_forward_pre_hook(module_logging_hook_pre)
register_module_forward_hook(module_logging_hook_post)
else:
module_logging_hook_pre = None
module_logging_hook_post = None
class LoggingDispatchMode(TorchDispatchMode):
def __init__(self):
"""
Initialization function to initialize the attributes and methods of the class.
Some initialization operations can be performed here, such as setting default values.
初始化函数,用于初始化类的属性和方法。
在此处可以进行一些初始化操作,例如设置默认值等。
"""
super().__init__()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""
Override the default dispatch behavior of torch.nn.Module.
This function will be called before and after each method call on this module.
It can be used to log information about the method calls.
Args:
func (function): The function that is being called on this module.
types (Tuple[str]): A tuple of strings representing the type signatures of the arguments.
See torch.types for more details.
args (Tuple[Any], optional): The positional arguments passed to the function. Defaults to ().
kwargs (Dict[str, Any], optional): The keyword arguments passed to the function. Defaults to {}.
Returns:
Any: The result returned by the function.
"""
@@ -249,20 +198,19 @@ class LoggingDispatchMode(TorchDispatchMode):
print(indent + "{} calling".format(func), flush=True)
result = func(*args, **(kwargs or {}))
print(indent + "{} called".format(func), flush=True)
return result
return result
class CUDAGraphInnerWatcher(TorchDispatchMode):
def __init__(self, name_list):
"""
Initialization function to save the name list to the class attribute.
It also creates a dictionary to keep track of the tensors that have been traced.
初始化函数,将传入的名称列表保存到类属性中。
同时创建一个字典来记录已经追踪过的张量。
Args:
name_list (List[str]): A list of names of tensors to be tracked.
name_list (List[str]): 包含需要追踪的张量名称的列表。
Returns:
None.
"""
@@ -274,13 +222,13 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
Override the default dispatch behavior of PyTorch tensors to track
the tracing process. If the result of a function call is a tensor on CUDA,
it will be added to the traced_tensor dictionary with the name of the function.
Args:
func (Callable): The function to be called.
types (Tuple[Type]): The type hints of the function.
args (Tuple[Any], optional): Positional arguments for the function. Defaults to ().
kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the function. Defaults to None.
Returns:
Any: The result of the function call.
"""
@@ -292,13 +240,13 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Clear the traced_tensor and name_list, and call the parent class's __exit__ method.
清空 traced_tensor name_list,并调用父类的 __exit__ 方法。
Args:
exc_type (Optional[Type[BaseException]]): The type of the exception, default is None.
exc_val (Optional[BaseException]): The value of the exception, default is None.
exc_tb (Optional[TracebackType]): he traceback object, default is None.
exc_type (Optional[Type[BaseException]]): 异常类型,默认为 None
exc_val (Optional[BaseException]): 异常值,默认为 None
exc_tb (Optional[TracebackType]): Traceback 对象,默认为 None
Returns:
None.
"""
@@ -308,11 +256,22 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
self.name_list.clear()
super(CUDAGraphInnerWatcher, self).__exit__(exc_type, exc_val, exc_tb)
# def patch_annotations_for_schema(func):
# sig = inspect.signature(func)
# new_params = []
# for name, param in sig.parameters.items():
# anno = param.annotation
# if anno == list[int]:
# anno = typing.List[int]
# new_params.append(param.replace(annotation=anno))
# new_sig = sig.replace(parameters=new_params)
# func.__signature__ = new_sig
# return func
def patch_annotations_for_schema(func):
"""
At runtime, replace list[int] and Optional[list[int]] in the function signature with typing.List[int] and Optional[typing.List[int]]
so that torch.library.infer_schema can recognize it.
运行时替换函数签名里的 list[int]Optional[list[int]] typing.List[int] / Optional[typing.List[int]]
torch.library.infer_schema 能识别
"""
sig = inspect.signature(func)
new_params = []
@@ -320,7 +279,7 @@ def patch_annotations_for_schema(func):
for name, param in sig.parameters.items():
ann = param.annotation
# If it is Optional[T]
# 如果是 Optional[T]
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
if get_origin(inner_type) is list: # Optional[list[int]]
@@ -328,7 +287,7 @@ def patch_annotations_for_schema(func):
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
param = param.replace(annotation=new_ann)
# If it is a direct list[int]
# 如果是直接 list[int]
elif get_origin(ann) is list:
args = get_args(ann)
new_ann = List[args[0] if args else typing.Any]
@@ -339,23 +298,20 @@ def patch_annotations_for_schema(func):
func.__signature__ = sig.replace(parameters=new_params)
return func
def supports_custom_op() -> bool:
"""supports_custom_op"""
return hasattr(torch.library, "custom_op")
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
@@ -374,28 +330,25 @@ def direct_register_custom_op(
"""
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies."
)
"the required dependencies.")
return
import torch.library
if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func)
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
my_lib._register_fake(op_name, fake_impl)