提交vllm0.11.0开发分支

This commit is contained in:
chenyili
2025-12-10 17:51:24 +08:00
parent deab7dd0b6
commit 7c22d621fb
175 changed files with 31856 additions and 8683 deletions

View File

@@ -1,21 +1,3 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Xinyu Dong
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""vllm kunlun init"""
from .platforms import current_platform
import sys
@@ -25,26 +7,24 @@ import builtins
import os
import time
import vllm.envs as envs
OLD_IMPORT_HOOK = builtins.__import__
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
try:
start_time = time.time()
# Module mapping table
# 模块映射表
module_mappings = {
"vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer",
"vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe",
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
"vllm.v1.worker.gpu_model_runner": "vllm_kunlun.v1.worker.gpu_model_runner"
}
# Keep the original imported modules
# 需要保持原始导入的模块
original_imports = [
"vllm.model_executor.layers.fused_moe.base",
"vllm.model_executor.layers.fused_moe.config",
"vllm.model_executor.layers.fused_moe.layer",
"vllm.model_executor.layers.fused_moe.layer"
]
if module_name in original_imports:
@@ -55,7 +35,7 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
globals=globals,
locals=locals,
fromlist=fromlist,
level=level,
level=level
)
if module_name in module_mappings:
@@ -68,15 +48,12 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
return module
relative_mappings = {
(
"compressed_tensors_moe",
"compressed_tensors",
): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
("compressed_tensors_moe", "compressed_tensors"): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer",
}
if level == 1:
parent = globals.get("__package__", "").split(".")[-1] if globals else ""
parent = globals.get('__package__', '').split('.')[-1] if globals else ''
key = (module_name, parent)
if key in relative_mappings:
if module_name in sys.modules:
@@ -91,15 +68,18 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
pass
return OLD_IMPORT_HOOK(
module_name, globals=globals, locals=locals, fromlist=fromlist, level=level
module_name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level
)
def import_hook():
"""Apply import hook for VLLM Kunlun"""
if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")):
builtins.__import__ = _custom_import
try:
modules_to_preload = [
"vllm_kunlun.ops.quantization.compressed_tensors_moe",
@@ -112,31 +92,39 @@ def import_hook():
except Exception:
pass
def register():
"""Register the Kunlun platform"""
from .utils import redirect_output
from .vllm_utils_wrapper import (
direct_register_custom_op,
patch_annotations_for_schema,
)
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
patch_bitsandbytes_loader()
import_hook()
if envs.VLLM_USE_V1:
patch_V1blockTable()
# patch_V1blockTable()
patch_V1top_p_K()
patch_V1penalties()
# TODO fixed fast top & k for vLLM 0.10.2,
pass
else:
patch_sampler()
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
def register_model():
"""Register models for training and inference"""
from .models import register_model as _reg
_reg()
# [monkey patach sampler]
import sys
import sys, importlib, warnings
def patch_bitsandbytes_loader():
try:
# 载入你插件里自定义的 direct_register_custom_op 实现
custom_utils = importlib.import_module("vllm_kunlun.models.model_loader.bitsandbytes_loader")
# 覆盖 vllm.utils
sys.modules["vllm.model_executor.model_loader.bitsandbytes_loader"] = custom_utils
print("[vllm_kunlun] bitsandbytes_loader patched ->", custom_utils.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] bitsandbytes_loader patch failed: {e!r}")
def patch_sampler():
try:
@@ -149,24 +137,12 @@ def patch_sampler():
def patch_V1top_p_K():
try:
custom_sampler = importlib.import_module(
"vllm_kunlun.v1.sample.ops.topk_topp_sampler"
)
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.topk_topp_sampler")
sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler
print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}")
def patch_V1penalties():
try:
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.penalties")
sys.modules["vllm.v1.sample.ops.penalties"] = custom_sampler
print("[vllm_kunlun] V1sampler penalties patched ->", custom_sampler.__file__)
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 sampler penalties patch failed: {e!r}")
def patch_V1blockTable():
try:
custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table")
@@ -175,6 +151,5 @@ def patch_V1blockTable():
except Exception as e:
warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}")
# Automatically apply patches when modules are imported
# 在模块导入时自动应用补丁
import_hook()

View File

@@ -1,20 +1,6 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian
# Email: baoqian@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
from abc import abstractmethod
@@ -46,7 +32,7 @@ class TorchCompileWrapperWithCustomDispatcher:
def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config, CUDAGraphMode
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
if compiled_callable is None:
@@ -61,9 +47,13 @@ class TorchCompileWrapperWithCustomDispatcher:
compiled_callable = torch.compile(
self.forward,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
fullgraph=True, #envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend,
options=options)
# print(vllm_config.compilation_config)
# vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# vllm_config.compilation_config.cudagraph_capture_sizes = [32768]
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
@@ -126,7 +116,12 @@ class TorchCompileWrapperWithCustomDispatcher:
decompiled_file)
except Exception:
pass
# if self.vllm_config.compilation_config.use_cudagraph and \
# "update" in new_code.co_names:
# import depyf
# src = depyf.decompile(new_code)
# msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
# raise RuntimeError(msg)
@contextmanager
def dispatch_to_code(self, index: int):

View File

@@ -1,20 +1,3 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian, Dong Xinyu
# Email: baoqian@baidu.com, dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun_communicator"""
from contextlib import contextmanager
from typing import Optional

View File

@@ -1,16 +0,0 @@
"""# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""
from vllm_kunlun.lora.ops.kunlun_ops.lora_ops import (bgmv_expand,bgmv_expand_slice, bgmv_shrink,
sgmv_expand, sgmv_expand_slice,
sgmv_shrink)
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink"
]

View File

@@ -1,443 +0,0 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Author: Wang Hao
# Email: wanghao129@baidu.com
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun_ops for lora"""
import torch
from torch._C import dtype
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
"""
sgmv_shrink
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
device=device, dtype=torch.int32
)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
torch.ops._C.gen_block_statistic(lora_ids, block_statistic)
inputs_sorted = torch.zeros_like(inputs, dtype=inputs.dtype, device=device)
torch.ops._C.moe_pre_sorted(
inputs,
lora_ids,
block_statistic,
inputs_sorted,
moe_index,
expert_m,
sorted_tokens_num_lod
)
output_tensor.unsqueeze_(1)
torch.ops._C.moe_fc(
x=inputs_sorted,
weight=lora_a_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=output_tensor,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_tensor.squeeze_(1).mul_(scaling)
return output_tensor
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
"""
sgmv_expand
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
device=device, dtype=torch.int32
)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
out = torch.zeros((token_nums, 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(
output_post,
moe_index.unsqueeze(1),
normed_scale,
normed_scale,
output_post
)
common_len = min(output_post.shape[1], output_tensor.shape[1])
limit = min(output_post.shape[0], output_tensor.shape[0])
if add_inputs:
output_tensor[:limit, :common_len] += output_post[:limit, :common_len]
else:
output_tensor[:limit, :common_len] = output_post[:limit, :common_len]
return output_tensor
def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
"""
sgmv_expand_slice
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
device=device, dtype=torch.int32
)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
out = torch.zeros((token_nums, 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(
output_post,
moe_index.unsqueeze(1),
normed_scale,
normed_scale,
output_post
)
slice_end = slice_offset + slice_size
actual_slice_size = min(slice_size, output_tensor.shape[1] - slice_offset)
limit = min(output_post.shape[0], output_tensor.shape[0])
if add_inputs:
output_tensor[:limit, slice_offset:slice_end] += output_post[:limit, :actual_slice_size]
else:
output_tensor[:limit, slice_offset:slice_end] = output_post[:limit, :actual_slice_size]
return output_tensor
def bgmv_shrink(
inputs: torch.Tensor, # [m, hidden_dim]
lora_a_weights: torch.Tensor, # [n, 1, r, hidden_dim]
output_tensor: torch.Tensor, # [m, r]
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
lora_indices_tensor: torch.Tensor, # [m]
scaling: float = 1.0
) -> torch.Tensor:
"""
bgmv_shrink
"""
expert_num = 9
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
torch.ops._C.gen_block_statistic(lora_ids.unsqueeze(1), block_statistic)
inputs_sorted = torch.empty_like(inputs, dtype=inputs.dtype, device=inputs.device)
torch.ops._C.moe_pre_sorted(
inputs,
lora_ids.unsqueeze(1),
block_statistic,
inputs_sorted,
moe_index,
expert_m,
sorted_tokens_num_lod
)
output_tensor.unsqueeze_(1) # Change to [m, 1, r]
torch.ops._C.moe_fc(
x=inputs_sorted,
weight=lora_a_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=output_tensor,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_tensor.squeeze_(1).mul_(scaling)
return output_tensor
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
""""
bgmv_expand
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
out = torch.zeros((inputs.shape[0], 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(output_post, moe_index.unsqueeze(1), normed_scale, normed_scale, output_post)
limit = output_tensor.shape[0]
if output_post.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
# LoRA adapter and model may add different amounts of padding to output
common_len = min(output_post.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += output_post[:limit, :common_len]
else:
output_tensor[:, :common_len] = output_post[:limit, :common_len]
return output_tensor
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True
):
"""
bgmv_expand_slice
"""
expert_num = 9
device = inputs.device
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
out = torch.zeros((inputs.shape[0], 1, slice_size), dtype=inputs.dtype, device=device)
torch.ops._C.moe_fc(
x=inputs,
weight=lora_b_weights,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=moe_index,
moe_topk=1,
y=out,
act=None,
x_perchannel_max=None,
w_perchannel_max=None,
topk_ids=None,
topk_w=None,
bias=None,
tgemm_type=None,
tweight_type=None,
scale_n=0,
scale_k=0,
use_pack_int4=False
)
output_post = out.squeeze(1)
torch.ops._C.moe_post(output_post, moe_index.unsqueeze(1), normed_scale, normed_scale, output_post)
slice_end = slice_offset + slice_size
actual_slice_size = min(slice_size, output_tensor.shape[1] - slice_offset)
limit = min(output_post.shape[0], output_tensor.shape[0])
if add_inputs:
output_tensor[:limit, slice_offset:slice_end] += output_post[:limit, :actual_slice_size]
else:
output_tensor[:limit, slice_offset:slice_end] = output_post[:limit, :actual_slice_size]
return output_tensor

View File

@@ -1,547 +0,0 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Wang Hao
# Email: wanghao129@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import TYPE_CHECKING, Optional, Union, final
import torch
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union
from vllm_kunlun.lora.ops.kunlun_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
import time
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperKunlun(PunicaWrapperBase):
"""
PunicaWrapperKunlun with moe_fc
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: Union[torch.device, str],
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
):
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
sgmv_shrink(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
expert_m,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
):
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
bgmv_shrink(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
expert_m,
self.token_lora_indices,
scale,
)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
add_inputs: bool,
):
sgmv_expand(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
self.token_lora_indices,
add_inputs,
)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
sgmv_expand_slice(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
normed_scale,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
bgmv_expand_slice(
x,
w_t_all,
y,
block_statistic,
sorted_tokens_num_lod,
moe_index,
normed_scale,
self.token_lora_indices,
y_offset,
y_slice_size,
add_inputs,
)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
)
expand_slice_fun(
y,
x,
w_t_all,
block_statistic,
sorted_tokens_num_lod,
moe_index,
y_offset,
y_slice_size,
add_inputs,
)
def _apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (
self._shrink_prefill if self.is_prefill else self._shrink_decode
)
shrink_fun(
y, x, w_t_all, block_statistic, sorted_tokens_num_lod, moe_index, scale
)
y = y.view_as(y_org)
def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)): # Each slice represents a layer
self._apply_shrink(
y[slice_idx],
x,
lora_a_stacked[slice_idx],
block_statistic,
sorted_tokens_num_lod,
moe_index,
scale,
)
def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
bias's weight
output_slices (Tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
block_statistic,
sorted_tokens_num_lod,
moe_index,
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
expand_fun: Callable = (
self._expand_prefill if self.is_prefill else self._expand_decode
)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (Tuple[int, ...]): Every slice's size.
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
"""
if self.no_lora:
return
expert_num = 9
block_statistic = torch.zeros(
[12, expert_num], dtype=torch.int32, device=x.device
)
sorted_tokens_num_lod = torch.zeros(
expert_num + 1, dtype=torch.int32, device=x.device
)
token_nums = x.size(0)
moe_index = torch.zeros(token_nums, dtype=torch.int32, device=x.device)
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(
self.token_lora_indices, y, output_slices, lora_bias_stacked
)
if buffer is None:
r = lora_b_stacked[0].size(-1)
buffer = tuple(
torch.zeros((x.size(0), r), dtype=torch.float16, device=x.device)
for _ in range(len(output_slices))
)
# [tensor.squeeze_(1) for tensor in lora_a_stacked]
new_lora_a_stacked = tuple(lora_a.squeeze(1) for lora_a in lora_a_stacked)
self.add_shrink(
buffer,
x,
new_lora_a_stacked,
block_statistic,
sorted_tokens_num_lod,
moe_index,
scale,
**kwargs,
)
# [tensor.unsqueeze_(1) for tensor in lora_a_stacked]
# [tensor.squeeze_(1) for tensor in lora_b_stacked]
new_lora_b_stacked = tuple(lora_b.squeeze(1) for lora_b in lora_b_stacked)
self.add_expand(
y,
buffer,
new_lora_b_stacked,
block_statistic,
sorted_tokens_num_lod,
moe_index,
None,
output_slices,
add_inputs=True,
**kwargs,
)
# [tensor.unsqueeze_(1) for tensor in lora_b_stacked]
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
if lora_a_stacked.dim() == 2:
lora_a_stacked = lora_a_stacked.unsqueeze(0)
if lora_b_stacked.dim() == 2:
lora_b_stacked = lora_b_stacked.unsqueeze(0)
r = lora_a_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
indices = self.sampler_indices
if indices.max() >= lora_a_stacked.size(0):
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
y = y.view_as(y_org)

View File

@@ -7,6 +7,12 @@ def register_model():
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
from .qwen3 import Qwen3ForCausalLM #noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
from .qwen3_vl import Qwen3VLForConditionalGeneration
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
# from .llama4 import Llama4ForCausalLM #noqa: F401
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
# from .deepseek_v2 import KunlunDeepseekV2MoE
# ModelRegistry.register_model(
# "DemoModel",
@@ -27,6 +33,10 @@ def register_model():
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
ModelRegistry.register_model(
"GlmForCausalLM",
@@ -34,7 +44,8 @@ def register_model():
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
ModelRegistry.register_model(
"InternLM2ForCausalLM",
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
@@ -52,16 +63,20 @@ def register_model():
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
ModelRegistry.register_model(
"Glm4MoeForCausalLM",
"vllm_kunlun.models.glm4_moe:Glm4MoeForCausalLM")
"Qwen3VLForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
ModelRegistry.register_model(
"Glm4ForCausalLM",
"vllm_kunlun.models.glm4:Glm4ForCausalLM")
"Qwen3VLMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
ModelRegistry.register_model(
"Glm4vForConditionalGeneration",
"vllm_kunlun.models.glm4_1v:Glm4vForConditionalGeneration")
"Qwen3OmniMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
ModelRegistry.register_model(
"SeedOssForCausalLM",
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
def register_quant_method():

View File

@@ -1,301 +0,0 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/glm4.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import Glm4Config
from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm_kunlun.models.llama import LlamaMLP as Glm4MLP
from vllm_kunlun.models.llama import LlamaModel
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
class Glm4Attention(nn.Module):
def __init__(self,
config: Glm4Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.rotary_dim = self.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=partial_rotary_factor,
is_neox_style=False,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Glm4DecoderLayer(nn.Module):
def __init__(
self,
config: Glm4Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Glm4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=AttentionType.DECODER,
)
self.mlp = Glm4MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_self_attn_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states = self.post_self_attn_layernorm(hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
return hidden_states, residual
ALL_DECODER_LAYER_TYPES = {
"attention": Glm4DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Glm4Model(LlamaModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=Glm4DecoderLayer)
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Glm4Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

File diff suppressed because it is too large Load Diff

View File

@@ -1,716 +0,0 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/glm4_moe.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
import os
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,get_dp_group,get_tp_group,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
logger = init_logger(__name__)
class Glm4MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Glm4MoE(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
# NOTE In the transformers implementation, the gate isn't an nn.Linear,
# so we cannot use ReplicatedLinear here.
# See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
self.gate = nn.Linear(
config.hidden_size,
config.n_routed_experts,
bias=False,
dtype=torch.float32,
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = Glm4MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
kunlun_linear_weights = self.gate.weight
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
linear_weights=kunlun_linear_weights) * self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
class Glm4MoeAttention(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 131072,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-05,
qkv_bias: bool = False,
use_qk_norm: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.use_qk_norm = use_qk_norm
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if os.getenv('USE_ORI_ROPE') == "1" or not self.use_qk_norm:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads,
self.head_dim)).reshape(q.shape)
k = self.k_norm(k.reshape(-1, self.num_kv_heads,
self.head_dim)).reshape(k.shape)
q, k = self.rotary_emb(positions, q, k)
else:
# Rope fusion operators
q, k, v = Split_Norm_Rope(qkv,
self.rotary_emb.cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.max_position_embeddings,
self.num_heads,
self.num_kv_heads,
self.head_dim,
partial_rotary_factor=self.partial_rotary_factor,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Glm4MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
131072)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.self_attn = Glm4MoeAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=config.attention_bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_qk_norm=config.use_qk_norm,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace):
self.mlp = Glm4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Glm4MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Glm4MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Glm4MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.expert_weights = []
# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Glm4MoeDecoderLayer)
if isinstance(layer.mlp, Glm4MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No Glm4MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if f"layers.{layer_idx+i}." in weight_name:
return layer_idx + i
return None

View File

@@ -1,21 +1,5 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/gpt_oss.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional

View File

@@ -1,21 +1,11 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/interns1.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# --------------------------------------------------------
# InternS1
# Copyright (c) 2025 Shanghai AI Lab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
@@ -258,33 +248,39 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
return image_token * num_images + video_token * num_videos
# def get_dummy_mm_data(
# self,
# seq_len: int,
# mm_counts: Mapping[str, int],
# ) -> MultiModalDataDict:
# target_width, target_height = \
# self.info.get_image_size_with_most_features()
# target_num_frames = \
# self.info.get_num_frames_with_most_features(seq_len, mm_counts)
# num_images = mm_counts.get("image", 0)
# num_videos = mm_counts.get("video", 0)
# config = self.info.get_hf_config()
# image_size_h, image_size_w = config.vision_config.image_size
# return {
# "image":
# self._get_dummy_images(width=target_width,
# height=target_height,
# num_images=num_images),
# "video":
# self._get_dummy_videos(width=image_size_w,
# height=image_size_h,
# num_frames=target_num_frames,
# num_videos=num_videos),
# }
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
"""Generates dummy multimodal data on Kunlun3 platform for performance analysis and warmup.
Retrieves visual resolution based on configuration (defaulting to 224x224)
and generates resized dummy data for images and videos.
Args:
seq_len: Sequence length (unused).
mm_counts: A mapping of multimodal type counts, containing "image"
and "video" keys.
Returns:
MultiModalDataDict: A dictionary containing the generated dummy image
and video data, structured as:
{
"image": dummy_image_data,
"video": dummy_video_data
}
Author:
Dong Xinyu
"""
# 读取配置里的视觉分辨率;若缺省则兜底 224×224
config = self.info.get_hf_config()
img_size = getattr(config.vision_config, "image_size", None)
if isinstance(img_size, (tuple, list)) and len(img_size) == 2:
@@ -292,13 +288,15 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
else:
cfg_h, cfg_w = 224, 224
# 统一缩减:不再使用 “with_most_features”而是选择较小的安全尺寸
target_width = min(cfg_w, 224)
target_height = min(cfg_h, 224)
target_num_frames = 1
target_num_frames = 1 # profile/warmup 只造 1 帧即可
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
# 统一让视频也按缩减后的分辨率生成
return {
"image": self._get_dummy_images(
width=target_width,

View File

@@ -1,21 +1,12 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/interns1_vit.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable
from typing import Optional
@@ -26,6 +17,7 @@ from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.model_executor.layers.activation import get_act_fn
# from vllm_kunlun.ops.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -253,6 +245,7 @@ class InternS1VisionMLP(nn.Module):
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# self.activation_fn = GeluAndMul()
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,

View File

@@ -1,21 +1,12 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/internvl.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, Optional, TypeVar, Union

View File

@@ -1,9 +1,15 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/llama.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This file is a part of the vllm-kunlun project.
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,8 +44,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata

View File

@@ -1,9 +1,16 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This file is a part of the vllm-kunlun project.
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +40,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm_kunlun.ops.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm_kunlun.ops.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -44,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.adapters import as_seq_cls_model
@@ -177,7 +184,12 @@ class Qwen2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# INTERNVL_3暂时使用环境变量来控制是否使用原生rotary_embedding
# 若要修改,可尝试参考 qwen3.py
if os.getenv('INTERNVL_3') == "1":
q, k = self.rotary_emb.forward_native(positions, q, k)
else:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -295,6 +307,7 @@ class Qwen2Model(nn.Module):
))
self.config = config
config = config.get_text_config()
self.quant_config = quant_config
self.vocab_size = config.vocab_size
@@ -479,10 +492,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
# sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states,)
# sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,16 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2vl.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This file is a part of the vllm-kunlun project.
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,7 +45,7 @@ from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
# from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -70,11 +77,12 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from vllm.model_executor.models.vision import get_vit_attn_backend
import xspeedgate_ops
logger = init_logger(__name__)
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
_MAX_FRAMES_PER_VIDEO = 14
# === Vision Inputs === #
@@ -226,13 +234,10 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
if freqs.dim() == 3 and freqs.shape[1] == 2:
# freqs: (seq_len, 2, head_dim)
# Call custom XPU Kernel version
import xspeedgate_ops
return torch.ops.xspeedgate_ops.rope_vit(t_, freqs, interleaved = False).type_as(t)
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
@@ -922,10 +927,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None,
)
def _get_max_video_frames(self, max_tokens: int) -> int:
def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
num_frames = start_num_frames
while True:
next_num_frames = num_frames + 1
@@ -947,15 +952,23 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
self,
seq_len: int,
mm_counts: Mapping[str, int],
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
) -> int:
max_images = mm_counts.get("image", 0)
# max_images = mm_counts.get("image", 0)
# max_videos = mm_counts.get("video", 0)
# max_image_tokens = self.get_max_image_tokens() * max_images
# max_total_frames = self._get_max_video_frames(seq_len -
# max_image_tokens)
# max_frames_per_video = min(max_total_frames // max(max_videos, 1),
# _MAX_FRAMES_PER_VIDEO)
# return max(max_frames_per_video, 1)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_total_frames = self._get_max_video_frames(seq_len)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
max_frames_per_video)
return max(max_frames_per_video, 1)
@@ -1404,10 +1417,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
# sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
return self.language_model.compute_logits(hidden_states)
# sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
@@ -1507,4 +1520,4 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -1,9 +1,14 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen3.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This file is a part of the vllm-kunlun project.
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,65 +23,57 @@
# limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import xtorch_ops
from typing import Any, Optional, Union
import torch
import os
from torch import nn
from transformers import Qwen3Config
from vllm.attention import AttentionType, AttentionMetadata
from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm_kunlun.ops.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm import envs
from vllm.model_executor.models.adapters import as_seq_cls_model
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.platforms import current_platform
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
logger = init_logger(__name__)
class Qwen3Attention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -98,10 +95,7 @@ class Qwen3Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position = max_position
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
self.max_position = int(self.max_position * scaling_factor)
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -123,18 +117,25 @@ class Qwen3Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type,
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@@ -142,35 +143,19 @@ class Qwen3Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# TODO: Supports both original Rope and Kunlun Rope fusion operators
if os.getenv('FUSED_QK_ROPE_OP') == "1":
# Rope fusion operators
q, k, v = Split_Norm_Rope(qkv,
self.rotary_emb.cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.max_position,
self.num_heads,
self.num_kv_heads,
self.head_dim,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -190,6 +175,9 @@ class Qwen3DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen3 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
@@ -214,6 +202,7 @@ class Qwen3DecoderLayer(nn.Module):
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen3MLP(
hidden_size=self.hidden_size,
@@ -231,7 +220,6 @@ class Qwen3DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -244,8 +232,6 @@ class Qwen3DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
)
# Fully Connected
@@ -259,6 +245,7 @@ ALL_DECODER_LAYER_TYPES = {
"attention": Qwen3DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
@@ -268,189 +255,15 @@ ALL_DECODER_LAYER_TYPES = {
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen3Model(nn.Module):
"""Qwen3Model"""
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen3DecoderLayer):
super().__init__()
class Qwen3Model(Qwen2Model):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=Qwen3DecoderLayer)
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Qwen2DecoderLayer
decoder_layer_type = decoder_layer_type or Qwen3DecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""get_input_embeddings"""
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""
Args:
input_ids (torch.Tensor): Input sequence of shape `(batch, seq_len)`.
Indices are expected to be in the range `[0, config.vocab_size]`.
positions (torch.Tensor): Positional tensor of shape `(batch, seq_len)`.
intermediate_tensors (Optional[IntermediateTensors], optional):
Intermediate tensors from previous forward pass. Defaults to `None`.
inputs_embeds (Optional[torch.Tensor], optional):
Optionally, instead of positional embeddings, you can choose to
provide your own embedding lookup matrix of shape `(batch, seq_len, emb_dim)`.
If None, the model will create one on its own using the input ids.
Defaults to `None`.
Returns:
Union[torch.Tensor, IntermediateTensors]:
If `intermediate_tensors` is not None, returns a IntermediateTensors object.
Otherwise, returns a tensor of shape `(batch, seq_len, hidden_size)` representing
the output of the last transformer encoder layer.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i, layer in enumerate(self.layers[self.start_layer:self.end_layer], start=self.start_layer):
hidden_states, residual = layer(
positions,
hidden_states,
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
"""Load model weights.
Args:
weights (Iterable[tuple[str, torch.Tensor]]): An iterator containing weight names and their corresponding values.
Returns (set[str]):
A set of already loaded weight names.
Exceptions:
None.
"""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -493,6 +306,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@@ -502,7 +322,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
kv_caches: list[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
@@ -511,10 +330,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
@@ -525,6 +342,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
import typing
from collections.abc import Iterable
from typing import Callable, Optional, Union
import torch
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
Qwen3VLMoeConfig)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
Qwen3VLForConditionalGeneration,
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
from vllm.model_executor.models.utils import is_pp_missing_parameter, maybe_prefix
logger = init_logger(__name__)
class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen3VLMoeConfig)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
# the same shape as input_embeds
"deepstack_input_embeds": 0
})
class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
if not get_pp_group().is_first_rank:
assert self.start_layer >= len(
vllm_config.model_config.hf_config.vision_config.
deepstack_visual_indexes), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer_idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and \
layer_idx in range(0, len(deepstack_input_embeds)):
hidden_states = hidden_states + deepstack_input_embeds[
f"deepstack_input_embeds_{layer_idx}"]
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_fused_expert_weights(self, name: str, params_dict: dict,
loaded_weight: torch.Tensor, shard_id: str,
num_experts: int) -> bool:
param = params_dict[name]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
loaded_local_expert = False
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
success = weight_loader(param,
curr_expert_weight,
name,
shard_id,
expert_id,
return_success=True)
if success:
loaded_local_expert = True
return loaded_local_expert
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
".v_scale", "_v_scale", ".weight_scale",
"_weight_scale", ".input_scale", "_input_scale")
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.get_text_config().num_experts
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if ("experts.gate_up_proj" in name
or "experts.down_proj" in name):
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1,
-2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
success_w1 = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight[0],
"w1", num_experts)
success_w3 = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight[1],
"w3", num_experts)
success = success_w1 and success_w3
else:
# down_proj
success = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight,
shard_id, num_experts)
else:
# Skip loading extra parameters for GPTQ/modelopt models
if name_mapped.endswith(
ignore_suffixes
) and name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3MoeForCausalLM, self).__init__()
self.config = vllm_config.model_config.hf_config.text_config
self.quant_config = vllm_config.quant_config
self.model = Qwen3MoeLLMModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3VLForConditionalGeneration, self).__init__()
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix,
"language_model"))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.use_deepstack = hasattr(config.vision_config,
'deepstack_visual_indexes')
self.deepstack_num_level = len(
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level

View File

@@ -0,0 +1,500 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Seed team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only SeedOss model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
import torch
from torch import nn
from transformers import PretrainedConfig as SeedOssConfig
from vllm.attention import AttentionType
from vllm_kunlun.ops.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm_kunlun.ops.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm_kunlun.ops.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class SeedOssMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class SeedOssAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
rope_scaling: tuple | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
self.head_dim = head_dim
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class SeedOssDecoderLayer(nn.Module):
def __init__(
self,
config: SeedOssConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, SeedOss uses causal attention as it is a
# decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = SeedOssAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = SeedOssMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
)
class SeedOssModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if cache_config.sliding_window is not None and hasattr(
config, "max_window_layers"
):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
)
)
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to SeedDecoderLayer
decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = SeedOssModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@@ -12,10 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
# This file is a part of the vllm-ascend project.
#
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.layernorm

View File

@@ -1,20 +1,3 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
import torch
@@ -29,7 +12,6 @@ logger = init_logger(__name__)
try:
import xtorch_ops
logger.info(f"Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
@@ -37,15 +19,13 @@ except ImportError as e:
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
"""is per token smooth quant"""
""" is per token smooth quant """
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
@@ -70,9 +50,10 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV1"""
alibi_sqrt=False
):
""" PagedAttentionV1 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
@@ -83,7 +64,7 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
vo_head_dim=128
)
@staticmethod
@@ -112,9 +93,10 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False,
):
"""PagedAttentionV2"""
alibi_sqrt=False
):
""" PagedAttentionV2 """
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
x=query,
k_cache=key_cache,
@@ -125,28 +107,31 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128,
vo_head_dim=128
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
xtorch_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
xtorch_ops.quick_gelu(
x,
out=out,
)
)
# Layernorm
@staticmethod
@@ -157,7 +142,9 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
xtorch_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
@staticmethod
def fused_add_rms_norm(
@@ -175,11 +162,16 @@ class KunlunOps:
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
"""
refactor RotaryEmbedding forward function
"""
@@ -196,43 +188,66 @@ class KunlunOps:
key_x = key_x.unsqueeze(0)
xtorch_ops.rotary_embedding_gptj(
positions, query_x, key_x, head_size, cos_sin_cache
)
positions,
query_x,
key_x,
head_size,
cos_sin_cache)
query.data = query_x
key.data = key_x
key.data = key_x
if query_x_dim != query_x.dim():
query_x = query_x.unsqueeze(0)
key_x = key_x.unsqueeze(0)
return query, key
# TODO: need opt
if cos_sin_cache.dim() == 4:
max_seq_len = cos_sin_cache.shape[2]
head_dim = cos_sin_cache.shape[3]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
0
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
# Reshape query and key
# 重塑 query key 的形状
num_tokens = query_x.shape[0]
num_heads = query_x.shape[1] // head_size
num_kv_heads = key_x.shape[1] // head_size
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
# query_x = query_x.view(num_tokens, num_heads, head_size)
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
# # 确保形状正确
# assert query_x.shape == (num_tokens, num_heads, head_size), \
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
torch.ops._C.rotary_embedding(
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
is_neox_style)
query_x = query_x.view(num_tokens, num_heads * head_size)
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
# query.data = query_x
# key.data = key_x
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
positions,
mrope_section,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
"""
refactor RotaryEmbedding forward function
"""
@@ -241,21 +256,35 @@ class KunlunOps:
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
mrope_section)
query.data = query_x
key.data = key_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
xtorch_ops.swap_blocks(src, dst, block_mapping)
def swap_blocks(
src,
dst,
block_mapping):
""" swap_blocks """
xtorch_ops.swap_blocks(
src,
dst,
block_mapping
)
@staticmethod
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
def copy_blocks(
key_caches,
value_caches,
block_mapping):
""" copy_blocks """
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
@@ -273,9 +302,16 @@ class KunlunOps:
value_cache,
slot_mapping,
kv_cache_dtype,
):
"""reshape_and_cache"""
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
):
""" reshape_and_cache """
# slot_mapping_cast = slot_mapping.to(torch.int32)
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping
)
@staticmethod
def multi_query_kv_attention(
@@ -284,7 +320,7 @@ class KunlunOps:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs,
**kargs
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
@@ -303,7 +339,7 @@ class KunlunOps:
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
q=query,
@@ -318,132 +354,85 @@ class KunlunOps:
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
def quant_fusedresidual_rmsnorm_op(x,
residual,
weight,
bias,
scale_to_int,
eps,
dyn_scale: bool,
type: int = 1):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
def quant_rmsnorm_op(x,
weight,
bias,
scale_to_int,
eps,
dyn_scale : bool,
type: int = 1):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
def smooth_quant_matmul_column_row_kernels(input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
out = torch.empty((input_shape[0] * input_shape[1],
weight_shape[0]),
dtype=torch.float16,
device=weight.device)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
out = torch.empty((input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out)
out = out.view(*output_bs_shape, weight_shape[0])
return out
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_moe_ep(
hidden_states: torch.Tensor,
@@ -460,23 +449,23 @@ class KunlunOps:
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
router_logits = x.to(linear_weights.dtype)@linear_weights.T
topk_weights = torch.empty(batch,
top_k,
dtype=router_logits.dtype,
device=router_logits.device)
topk_ids = torch.empty(batch,
top_k,
dtype=torch.int32,
device=router_logits.device)
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
@@ -490,22 +479,50 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
torch.ops._C.swiglu(cur_token@ w13_weight[i].T, up_gate)
out[selected_token] = up_gate @ w2_weight[i].T
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
return output
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
linear_weights: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
expert_num = linear_weights.shape[0]
torch.ops._C.moe_ffn_block(
x=hidden_states,
gate_w=linear_weights,
inter_w=w1,
output_w=w2,
expert_num=expert_num,
moe_top_k=topk,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=num_expert_group,
out=output,
)
return output
@staticmethod
def fused_multi_head_latent_page_attention(
hidden_states: torch.Tensor,
@@ -538,11 +555,10 @@ class KunlunOps:
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
xtorch_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
@@ -579,3 +595,42 @@ class KunlunOps:
v_cache=v_cache,
)
return output
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
"""fused_gdn_gating"""
output = xtorch_ops.fused_gdn_gating(
A_log,
a,
dt_bias,
)
return output
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
'''
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
'''
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)

View File

@@ -1,8 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
"""Custom activation functions."""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict
@CustomOp.register("kunlun_fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, threshold: float = 0.):
"""
Initializes the instance.
Args:
threshold (float, optional): Threshold value for the filter. Defaults to 0..
Returns:
None: This method does not return anything.
"""
super().__init__()
self.threshold = threshold
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""
计算输入张量的正向传播,并返回一个新的张量。
该函数实现了原生的前向传播过程,即对输入张量进行阈值化处理后,将其乘以另一个张量。
Args:
x (torch.Tensor, shape=[*, d]):
输入张量,其中*表示任意维度d为特征维度。
Returns:
torch.Tensor, shape=[*, d]:
返回一个新的张量其形状与输入张量相同除了最后一个维度被设置为d/2。
如果输入张量的最后一个维度小于等于d/2则返回的张量将保持不变否则将对输入张量进行阈值化处理。
"""
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
在CUDA设备上执行前向传播。
Args:
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
Returns:
torch.Tensor: 输出张量,形状为(N, C, H, W)。
"""
return self.forward_native(x)
@CustomOp.register("kunlun_silu_and_mul")
@@ -15,9 +85,532 @@ class SiluAndMul(CustomOp):
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""forward_cuda"""
import xtorch_ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
torch.ops._C.swiglu(x, out)
return out
return out
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
"""forward_kunlun"""
import xtorch_ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
xtorch_ops.swiglu(x, out)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply the function on `x` using XPU backend.
Args:
x (torch.Tensor): Input tensor of any shape. Must be a floating point tensor.
The number of channels should be even.
Returns:
torch.Tensor: Output tensor with the same shape as input except the last dimension is reduced by half.
It has the same dtype as the input and lives on the same device.
Raises:
None
"""
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播一个神经元,计算输入的信号。
参数:
x (torch.Tensor): 形状为(-1, d)的张量其中d是输入的维度。
每个元素表示一个输入信号。
返回值torch.Tensor
形状为(-1, d)的张量其中d是输出的维度。
每个元素表示一个输出信号。
"""
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
result = s * x_reshaped[:, d:]
return result.view(*x.shape[:-1], d)
@CustomOp.register("kunlun_mul_and_silu")
class MulAndSilu(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self):
"""
初始化函数,用于实例化类的对象。
如果当前平台是 CUDA 或 XPU则使用 torch.ops._C.mul_and_silu 进行操作;
否则,如果当前平台是 CPU则使用 forward_native 方法进行操作。
"""
super().__init__()
if current_platform.is_cuda_alike():
self.op = torch.ops._C.mul_and_silu
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return x[..., :d] * F.silu(x[..., d:])
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
在CUDA设备上执行前向传播操作。
Args:
x (torch.Tensor): 输入张量,其形状应为(..., d其中d是特征维度。
Returns:
torch.Tensor: 输出张量其形状与输入张量相同但最后一个维度被替换为d/2。
Raises:
无。
"""
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
# TODO implement forward_xpu for MulAndSilu
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("kunlun_gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def __init__(self, approximate: str = "none"):
"""
Initializes the instance.
Args:
approximate (str, optional): The approximation method to use. Defaults to "none".
Can be one of "none", "tanh".
Raises:
ValueError: If the `approximate` parameter is not one of "none", "tanh".
"""
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
在CUDA设备上进行前向传播。
Args:
x (torch.Tensor): 输入张量形状为batch_size, ..., dim其中dim是特征维度。
Returns:
torch.Tensor: 输出张量形状为batch_size, ..., dim//2其中dim是特征维度除以2是因为GELU的输出是两个分量。
Raises:
无。
"""
# from vllm import _custom_ops as ops
import xtorch_ops
# d = x.shape[-1] // 2
# output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(x, dtype=x.dtype, device=x.device)
if self.approximate == "none":
# ops.gelu_and_mul(out, x)
print(x,x.shape)
xtorch_ops.gelu(x, out)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d, _ = self._check_and_make_out(x)
# 保守地用 contiguous避免 view 相关坑
x = x.contiguous()
x1 = x[..., :d]
x2 = x[..., d:]
return F.gelu(x1, approximate=self.approximate) * x2
# def forward_native(self, x: torch.Tensor) -> torch.Tensor:
# """PyTorch-native implementation equivalent to forward()."""
# d = x.shape[-1] // 2
# return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply gelu activation function on input tensor using iPEX backend.
Args:
x (torch.Tensor): Input tensor with shape (N, C, H, W).
The data type can be float32 or float64.
Returns:
torch.Tensor: Output tensor with the same shape and data type as input.
The output will have a range of (-0.5, 0.5) for tanh approximation.
"""
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str:
"""
返回一个字符串,包含有关模型的额外信息。这个函数可以被用于打印出模型的概要信息。
默认情况下这个函数会返回一个包含模型是否使用近似值approximate的信息。
Returns:
str (str): 一个字符串,包含有关模型的额外信息。
"""
return f'approximate={repr(self.approximate)}'
@CustomOp.register("kunlun_gelu_new")
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
计算CUDA上的GELU函数。
Args:
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
Returns:
torch.Tensor: GELU函数的结果形状与输入相同。
Raises:
无。
"""
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply the GELU activation function element-wise.
Args:
x (torch.Tensor): Input tensor with any shape. The data type is float32 or float64.
Returns:
torch.Tensor: Output tensor with the same shape as input. The data type is the same as input.
Raises:
None
"""
from vllm._ipex_ops import ipex_ops as ops
return ops.gelu_new(x)
@CustomOp.register("kunlun_gelu_fast")
class FastGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
计算输入张量x的CUDA版本GELUGaussian Error Linear Unit
该函数调用了vllm模块中的_custom_ops模块中的gelu_fast函数完成GELU操作。
Args:
x (torch.Tensor): 输入张量,形状为(N, C, H, W)类型为float32或float64。
Returns:
torch.Tensor: GELU后的输出张量形状与x相同类型与x相同。
Raises:
无。
"""
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply the GELU function element-wise on input tensor ``x``.
Args:
x (torch.Tensor): Input tensor with any shape. The data type can be float or half float.
The range of the input values is expected to be -inf to inf.
Returns:
torch.Tensor: Output tensor with the same shape and data type as input ``x``.
The output values are in the range [-0.5, 0.5] for float dtype and [-15, 15] for half float dtype.
Raises:
TypeError: If the input ``x`` is not a torch.Tensor.
RuntimeError: If the input ``x`` contains non-finite numbers.
"""
from vllm._ipex_ops import ipex_ops as ops
return ops.gelu_fast(x)
@CustomOp.register("kunlun_quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
使用CUDA设备进行前向计算。
Args:
x (torch.Tensor): 输入张量形状为N, C, H, W
Returns:
torch.Tensor: 输出张量形状与输入相同值为GELU函数的结果。
Raises:
无。
"""
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply the GELU function element-wise on input tensor ``x``.
Args:
x (torch.Tensor): Input tensor with any shape. The data type is float32 or float64.
Returns:
torch.Tensor: Output tensor with the same shape and data type as input ``x``.
Raises:
None
"""
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
"""forward_kunlun"""
from vllm._kunlun_ops import KunlunOps as ops
out = torch.empty_like(x)
ops.quick_gelu(out, x)
return out
@CustomOp.register("kunlun_relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return torch.square(F.relu(x))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""
在CUDA设备上执行前向传播。
Args:
x (torch.Tensor): 输入张量,形状为(N, C, H, W)数据类型为float32或float64。
Returns:
torch.Tensor: 输出张量,形状与输入相同,数据类型与输入一致。
Raises:
无。
"""
return self.forward_native(x)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
):
"""
Initializes the LayerNorm module.
Args:
act_module (nn.Module): The activation function to use after layer norm.
Default: nn.GELU()
intermediate_size (int): The size of the intermediate representation.
input_is_parallel (bool, optional): Whether the input is parallelly processed.
Default: True
params_dtype (Optional[torch.dtype], optional): The data type of parameters.
If None, use the default data type. Default: None
"""
super().__init__()
self.act = act_module
self.input_is_parallel = input_is_parallel
if input_is_parallel:
tp_size = get_tensor_model_parallel_world_size()
intermediate_size_per_partition = divide(intermediate_size,
tp_size)
else:
intermediate_size_per_partition = intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播函数,将输入的张量进行缩放和激活操作。
Args:
x (torch.Tensor): 输入张量形状为N, C, H, W或者N, C, H, W, D
Returns:
torch.Tensor: 返回处理后的张量,形状与输入相同。
"""
return self.act(x) / self.scales
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""
加载权重,如果输入是并行的,则需要将其平均分配到每个模型参数中。
参数:
param (nn.Parameter): 需要加载权重的模型参数。
loaded_weight (torch.Tensor): 加载的权重张量。
返回值:
无返回值直接修改了param的数据。
"""
param_data = param.data
if self.input_is_parallel:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param_data.shape[0]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = LazyDict({
"gelu":
lambda: nn.GELU(),
"gelu_fast":
lambda: FastGELU(),
"gelu_new":
lambda: NewGELU(),
"gelu_pytorch_tanh":
lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"silu":
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
})
def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
# print(f"activation function name: {act_fn_name}")
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"geglu": lambda: GeluAndMul(),
})
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]

View File

@@ -1,55 +1,28 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
# Email: baoqian@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""kunlun attention wrapper for context and decode"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from itertools import accumulate
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from .utils import CommonAttentionState, CommonMetadataBuilder
from vllm.attention.backends.utils import (
is_block_tables_empty,
compute_slot_mapping_start_idx,
compute_slot_mapping,
)
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from .utils import (CommonAttentionState, CommonMetadataBuilder)
from vllm.attention.backends.utils import (is_block_tables_empty,
compute_slot_mapping_start_idx, compute_slot_mapping)
from vllm_kunlun.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm_kunlun.ops._kunlun_ops import KunlunOps
from vllm.attention.backends.abstract import AttentionLayer
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d
logger = init_logger(__name__)
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
accept_output_buffer = False
@staticmethod
def get_name() -> str:
return "KUNLUN_ATTENTION"
@@ -80,9 +53,8 @@ class KunlunAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size
)
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
@@ -182,6 +154,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
@@ -194,27 +167,23 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
@property
def is_all_encoder_attn_metadata_set(self):
"""
'''
All attention metadata required for encoder attention is set.
"""
return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)
)
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
"""
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
"""
return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)
)
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
@@ -227,60 +196,43 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# metadata structure
return self._cached_prefill_metadata
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (
None
if self.query_start_loc is None
else self.query_start_loc[: self.num_prefills + 1]
)
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
# flash attention needs both lod information on host and device
query_start_loc_host = (
None
if self.query_start_loc_host is None
else self.query_start_loc_host[: self.num_prefills + 1]
)
kv_prefix_start_loc_host = (
None
if self.kv_prefix_start_loc_host is None
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
+ query_start_loc_host
)
kv_prefix_start_loc = (
None
if kv_prefix_start_loc_host is None
else kv_prefix_start_loc_host.cuda()
)
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[: self.num_prefill_tokens]
)
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[: self.num_prefills]
)
context_lens_tensor = (
None
if self.context_lens_tensor is None
else self.context_lens_tensor[: self.num_prefills]
)
query_start_loc_host = (None if self.query_start_loc_host is None else
self.query_start_loc_host[:self.num_prefills + 1])
kv_prefix_start_loc_host = (None if self.kv_prefix_start_loc_host is None else
self.kv_prefix_start_loc_host[:self.num_prefills + 1] + query_start_loc_host)
kv_prefix_start_loc = (None if kv_prefix_start_loc_host is None else kv_prefix_start_loc_host.cuda())
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
# for prefix cache, block table only contains blocks that hit
# if self.block_tables is None:
# block_tables = None
# elif self.block_tables.shape[1] == 0:
# block_tables = self.block_tables[:self.num_prefills]
# else:
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
block_tables = (
None
if self.block_tables is None
else self.block_tables[: self.num_prefills]
)
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
@@ -305,8 +257,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
seq_start_loc=self.seq_start_loc,
)
seq_start_loc=self.seq_start_loc)
return self._cached_prefill_metadata
@property
@@ -319,35 +270,25 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert (self.seq_lens_tensor is not None) or (
self.encoder_seq_lens_tensor is not None
)
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (
None
if self.slot_mapping is None
else self.slot_mapping[self.num_prefill_tokens :]
)
seq_lens_tensor = (
None
if self.seq_lens_tensor is None
else self.seq_lens_tensor[self.num_prefills :]
)
seq_lens_tensor_cpu = (
None
if self.seq_lens_tensor_cpu is None
else self.seq_lens_tensor_cpu[self.num_prefills :]
)
block_tables = (
None
if self.block_tables is None
else self.block_tables[self.num_prefills :]
)
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
self.seq_lens_tensor_cpu[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = KunlunMetadata(
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
@@ -364,16 +305,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False,
)
enable_kv_scales_calculation=False)
return self._cached_decode_metadata
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
"""KunlunMetadataBuilder"""
_metadata_cls = KunlunMetadata
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
super().__init__(input_builder)
self.prefix_cache_kv_lens: List[int] = []
@@ -382,120 +320,90 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
"""prepare"""
super().prepare()
self.prefix_cache_kv_lens = list()
def _add_seq_group(
self,
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool,
):
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (
seq_id,
token_len,
seq_len,
curr_seq_len,
query_len,
context_len,
curr_sliding_window_block,
) in zip(
inter_data.seq_ids,
[len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens,
inter_data.seq_lens,
inter_data.query_lens,
inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
):
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(placeholders)
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert (
query_len == 1
), "seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len
)
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
block_table = []
assert (
not chunked_prefill_enabled
), "chunk prefill not supported for kunlun attention"
assert not chunked_prefill_enabled, "chunk prefill not supported for kunlun attention"
if inter_data.prefix_cache_hit:
assert context_len != 0
assert context_len % self.block_size == 0
block_table = block_tables[seq_id][: context_len // self.block_size]
elif (not is_prompt) and block_tables is not None:
# block_table = block_tables[seq_id]
block_table = block_tables[seq_id][:context_len // self.block_size]
elif ((not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][-curr_sliding_window_block:]
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
if is_prompt:
self.prefix_cache_kv_lens.append(context_len)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window
)
compute_slot_mapping(
is_profile_run,
self.slot_mapping,
seq_id,
seq_len,
context_len,
start_idx,
self.block_size,
inter_data.block_tables,
)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def build(
self,
seq_lens: List[int],
query_lens: List[int],
cuda_graph_pad_size: int,
batch_size: int,
):
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""build"""
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
query_start_loc = list(accumulate(query_lens, initial=0))
query_start_loc_host = torch.tensor(
query_start_loc, dtype=torch.int32, device="cpu"
)
query_start_loc_host = torch.tensor(query_start_loc, dtype=torch.int32, device='cpu')
attn_meta.query_start_loc_host = query_start_loc_host
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
# If kv cache is included and there is a hit
# 包含kv cache ,且存在命中的情况
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
self.prefix_cache_kv_lens = list(
accumulate(self.prefix_cache_kv_lens, initial=0)
)
prefix_cache_kv_lens_tensor = torch.tensor(
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
)
self.prefix_cache_kv_lens = list(accumulate(self.prefix_cache_kv_lens, initial=0))
prefix_cache_kv_lens_tensor = torch.tensor(self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu")
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
return attn_meta
def _get_seq_len_block_table_args(
attn_metadata: KunlunMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
"""
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
@@ -517,7 +425,7 @@ def _get_seq_len_block_table_args(
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
"""
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
@@ -526,26 +434,23 @@ def _get_seq_len_block_table_args(
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables,
)
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
None,
)
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"""KunlunAttentionImpl"""
@@ -564,7 +469,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError("kunlunAttention does not support block-sparse attention.")
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
@@ -585,8 +491,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}."
)
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
@@ -654,21 +560,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if attn_type == AttentionType.ENCODER and (
not attn_metadata.is_all_encoder_attn_metadata_set
):
raise AttributeError(
"Encoder attention requires setting " "encoder metadata attributes."
)
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif attn_type == AttentionType.ENCODER_DECODER and (
not attn_metadata.is_all_cross_attn_metadata_set
):
raise AttributeError(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
@@ -682,7 +583,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
@@ -691,8 +592,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
@@ -701,14 +601,10 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
else:
updated_slot_mapping = attn_metadata.slot_mapping
value = value.contiguous()
KunlunOps.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
)
KunlunOps.reshape_and_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype)
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
@@ -753,20 +649,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
out = KunlunOps.multi_query_kv_attention(
prefill_meta.query_start_loc,
prefill_meta.query_start_loc_host,
query,
key,
value,
alibi_slopes=self.alibi_slopes,
).view_as(query)
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, query, key, value,
alibi_slopes=self.alibi_slopes).view_as(query)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
assert (
attn_type != AttentionType.ENCODER_ONLY
), "Encoder-only models should not have decode metadata."
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
(
seq_lens_arg,
max_seq_len_arg,
@@ -791,4 +681,4 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
return output.view(-1, self.num_heads * self.head_size)

View File

@@ -4,13 +4,12 @@ import torch
import torch.nn.functional as F
from typing import Optional, List, Dict, Any
from vllm.attention import AttentionType
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.forward_context import ForwardContext, get_forward_context
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
from vllm.platforms import _Backend
class Attention(VllmAttention):
"""Attention"""
def __init__(
self,
num_heads: int,
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata, output=output
)
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output_kunlun(
query, key, value, output, self.layer_name
)
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
else:
return unified_attention(query, key, value, self.layer_name)
return unified_attention(
query, key, value, self.layer_name)
#
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
class MultiHeadAttention(VllmMultiHeadAttention):
def __init__(
self,
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
num_kv_heads: Optional[int] = None,
):
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
num_heads = num_heads,
head_size = head_size,
scale = scale,
num_kv_heads = num_kv_heads,
)
# kunlun only supports flash_attn
# kunlun只支持flash_attn
self.attn_backend = _Backend.FLASH_ATTN
def forward(
self,
query: torch.Tensor,
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
# kunlun only supports flash_attn
# kunlun只支持flash_attn
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_func
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale
)
out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
"""wait_for_kv_layer_from_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str, kv_cache_layer: List[torch.Tensor]
):
layer_name: str,
kv_cache_layer: List[torch.Tensor]):
"""maybe_save_kv_layer_to_connector"""
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
def unified_attention_with_output_kunlun(
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
output_scale: Optional[torch.Tensor] = None,) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def _fake_unified_attention_with_output_kunlun(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
output_scale: Optional[torch.Tensor] = None,) -> None:
return None
unified_attention_with_output_kunlun.register_fake(
_fake_unified_attention_with_output_kunlun
)
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
def unified_attention(
query: torch.Tensor,
@@ -268,7 +269,8 @@ def unified_attention(
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
return output

View File

@@ -0,0 +1,9 @@
from .chunk import chunk_gated_delta_rule
from .fused_recurrent import fused_recurrent_gated_delta_rule
from .layernorm_guard import RMSNormGated
from .torch_fla import l2norm, torch_chunk_gated_delta_rule
__all__ = [
"RMSNormGated",
"chunk_gated_delta_rule",
"fused_recurrent_gated_delta_rule",
]

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import torch.nn.functional as F
import torch
import torch.distributed as dist
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
A = A.transpose(1,2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=q.dtype)
#torch版
for i in range(len(cu_seqlens)-1):
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
return g, o, A, final_state, w, h, v_new
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
return o.to(q.dtype), final_state
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
return o, final_state

View File

@@ -0,0 +1,251 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp
from .utils import is_nvidia_hopper, use_cuda_graph
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
h,
h0,
ht,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# [BK, BV]
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
# calculate offset
h += (boh * H + i_h) * K * V
v += (bos * H + i_h) * V
k += (bos * Hg + i_h // (H // Hg)) * K
w += (bos * H + i_h) * K
if SAVE_NEW_VALUE:
v_new += (bos * H + i_h) * V
stride_v = H * V
stride_h = H * K * V
stride_k = Hg * K
stride_w = H * K
if USE_INITIAL_STATE:
h0 = h0 + i_nh * K * V
if STORE_FINAL_STATE:
ht = ht + i_nh * K * V
# load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_v_new = (
tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None]
b_g_last = tl.exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:
b_h2 = b_h2 * b_g_last
if K > 128:
b_h3 = b_h3 * b_g_last
if K > 192:
b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new)
if K > 64:
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new)
if K > 128:
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new)
if K > 192:
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new)
# epilogue
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V)
final_state = k.new_empty(
N, H, K, V, dtype=torch.float32) if output_final_state else None
v_new = torch.empty_like(u) if save_new_value else None
def grid(meta):
return (triton.cdiv(V, meta['BV']), N * H)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BV=64,
)
return h, v_new, final_state

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
# @triton.autotune(
# configs=[
# triton.Config({
# 'BK': BK,
# 'BV': BV
# },
# num_warps=num_warps,
# num_stages=num_stages) for BK in BKV_LIST
# for BV in BKV_LIST for num_warps in NUM_WARPS
# for num_stages in [2, 3, 4]
# ],
# key=['H', 'K', 'V', 'BT'],
# )
@triton.jit(do_not_specialize=['T'])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
# offset calculation
q += (bos * Hg + i_h // (H // Hg)) * K
k += (bos * Hg + i_h // (H // Hg)) * K
v += (bos * H + i_h) * V
o += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * K * V
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
(BK, BT), (0, 1))
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
(BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_o = b_o * tl.exp(b_g)[:, None]
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
o_t = i_t * BT + tl.arange(0, BT)
# m_t = o_t < T
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A, 0)
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
if FLA_GDN_FIX_BT:
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1]**-0.5
o = torch.empty_like(v)
def grid(meta):
return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=64,
BV=32
)
return o

View File

@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
'USE_G': lambda args: args['g_cumsum'] is not None
})
# @triton.autotune(
# configs=[
# triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
# for BK in [32, 64, 128] for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=['H', 'K', 'BT', 'IS_VARLEN'],
# )
@triton.jit(do_not_specialize=['T'])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
g_cumsum,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
#m_t = o_t < T
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_beta = tl.load(p_beta, boundary_check=(0, ))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
(1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * tl.exp(b_g_diff) # 使用了triton而非vllm中的exp
#m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
#b_A = tl.where(m_A, b_A, 0)
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, Hg, K = k.shape
H = beta.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
k=k,
beta=beta,
g_cumsum=g_cumsum,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
BT=BT,
BK=64,
)
return A

View File

@@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .utils import check_shared_mem, input_guard
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
# @triton.autotune(configs=[
# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]
# ],
# key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'])
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_scalar_kernel(
s,
o,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BT: tl.constexpr,
REVERSE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ),
(i_t * BT, ), (BT, ), (0, ))
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ),
(i_t * BT, ), (BT, ), (0, ))
else:
p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
(BT, ), (0, ))
p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
(BT, ), (0, ))
# [BT]
b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32)
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, ))
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
# @triton.autotune(configs=[
# triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
# for num_warps in [2, 4, 8]
# ],
# key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_vector_kernel(
s,
o,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
REVERSE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, BT)
if REVERSE:
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
else:
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
else:
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
assert chunk_size == 2**(chunk_size.bit_length() -
1), "chunk_size must be a power of 2"
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
is_use_mask_zero = True
)
return g
def chunk_local_cumsum_vector(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, chunk_size) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2**(chunk_size.bit_length() -
1), "chunk_size must be a power of 2"
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
def grid(meta):
return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
# keep cumulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel[grid](g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
S=S,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse)
return g
@input_guard
def chunk_local_cumsum(g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
if cu_seqlens is not None:
assert g.shape[
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens,
head_first, output_dtype)
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens,
head_first, output_dtype)
else:
raise ValueError(f"Unsupported input shape {g.shape}. "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise")

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
import xtorch_ops
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
q.contiguous(),
k.contiguous(),
v.contiguous(),
g.contiguous(),
beta.contiguous(),
scale,
initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
def fused_recurrent_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if scale is None:
scale = k.shape[-1]**-0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
inplace_final_state,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import triton
from .utils import tensor_cache
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
indices = torch.cat([
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
1).to(cu_seqlens)
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
return torch.cat([
cu_seqlens.new_tensor([0]),
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
]).cumsum(-1)

View File

@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import os
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D'])
@triton.jit
def l2norm_fwd_kernel1(
x,
y,
D,
BD: tl.constexpr,
eps,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
# Compute mean and variance
cols = tl.arange(0, BD)
mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=0)
b_rstd = 1 / tl.sqrt(b_var + eps)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y = b_x * b_rstd
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
key=['D'])
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
y,
eps,
NB,
T,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * MBLOCK
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Tri Dao
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2024, Tri Dao.
# ruff: noqa: E501
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vllm.triton_utils import tl, triton
from .utils import input_guard
def rms_norm_ref(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True):
dtype = x.dtype
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
group = tl.program_id(1)
X += row * stride_x_row + group * N
Y += row * stride_y_row + group * N
if HAS_Z:
Z += row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y + cols, y, mask=mask)
def layer_norm_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
z: torch.Tensor = None,
out: torch.Tensor = None,
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
# if weight.shape != (N,):
# weight = weight.reshape(N)
# print("weight",weight.shape)
# print("x",x.shape)
assert weight.shape == (N, )
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
layer_norm_fwd_kernel[grid](x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps)
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
# if z.shape != x_shape_og:
# z = z.reshape(x_shape_og)
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
return y.reshape(x_shape_og)
def layernorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, is_rms_norm)
def rmsnorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, True)
class LayerNormGated(nn.Module):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
norm_before_gate: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return layernorm_fn(x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate)
class RMSNormGated(nn.Module):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
norm_before_gate: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return rmsnorm_fn(x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate)

39
vllm_kunlun/ops/fla/op.py Normal file
View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import os
from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
div = tldevice.fast_dividef
exp = tldevice.fast_expf
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
@triton.jit
def div_normal(x, y):
return x / y
div = div_normal
exp = tl.exp
log = tl.log
log2 = tl.log2
if not hasattr(tl, 'gather'):
@triton.jit
def gather(src, index, axis, _builder=None):
# This is a fallback implementation when tl.gather is not supported
# In order to pass triton compiler, there is no actual gather operation
return src
else:
gather = tl.gather

View File

@@ -0,0 +1,422 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
import os
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .utils import input_guard
base_dir = os.path.dirname(__file__)
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
indices = torch.cat(
[
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [1, 2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["BT"],
# )
@triton.jit(do_not_specialize=["T"])
def solve_tril_16x16_kernel(
A,
Ad,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
chunk_indices + i_t * 2 + 1
).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
A = A + (bos * H + i_h) * BT
Ad = Ad + (bos * H + i_h) * 16
offset = (i_t * 16) % BT
p_A = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
)
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
o_i = tl.arange(0, 16)
for i in range(1, min(16, T - i_t * 16)):
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
mask = o_i == i
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += o_i[:, None] == o_i[None, :]
tl.store(
p_Ai,
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def solve_tril_16x16_kernel_modified(
i_t,
i_bh,
i_n,
bos,
i_b,
i_h,
subA,
subAd,
A,
Ad,
cu_seqlens,
chunk_indices,
T, # 32
H: tl.constexpr, # 4
BT: tl.constexpr, # 64
IS_VARLEN: tl.constexpr,
):
A = A + (bos * H + i_h) * BT
print("for A Base offset ", (bos * H + i_h) * BT)
offset = (i_t * 16) % BT
range16 = tl.arange(0, 16)
newp_A = subA + range16[:, None] * 16 + range16[None, :]
b_A = tl.load(newp_A).to(tl.float32)
o_i = tl.arange(0, 16)
for i in range(1, min(16, T - i_t * 16)):
print("[naive impl-0]loopIdx:", i)
# print("for A start (i_t * 16 + i) * H * BT", (i_t * 16 + i) * H * BT)
# print("for A start offset", offset)
# print("for A start", (i_t * 16 + i) * H * BT + offset)
print("[naive impl-1]b_A value in now loopIdx:", b_A)
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
# print("[naive impl-2]b_a value in now loopIdx:", b_a)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
print("[naive impl-2-1]b_a value after reduce in now loopIdx:", b_a)
mask = o_i == i
b_A = tl.where(mask[:, None], b_a, b_A)
print("[naive impl-2-2]b_A value after oimask in now loopIdx:", b_A)
# print("[naive impl-3]b_A result in now loopIdx:", b_A)
# print(f"[naive impl-4] b_A value after allLoop = {b_A}")
b_A += o_i[:, None] == o_i[None, :]
# print(f"[naive impl-5] b_A value after mask = {b_A}")
newp_Ad = subAd + range16[:, None] * 16 + range16[None, :]
tl.store(
newp_Ad,
b_A.to(subAd.dtype.element_ty, fp_downcast_rounding="rtne"),
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [1, 2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["BT"],
# )
@triton.jit(do_not_specialize=["T"])
def solve_tril_16x16_kernel_modified_in_Loop(
i_t,
i_bh,
i_n,
bos,
i_b,
i_h,
subA,
subAd,
AInLoop,
ba_reduce,
loopIdx,
reduce_res,
A,
Ad,
cu_seqlens,
chunk_indices,
T, # 32
H: tl.constexpr, # 4
BT: tl.constexpr, # 64
IS_VARLEN: tl.constexpr,
):
range16 = tl.arange(0, 16)
newp_A = subA + range16[:, None] * 16 + range16[None, :]
b_A = tl.load(newp_A).to(tl.float32)
# print("[loop impl-0]loopIdx:", loopIdx)
# print("[loop impl-1]b_A value in now loopIdx:", b_A)
o_i = tl.arange(0, 16)
i=loopIdx
b_a = -tl.load(AInLoop + o_i)
# print("[loop impl-2]b_a value in now loopIdx:", b_a)
red_res = b_a[:, None] * b_A
# print("[Triton]red_res=", red_res)
tl.store(reduce_res + range16[:, None] * 16 + range16[None, :], red_res)
# b_a = b_a + tl.sum(b_a[:, None] * b_A, 1) # TODO: revert to 0
# # print("triton reduce b_a", b_a)
# tl.store(ba_reduce + o_i, b_a)
# mask = o_i == i
# # print("mask", mask[:, None])
# # print("b_a", b_a)
# # print("b_A", b_A)
# print("before b_A", b_A)
# b_A = tl.where(mask[:, None], b_a, b_A)
# print("[loop impl-3]b_A result in now loopIdx:", b_A)
# tl.store(newp_A, b_A)
def solve_tril_16x16_kernel_new(
NT,
B,
A,
Ad,
cu_seqlens,
chunk_indices,
T,
H,
BT,
IS_VARLEN,
):
Ad_modify = Ad
for loopX in range(NT):
# i_n, i_t = tl.load(chunk_indices ...
chunk_indices_load_offset_1 = loopX * 2
row_idx = chunk_indices_load_offset_1 // chunk_indices.shape[1]
col_idx = chunk_indices_load_offset_1 % chunk_indices.shape[1]
i_n = int(chunk_indices[row_idx, col_idx])
chunk_indices_load_offset_2 = loopX * 2 + 1
row_idx = chunk_indices_load_offset_2 // chunk_indices.shape[1]
col_idx = chunk_indices_load_offset_2 % chunk_indices.shape[1]
i_t = int(chunk_indices[row_idx, col_idx])
# bos, eos = tl.load(cu_seqlens ...
cu_seqlens_load_offset_1 = i_n
bos = int(cu_seqlens[cu_seqlens_load_offset_1])
cu_seqlens_load_offset_2 = i_n + 1
eos = int(cu_seqlens[cu_seqlens_load_offset_2])
T = eos - bos
for loopY in range(B * H):
i_b = loopY // H
i_h = loopY % H
# get subA
if (bos * H + i_h) < H:
Tstart = loopX * 16 % BT
Tend = Tstart + 16
BTstart = loopX * 16 % BT
BTend = BTstart + 16
subA = A[0, Tstart:Tend, loopY, BTstart:BTend].contiguous().clone()
# print(f"subA slice A dim[0, {Tstart}:{Tend}, {loopY}, {BTstart}:{BTend}]")
if (Tend > T): # bondary check
subA[T-16:, :] = 0
# subA.shape torch.Size([9, 16])
# vvv
# subA.shape torch.Size([16, 16]) 用0补齐
if subA.shape[0] < 16:
pad_rows = 16 - subA.shape[0]
zeros = torch.zeros((pad_rows, subA.shape[1]), dtype=subA.dtype, device=subA.device)
subA = torch.cat([subA, zeros], dim=0)
else:
assert(0) & "need deal this situation"
# get subAd
if (bos * H + i_h) < H:
Tstart = loopX * 16
Tend = Tstart + 16
BTstart = 0 * 16
BTend = BTstart + 16
subAd = Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend].contiguous().clone()
# print(f'T={T}, Tstart={Tstart}, Tend={Tend}, BTstart={BTstart}, BTend={BTend}')
else:
assert(0) & "need deal this situation"
mask = (torch.arange(16, device=subA.device)[:, None] > torch.arange(16, device=subA.device)[None, :])
subA = -torch.where(mask, subA, torch.zeros_like(subA))
for inLoopIdx in range(1, min(16, T - i_t * 16)):
# print(f"loopX={loopX}, loopY={loopY}, inLoopIdx={inLoopIdx}")
offsetStart=loopX*16 % BT
offsetEnd=offsetStart+16
AInLoop = A[0, (loopX * 16 + inLoopIdx), loopY, offsetStart:offsetEnd]
# print(f"AInLoop slice A dim[0, {(loopX * 16 + inLoopIdx)}, {loopY}, {offsetStart}:{offsetEnd}")
ba_reduce = torch.empty_like(AInLoop)
reduce_res = torch.empty_like(subA)
solve_tril_16x16_kernel_modified_in_Loop[1, 1](
i_t,
loopY,
i_n,
bos,
i_b,
i_h,
subA=subA,
subAd=subAd,
AInLoop=AInLoop,
ba_reduce=ba_reduce,
loopIdx=inLoopIdx,
reduce_res=reduce_res,
A=A,
Ad=Ad_modify,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
num_warps=1,
num_stages=4,
)
AInLoop = AInLoop.flatten()
b_A = subA # [16x16]
b_a = -AInLoop[0:16] # [16]
b_a = b_a + torch.sum(reduce_res, 0)
ba_reduce = b_a
o_i = torch.arange(16, device=ba_reduce.device)
mask = (o_i == inLoopIdx)
mask_expand = mask[:, None]
subA = torch.where(mask_expand, ba_reduce, subA)
subAd = subA + (torch.arange(16, device=subA.device)[:, None] == torch.arange(16, device=subA.device)[None, :])
# deal store mask
Tstart = loopX * 16
Tend = Tstart + 16
BTstart = 0 * 16
BTend = BTstart + 16
# print(f"slice Ad_modify dim[0, {Tend-needMaskRow}:{Tend}, {loopY}, {BTstart}:{BTend}]")
if (Tend > T): # bondary mask
needMaskRow = Tend - T
Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend] = subAd[:T-Tstart, :]
else:
# assert (Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend].shape == subAd.shape)
Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend] = subAd
# if BT == 16:
# return Ad
return Ad_modify
# @input_guard
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
Compute the inverse of the lower triangular matrix
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, K]
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`
Returns:
(I + A)^-1 with the same shape as A
"""
assert A.shape[-1] in [16, 32, 64]
B, T, H, BT = A.shape
# cnt = 0
# for b in range(B):
# for t in range(T):
# for h in range(H):
# for d in range(BT):
# A[b, t, h, d] = cnt
# cnt += 1
Ad = -999 * torch.ones(
B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
)
# cnt = 0
# for b in range(B):
# for t in range(T):
# for h in range(H):
# for d in range(16):
# Ad[b, t, h, d] = cnt
# cnt += 1
Ad_modify = Ad.clone()
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
import os
if os.getenv("TRITON_INTERPRET", None) == "1":
solve_tril_16x16_kernel[NT, B * H](
A=A,
Ad=Ad,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
num_warps=1,
num_stages=4,
)
return Ad
Ad_modify = solve_tril_16x16_kernel_new(
NT,
B,
A=A,
Ad=Ad_modify,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
IS_VARLEN= True if cu_seqlens is not None else False,
# num_warps=1,
# num_stages=4,
).to(A.dtype)
return Ad_modify

View File

@@ -0,0 +1,85 @@
import torch
import torch.nn.functional as F
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
# for each chunk
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import contextlib
import functools
import logging
import os
from enum import Enum
from typing import Any, Callable, Literal, Optional
import torch
from vllm.triton_utils import triton
logger = logging.getLogger(__name__)
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
def tensor_cache(
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
cache_size = 4
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache_entries, cache_size
for i, entry in enumerate(cache_entries):
last_args, last_kwargs, last_result = entry
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
and all(a is b for a, b in zip(args, last_args)) \
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
(args, kwargs, last_result)
]
return last_result
result = fn(*args, **kwargs)
if len(cache_entries) >= cache_size:
cache_entries = cache_entries[1:]
cache_entries.append((args, kwargs, result))
return result
return wrapper
def input_guard(
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else
i.contiguous() for i in args)
contiguous_kwargs = {
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
for k, v in kwargs.items()
}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = torch.cuda.device(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
@functools.cache
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
return 'cpu'
@functools.cache
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
device = get_available_device()
mapping = {
"cuda": "nvidia",
"hip": "amd",
"xpu": "intel",
}
# return the mapped value, or the original if not found
return mapping.get(device, device)
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
device_torch_lib = getattr(torch, device)
device_platform = _check_platform()
is_amd = (device_platform == 'amd')
is_intel = (device_platform == 'nvidia')
is_nvidia = (device_platform == 'nvidia')
is_intel_alchemist = (is_intel
and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
is_nvidia_hopper = (is_nvidia
and ('NVIDIA H' in torch.cuda.get_device_name(0)
or torch.cuda.get_device_capability()[0] >= 9))
use_cuda_graph = (is_nvidia
and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)
['max_shared_mem'] for i in range(device_torch_lib.device_count())
]
except BaseException:
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
RESOLUTION = {
torch.bool: 0,
torch.int16: 0,
torch.int32: 0,
torch.int64: 0,
torch.float16: 1e-3,
torch.float32: 1.3e-6,
torch.bfloat16: 0.016,
torch.complex32: 1e-3,
torch.complex64: 1.3e-6,
}
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype
ref = ref.to(dtype)
atol = 1e-3 * reduce_dim
rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
# )
@triton.jit(do_not_specialize=["T"])
def recompute_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
chunk_indices + i_t * 2 + 1
).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
# )
@triton.jit(do_not_specialize=["T"])
def recompute_w_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
chunk_indices + i_t * 2 + 1
).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.load(p_A, boundary_check=(0, 1))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K),
(Hg * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb)
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
recompute_w_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return w, u

View File

@@ -1,29 +1,13 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Author: Dong Xinyu, Chen Zhennan, Bao Qian, Yuan Jizhong
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""layer.py"""
import torch
import os
from typing import Callable, Optional
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_ep_group
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
@@ -101,7 +85,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
) -> torch.Tensor:
"""forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.moe.use_ep:
return ops.fused_moe_ep(x,
layer.w13_weight,
@@ -116,6 +99,96 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
num_expert_group=num_expert_group,
topk_group=topk_group
)
# fused_moe do not support expert number > 400
elif layer.local_num_experts > 400:
hidden_states = x
global_num_experts = linear_weights.shape[0]
M, N = hidden_states.shape
hidden_dim = layer.w2_weight.shape[1]
normed_score = torch.empty(M,
top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
top_k,
dtype=torch.int32,
device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
router_logits = router_logits.float()
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=layer.w13_weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.swiglu(y, out1)
out = torch.empty(M,top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=layer.w2_weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=out)
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
else:
return ops.fused_moe(x,
layer.w13_weight,
@@ -155,6 +228,7 @@ class FusedMoE(VllmFusedMoE):
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
is_sequence_parallel=False,
):
super().__init__(
num_experts=num_experts, # Global number of experts
@@ -189,7 +263,7 @@ class FusedMoE(VllmFusedMoE):
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig.make(
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
@@ -197,7 +271,7 @@ class FusedMoE(VllmFusedMoE):
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
# quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
@@ -307,4 +381,35 @@ class FusedMoE(VllmFusedMoE):
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
num_physical_experts = num_experts + num_redundant_experts
# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map = \
EplbState.build_initial_global_physical_to_logical_map(
num_experts, num_redundant_experts)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
expert_id, shard_id) for expert_id in range(num_physical_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]

View File

@@ -12,49 +12,101 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers import layernorm
from typing import Optional, Union
import xtorch_ops
def vllm_kunlun_forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x,
self.weight.data,
out,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
)
return out
return out
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
@staticmethod
def forward_xpu(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if residual is not None:
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=weight+1,
eps=variance_epsilon,
output=x
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
weight+1,
out,
variance_epsilon,
)
return out
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling():
self.forward_static = self.forward_xpu # only use in cudagraph
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore
self.forward_static, backend="aot_eager")
self._is_compiled = True
return self.forward_native(x, residual)
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,6 @@ _PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
@@ -53,18 +52,18 @@ class PagedAttention:
head_size: int,
) -> Tuple[int, ...]:
"""
Get the shape of the KV cache. Returns different shapes based on whether the computation is on-chip.
If on-chip (is_kunlun() is True), returns shape (2, num_blocks, num_kv_heads, block_size, head_size);
Otherwise, returns shape (2, num_blocks, block_size * num_kv_heads * head_size).
获取KV缓存的形状根据是否在芯片上进行计算返回不同的形状。
如果在芯片上is_kunlun()为True则返回形状(2, num_blocks, num_kv_heads, block_size, head_size)
否则,返回形状(2, num_blocks, block_size * num_kv_heads * head_size)
Args:
num_blocks (int): The number of blocks.
block_size (int): The size of each block.
num_kv_heads (int): The number of KV heads.
head_size (int): The size of each head.
num_blocks (int): 块数量。
block_size (int): 每个块大小。
num_kv_heads (int): KV头数量。
head_size (int): 每个头大小。
Returns:
Tuple[int, ...]: The shape of the KV cache, including two elements: the first element is 2, indicating the number of dimensions is 2; the second element is one of num_blocks, num_kv_heads, block_size, and head_size.
Tuple[int, ...]: KV缓存的形状包括两个元素第一个元素为2表示维度数量为2第二个元素为num_blocksnum_kv_headsblock_sizehead_size中的任意一个。
"""
if current_platform.is_kunlun():
return (2, num_blocks, num_kv_heads, block_size, head_size)
@@ -77,20 +76,20 @@ class PagedAttention:
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Split a cached tensor (containing key and value) into two parts, each part is a tensor.
If running on KUNLUN, the first returned tensor is the key cache, and the second tensor is the value cache.
Otherwise, the first tensor is the key cache, and the second tensor is a view of the key cache with shape (num_blocks, num_kv_heads, head_size//x, -1, x),
and the third tensor is the value cache with shape (num_blocks, num_kv_heads, head_size, -1).
将一个缓存张量包含key和value分成两部分每个部分是一个张量。
如果在KUNLUN上运行则返回的第一个张量是key缓存第二个张量是value缓存。
否则第一个张量是key缓存第二个张量是key缓存的view其形状为(num_blocks, num_kv_heads, head_size//x, -1, x)
第三个张量是value缓存其形状为(num_blocks, num_kv_heads, head_size, -1)
Args:
kv_cache (torch.Tensor): A tensor containing key and value, with shape (2, num_blocks, kv_cache_size).
num_kv_heads (int): The number of heads in multi-head attention.
head_size (int): The size of each head.
kv_cache (torch.Tensor): 包含key和value的张量形状为(2, num_blocks, kv_cache_size)
num_kv_heads (int): 多头注意力中的头数。
head_size (int): 每个头的大小。
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- key_cache (torch.Tensor): A tensor containing the key cache, with shape (num_blocks, num_kv_heads, head_size//x, -1, x).
- value_cache (torch.Tensor): A tensor containing the value cache, with shape (num_blocks, num_kv_heads, head_size, -1).
- key_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size//x, -1, x)包含key缓存。
- value_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size, -1)包含value缓存。
"""
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
@@ -100,7 +99,8 @@ class PagedAttention:
value_cache = kv_cache[1]
else:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@@ -152,17 +152,16 @@ class PagedAttention:
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
), (
f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables."
)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
@@ -170,10 +169,9 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = max_seq_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
@@ -302,4 +300,4 @@ class PagedAttention:
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@@ -1,128 +0,0 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Li Wei, Pan Xiakai, You Zeyu
# Email: liwei157@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
Input: packed[N, K], dtype=int32, saved as AWQ order
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order
"""
N, K = packed.shape
self.align_type = 1 if K % 8 == 0 else 0
assert num_bits == 4, "Only int4 supported now"
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
if self.align_type == 0: # NORMAL MODE
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
packed_kunlun = (unpacked_kunlun << shifts).sum(
dim=-1, dtype=torch.int32
) # [N, K]
elif self.align_type == 1: # FAST MODEL
# Unpack AWQ order
unpacked_awq = (
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
) & 0xF # [N, K//8, 8, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_FAST = [
32, 0, 36, 4, 33, 1, 37, 5,
34, 2, 38, 6, 35, 3, 39, 7,
40, 8, 44, 12, 41, 9, 45, 13,
42, 10, 46, 14, 43, 11, 47, 15,
48, 16, 52, 20, 49, 17, 53, 21,
50, 18, 54, 22, 51, 19, 55, 23,
56, 24, 60, 28, 57, 25, 61, 29,
58, 26, 62, 30, 59, 27, 63, 31
]
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
else:
raise NotImplementedError
return packed_kunlun
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(
(
self.repack_int4_for_kunlun(layer.qweight.data)
if layer.qweight.data.dtype == torch.int32
else layer.qweight.data
),
requires_grad=False,
)
layer.qzeros = torch.nn.Parameter(
(
self.repack_int4_for_kunlun(layer.qzeros.data)
if layer.qzeros.data.dtype == torch.int32
else layer.qzeros.data
),
requires_grad=False,
)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = torch.ops._C.awq_dequantize(
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
)
out = torch.matmul(reshaped_x, out)
else:
out = torch.ops._C.awq_gemm(
reshaped_x, qweight, scales, qzeros, align_type=self.align_type
)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
AWQLinearMethod.process_weights_after_loading = process_weights_after_loading
AWQLinearMethod.apply = apply

View File

@@ -1,37 +1,14 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
# Author: Chen Zhennan, Dong Xinyu
# Email: chenzhennan@baidu.com
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Any, Literal, Optional, cast, Callable, Optional
from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
SparsityStructure,
)
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from compressed_tensors.config import (CompressionFormat,
SparsityCompressionConfig,
SparsityStructure)
from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils import replace_parameter
# TODO: import position will be changed after 0.9.0
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
@@ -42,7 +19,6 @@ import xtorch_ops
from safetensors.torch import load_file as safe_load_file
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
@@ -50,239 +26,177 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
linear_cfg = None
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
if k in tsm and isinstance(tsm[k], dict):
linear_cfg = tsm[k]
break
linear_cfg = tsm[k]; break
if not linear_cfg:
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
wq = linear_cfg.get("weights")
aq = linear_cfg.get("input_activations")
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
if not wq or not aq:
# print("incomplete scheme; fallback to INT8(W8A8)")
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
# Other branches are handled as needed; default fallback:
# 其它分流按需;默认回落:
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
# copied from vllm 0.9.0
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# Directly create a default quantization config dictionary to avoid validation issues with QuantizationArgs
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
# print("Creating default INT8 quantization config for MoE")
# 创建默认的权重量化配置字典
self.weight_quant = type('WeightQuant', (), {
'type': 'int',
'num_bits': 8,
'strategy': 'channel',
'group_size': 128,
'symmetric': True,
'dynamic': False,
'actorder': 'none',
'observer': None,
'observer_kwargs': {},
'block_structure': None
})()
# 创建默认的输入激活量化配置字典
self.input_quant = type('InputQuant', (), {
'type': 'int',
'num_bits': 8,
'strategy': 'token',
'group_size': 128,
'symmetric': True,
'dynamic': True,
'actorder': 'none',
'observer': None,
'observer_kwargs': {},
'block_structure': None
})()
# Create a default weight quantization config dictionary
self.weight_quant = type(
"WeightQuant",
(),
{
"type": "int",
"num_bits": 8,
"strategy": "channel",
"group_size": 128,
"symmetric": True,
"dynamic": False,
"actorder": "none",
"observer": None,
"observer_kwargs": {},
"block_structure": None,
},
)()
# Create a default input activation quantization config dictionary
self.input_quant = type(
"InputQuant",
(),
{
"type": "int",
"num_bits": 8,
"strategy": "token",
"group_size": 128,
"symmetric": True,
"dynamic": True,
"actorder": "none",
"observer": None,
"observer_kwargs": {},
"block_structure": None,
},
)()
# Change comparison method to directly compare strings
# 修改比较方式,直接比较字符串
per_channel = (
self.weight_quant.strategy == "channel"
and self.input_quant.strategy == "token"
)
and self.input_quant.strategy == "token")
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}"
)
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
"dynamic per token quantization. Found static input scales.")
def create_weights1(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Use float32 as a placeholder for weights to facilitate loading original weights from ckpt
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
), # generally is torch.bfloat16
requires_grad=False,
)
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
# 权重先用浮点占位,便于从 ckpt 加载原始权重
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype), # 通常是 torch.bfloat16
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Channel scale: float32 + 2D [E, out] (aligned with fused_moe/UT)
# 通道 scalefloat32 + 二维 [E, out](与 fused_moe/UT 对齐)
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
),
requires_grad=False,
)
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Input scale can be dynamically calculated
# 输入 scale 动态计算即可
layer.w13_input_scale = None
layer.w2_input_scale = None
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
), # directly use int8
requires_grad=False,
)
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8), # 直接使用 int8
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
), # directly use int8
requires_grad=False,
)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8), # 直接使用 int8
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Scale factors
# 缩放因子
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
),
requires_grad=False,
)
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Input scale can be dynamically calculated
# 输入 scale 动态计算
layer.w13_input_scale = None
layer.w2_input_scale = None
@torch.no_grad()
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
# Convert original weights to float32 for more robust statistics
#原始权重转 float32 做统计更稳健
w13_f = layer.w13_weight.float()
w2_f = layer.w2_weight.float()
w2_f = layer.w2_weight.float()
# Each column (abs_max) -> per-column scale (out dimension is dim=1, column is dim=-1)
# 每列(abs_max) -> per-column scaleout 维在 dim=1列在 dim=-1
qmax = 127.0
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
# Quantization: broadcast 3D scale and store back to 2D scale
# 量化:用 3D scale 广播,存回 2D scale
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
w2_q = torch.round(w2_f / w2_scale_3d).clamp_(-128, 127).to(torch.int8)
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
# Optional: If your fused/kernel expects scale pre-multiplied by 127 (to be consistent with some UT backends), uncomment the following two lines:
# 可选:若你的 fused/kernel 期望 scale 预乘 127与某些 UT 后端一致),打开下面两行:
w13_scale_2d = w13_scale_2d * 127.0
w2_scale_2d = w2_scale_2d * 127.0
w2_scale_2d = w2_scale_2d * 127.0
# Write back parameters: weight int8; scale uses float32 + 2D
replace_parameter(
layer, "w13_weight", torch.nn.Parameter(w13_q, requires_grad=False)
)
replace_parameter(
layer, "w2_weight", torch.nn.Parameter(w2_q, requires_grad=False)
)
replace_parameter(
layer,
"w13_weight_scale",
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False),
)
replace_parameter(
layer,
"w2_weight_scale",
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False),
)
# Brief check
print(
f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}"
)
# 回写参数:权重 int8scale float32 + 2D
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
replace_parameter(layer, 'w13_weight_scale',
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
replace_parameter(layer, 'w2_weight_scale',
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
# 简要检查
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
def apply(
self,
layer: torch.nn.Module,
@@ -300,11 +214,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False, # Add this parameter
expert_load_view: Optional[torch.Tensor] = None, # Add this parameter
logical_to_physical_map: Optional[torch.Tensor] = None, # Add this parameter
logical_replica_count: Optional[torch.Tensor] = None, # Add this parameter
linear_weights: Optional[torch.Tensor] = None, # Add this parameter
enable_eplb: bool = False, # 添加这个参数
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
) -> torch.Tensor:
output = torch.empty_like(x)
@@ -326,8 +240,5 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
)
return output
print(
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod"
)
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")

View File

@@ -1,108 +0,0 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Li Wei, You Zeyu
# Email: liwei157@baidu.com, youzeyu@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parameter import Parameter
from typing import Optional
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
layer.qzeros = Parameter(
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits)
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
requires_grad=False
)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
# No need shuffle on xpu
# ops.gptq_shuffle(layer.qweight, layer.g_idx,
# self.quant_config.weight_bits)
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
N, K = packed.shape
assert num_bits == 4, "Only int4 supported now"
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
# Unpack int32 to int4 values
unpacked_gptq = (
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
) & 0xF # [N, K//8, 8, 8]
# Convert to KUNLUN order
GPTQ_TO_KUNLUN_ORDER_FAST = [
32, 0, 33, 1, 34, 2, 35, 3,
36, 4, 37, 5, 38, 6, 39, 7,
40, 8, 41, 9, 42, 10, 43, 11,
44, 12, 45, 13, 46, 14, 47, 15,
48, 16, 49, 17, 50, 18, 51, 19,
52, 20, 53, 21, 54, 22, 55, 23,
56, 24, 57, 25, 58, 26, 59, 27,
60, 28, 61, 29, 62, 30, 63, 31,
]
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
return packed_kunlun
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
output = torch.ops.xspeedgate_ops.gptq_gemm(
reshaped_x,
layer.qweight,
layer.qzeros,
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading
GPTQLinearMethod.apply = apply

View File

@@ -12,35 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
# This file is a part of the vllm-ascend project.
#
import torch
import xspeedgate_ops
import os
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding,
YaRNScalingRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding,
MRotaryEmbedding,
)
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
from typing import Optional, Tuple
import xtorch_ops
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
if hasattr(self, "scaling_factor"):
self.max_position_embeddings = int(
self.max_position_embeddings * self.scaling_factor
)
if hasattr(self, 'scaling_factor'):
self.max_position_embeddings = int(self.max_position_embeddings * self.scaling_factor)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
if os.getenv("FUSED_QK_ROPE_OP") == "1":
#对于glm4-9b-chatrope跑forward_native所以需要cache保持特定的形状这里通过环境变量控制
#对于qwen2.5-vlrope跑mrope也需要cache保持特定的形状
#也就是说跑glm4-9b-chat、qwen2.5-vl需要设置GLM4_CHAT环境变量为1
if os.getenv('ROPE_NATIVE_2D') == "1":
cache = torch.cat((cos, sin), dim=-1)
return cache
if os.getenv('USE_ORI_ROPE') == "0":
cache_cos = torch.cat((cos, cos), dim=-1)
cache_sin = torch.cat((sin, sin), dim=-1)
# [2, self.max_position_embeddings, self.rotary_dim * 2]
@@ -51,89 +49,108 @@ def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
def vllm_kunlun_forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""forward_cuda"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""forward_cuda"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if (
self.cos_sin_cache.device != query.device
or self.cos_sin_cache.dtype != query.dtype
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
)
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
return x_t
def vllm_kunlun_apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
is_neox_style: bool) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
query, key = ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def vllm_kunlun_mrope_forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
query,
key,
positions.to(dtype=torch.int32),
self.cos_sin_cache,
self.mrope_interleaved,
self.is_neox_style,
self.head_size,
self.rotary_dim,
self.mrope_section[0],
self.mrope_section[1],
self.mrope_section[2]
)
assert positions.ndim == 2
assert key is not None
return query, key
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
query,
key,
positions.to(dtype=torch.int32),
self.cos_sin_cache,
False, # self.mrope_interleaved,
self.head_size,
self.rotary_dim,
self.mrope_section[0],
self.mrope_section[1],
self.mrope_section[2],
)
return query, key
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
if os.getenv("KUNLUN_ENABLE_MULTI_LORA") == "1" or os.getenv("FUSED_QK_ROPE_OP") == "1":
RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
else:
pass
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
def Split_Norm_Rope(
@@ -145,36 +162,27 @@ def Split_Norm_Rope(
max_position_embeddings: int,
q_head_num: int,
kv_head_num: int,
head_dim: int,
partial_rotary_factor: float = 1.0,
head_dim:int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_tokens = qkv.shape[0]
rotary_dim = head_dim
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
q_emb_out = torch.empty(
(num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
)
k_emb_out = torch.empty(
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
)
v_out = torch.empty(
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
)
rotary_dim=head_dim
q_emb_out = torch.empty((num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
k_emb_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
v_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
torch.ops._C.split_norm_rope_neox(
q_emb_out,
k_emb_out,
v_out,
qkv,
cos_sin_cache,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_position_embeddings,
q_head_num,
kv_head_num,
head_dim,
rotary_dim,
)
return q_emb_out, k_emb_out, v_out
q_emb_out,
k_emb_out,
v_out,
qkv,
cos_sin_cache,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_position_embeddings,
q_head_num,
kv_head_num,
head_dim,
rotary_dim,
)
return q_emb_out, k_emb_out, v_out

File diff suppressed because it is too large Load Diff

View File

@@ -98,12 +98,10 @@ cached_backends: Dict[int, CompilerFn] = {}
unset = Unset.token
from torch._C._dynamo.eval_frame import set_eval_frame
def _maybe_set_eval_frame(callback: DynamoCallback):
# A wrapper on set_eval_frame that is guarded by a Justknob.
# Users can disable torchDynamo by setting the JK to False.
# from torch._C._dynamo.eval_frame import set_eval_frame
#from torch._C._dynamo.eval_frame import set_eval_frame
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
torch._dynamo.utils.warn_once(
@@ -130,7 +128,7 @@ DONT_WRAP_FILES = {
def _debug_get_cache_entry_list(
code: Union[types.CodeType, Callable[..., Any]],
code: Union[types.CodeType, Callable[..., Any]]
) -> List[CacheEntry]:
"""
Given a code object or a callable object, retrieve the cache entries
@@ -373,9 +371,9 @@ class _TorchDynamoContext:
# add context containing GraphModule to any GraphModule forward functions
if isinstance(fn, GraphModule):
# add context containing GraphModule to any GraphModule forward functions
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = (
weakref.ref(fn)
)
code_context.get_context(fn.forward.__code__)[
"orig_graphmodule"
] = weakref.ref(fn)
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
@@ -789,11 +787,9 @@ def _optimize(
hooks,
backend_ctx_ctor,
dynamic=dynamic,
compiler_config=(
backend.get_compiler_config()
if hasattr(backend, "get_compiler_config")
else None
),
compiler_config=backend.get_compiler_config()
if hasattr(backend, "get_compiler_config")
else None,
rebuild_ctx=rebuild_ctx,
)
@@ -907,11 +903,9 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
flat_args[i],
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[
(
DimDynamic.DYNAMIC
if d in flat_args_dynamic_dims[i]
else DimDynamic.STATIC
)
DimDynamic.DYNAMIC
if d in flat_args_dynamic_dims[i]
else DimDynamic.STATIC
for d in range(len(flat_args[i].shape))
],
constraint_sizes=[None] * len(flat_args[i].shape),

View File

@@ -4,28 +4,26 @@ import os
from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
VLLM_MULTI_LOGPATH: str = ("./log",)
ENABLE_VLLM_MULTI_LOG: bool = (False,)
ENABLE_VLLM_INFER_HOOK: bool = (False,)
ENABLE_VLLM_OPS_HOOK: bool = (False,)
ENABLE_VLLM_MODULE_HOOK: bool = False
VLLM_MULTI_LOGPATH : str = "./log",
ENABLE_VLLM_MULTI_LOG : bool = False,
ENABLE_VLLM_INFER_HOOK : bool = False,
ENABLE_VLLM_OPS_HOOK : bool = False,
ENABLE_VLLM_MODULE_HOOK : bool = False
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"""
If the value is None, return None; otherwise, convert the string to an integer and return it.
如果值是None则返回None否则将字符串转换为整数并返回。
Args:
value (Optional[str], optional): The optional string to convert. Defaults to None.
value (Optional[str], optional): 要转换的可选字符串. Defaults to None.
Returns:
Optional[int]: If the value is None, return None; otherwise, convert the string to an integer and return it.
Optional[int]: 如果值是None则返回None否则将字符串转换为整数并返回.
"""
if value is None:
return None
return int(value)
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
@@ -33,56 +31,59 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
xvllm_environment_variables: dict[str, Callable[[], Any]] = {
# path to the logs of redirect-output, abstrac of related are ok
"VLLM_MULTI_LOGPATH": lambda: os.environ.get("VLLM_MULTI_LOGPATH", "./logs"),
# turn on / off multi-log of multi nodes & multi cards
"ENABLE_VLLM_MULTI_LOG": lambda: (
os.environ.get("ENABLE_VLLM_MULTI_LOG", "False").lower() in ("true", "1")
),
# turn on / off XVLLM infer stage log ability
"ENABLE_VLLM_INFER_HOOK": lambda: (
os.environ.get("ENABLE_VLLM_INFER_HOOK", "False").lower() in ("true", "1")
),
# turn on / off XVLLM infer_ops log ability
"ENABLE_VLLM_OPS_HOOK": lambda: (
os.environ.get("ENABLE_VLLM_OPS_HOOK", "False").lower() in ("true", "1")
),
"ENABLE_VLLM_MODULE_HOOK": lambda: (
os.environ.get("ENABLE_VLLM_MODULE_HOOK", "False").lower() in ("true", "1")
),
"VLLM_MULTI_LOGPATH":
lambda: os.environ.get("VLLM_MULTI_LOGPATH", "./logs"),
# turn on / off multi-log of multi nodes & multi cards
"ENABLE_VLLM_MULTI_LOG":
lambda: (os.environ.get("ENABLE_VLLM_MULTI_LOG", "False").lower() in
("true", "1")),
# turn on / off XVLLM infer stage log ability
"ENABLE_VLLM_INFER_HOOK":
lambda: (os.environ.get("ENABLE_VLLM_INFER_HOOK", "False").lower() in
("true", "1")),
# turn on / off XVLLM infer_ops log ability
"ENABLE_VLLM_OPS_HOOK":
lambda: (os.environ.get("ENABLE_VLLM_OPS_HOOK", "False").lower() in
("true", "1")),
"ENABLE_VLLM_MODULE_HOOK":
lambda: (os.environ.get("ENABLE_VLLM_MODULE_HOOK", "False").lower() in
("true", "1")),
# fuse sorted op with fused_moe kernel
"ENABLE_VLLM_MOE_FC_SORTED": lambda: (
os.environ.get("ENABLE_VLLM_MOE_FC_SORTED", "False").lower() in ("true", "1")
),
"ENABLE_VLLM_MOE_FC_SORTED":
lambda: (os.environ.get("ENABLE_VLLM_MOE_FC_SORTED", "False").lower() in
("true", "1")),
# enable custom dpsk scaling rope
"ENABLE_CUSTOM_DPSK_SCALING_ROPE": lambda: (
os.environ.get("ENABLE_CUSTOM_DPSK_SCALING_ROPE", "False").lower()
in ("true", "1")
),
"ENABLE_CUSTOM_DPSK_SCALING_ROPE":
lambda: (os.environ.get("ENABLE_CUSTOM_DPSK_SCALING_ROPE", "False").lower() in
("true", "1")),
# fuse qkv split & qk norm & qk rope
# only works for qwen3 dense and qwen3 moe models
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE": lambda: (
os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower()
in ("true", "1")
),
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE":
lambda: (os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower() in
("true", "1")),
}
# end-env-vars-definition
def __getattr__(name: str):
"""
This function is called when an attribute that doesn't exist is accessed.
If the attribute is one of the xvllm_environment_variables, return the corresponding value.
Otherwise, raise an AttributeError.
当调用不存在的属性时该函数被调用。如果属性是xvllm_environment_variables中的一个则返回相应的值。否则引发AttributeError异常。
Args:
name (str): The name of the attribute to retrieve.
name (str): 要获取的属性名称。
Raises:
AttributeError (Exception): If the attribute is not one of xvllm_environment_variables, this exception is raised.
AttributeError (Exception): 如果属性不是xvllm_environment_variables中的一个则会引发此异常。
Returns:
Any, optional: If the attribute is one of xvllm_environment_variables, the corresponding value is returned; otherwise, None is returned.
Any, optional: 如果属性是xvllm_environment_variables中的一个则返回相应的值否则返回None。
"""
# lazy evaluation of environment variables
if name in xvllm_environment_variables:
@@ -92,14 +93,13 @@ def __getattr__(name: str):
def __dir__():
"""
Returns a list of all visible variable names.
返回一个包含所有可见的变量名称的列表。
返回值list一个包含所有可见的变量名称的列表这些变量是通过`xvllm_environment_variables`字典定义的。
Returns:
list: A list of all visible variable names, which are defined through the `xvllm_environment_variables` dictionary.
Returns:
List[str]: A list of all visible variable names.
These variables are defined through the `xvllm_environment_variables` dictionary.
List[str]: 一个包含所有可见的变量名称的列表。
这些变量是通过`xvllm_environment_variables`字典定义的。
"""
return list(xvllm_environment_variables.keys())

View File

@@ -19,10 +19,11 @@ class KunlunPlatform(Platform):
@property
def device_type(self):
"""Returns the device type, which is fixed as 'cuda'.
"""
返回设备类型,固定为'cuda'
"""
return "cuda"
def is_kunlun(self) -> bool:
"""is_kunlun"""
return self._enum == PlatformEnum.CUDA
@@ -69,13 +70,14 @@ class KunlunPlatform(Platform):
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Returns the device name, which defaults to "kunlun".
"""
获取设备名称,默认返回 "kunlun"
Args:
device_id (int, optional): The device ID, default is 0. Ignored in this method. Defaults to 0.
device_id (int, optional): 设备ID默认为0. Ignored in this method. Defaults to 0.
Returns:
str: The device name, which is fixed as "kunlun".
str: 设备名称,固定返回 "kunlun".
"""
return "kunlun"
@@ -89,23 +91,26 @@ class KunlunPlatform(Platform):
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Returns the total memory size of the device in bytes (B). Defaults to the total memory size of the first device.
If the `device_id` parameter is not an integer or exceeds the available device range, a ValueError will be raised.
"""
获取设备总内存大小单位为字节B。默认返回第一个设备的总内存大小。
如果传入参数`device_id`不是整数或者超出了可用设备范围将会引发ValueError异常。
Args:
device_id (int, optional): The device ID, default is 0. Defaults to 0.
device_id (int, optional): 设备ID默认为0. Defaults to 0.
Raises:
ValueError: If the `device_id` parameter is not an integer or exceeds the available device range, this exception is raised.
ValueError: 当传入的`device_id`不是整数或者超出了可用设备范围时引发此异常。
Returns:
int: The total memory size of the device in bytes (B).
int: 设备总内存大小单位为字节B
"""
return psutil.virtual_memory().total
@classmethod
def inference_mode(cls):
"""Returns a context manager that disables gradient computation.
"""
进入推理模式,禁止计算梯度。
返回torch.no_grad(),一个上下文管理器,用于禁止计算梯度。
"""
return torch.no_grad()
@@ -114,29 +119,31 @@ class KunlunPlatform(Platform):
"""get_device_capability"""
major, minor = torch.cuda.get_device_capability()
return DeviceCapability(major=major, minor=minor)
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""Updates the default values of various components based on the configuration.
If not specified, automatically selects the worker class based on certain conditions.
If the block size is not set in the cache configuration, it is set to 16.
If using MLA and `VLLM_ATTENTION_BACKEND` is not set or is set to "FLASHMLA",
the cache block size is set to 64.
If running in DeepEP high throughput backend, data parallelism greater than 1, and CUDA graph mode,
it forces the use of eager mode, as DP + DeepEP high throughput kernels are not CUDA graph compatible,
and using DeepEP low latency kernels can resolve this issue.
"""
根据配置更新各个部分的默认值。
如果未指定则根据某些条件自动选择worker类。
如果缓存配置中没有设置块大小则将其设置为16。
如果使用MLA并且`VLLM_ATTENTION_BACKEND`未设置或设置为"FLASHMLA"
则将缓存块大小设置为64。
如果在DeepEP高吞吐量后端、数据并行大于1和CUDA图形模式下运行则强制
强制执行即时模式因为DP + DeepEP高吞吐量内核不是CUDA图形兼容的而且
使用DeepEP低延迟内核可以解决这个问题。
Args:
vllm_config (VllmConfig): VLLM configuration object.
vllm_config (VllmConfig): VLLM配置对象。
Raises:
NotImplementedError: If multi-step scheduling is used on vLLM V1, this exception is raised.
Please remove the --num-scheduler-steps argument from the command line.
NotImplementedError: If MLA is used on vLLM V1, this exception is raised.
Please ensure that the `VLLM_ATTENTION_BACKEND` environment variable is set before using MLA.
NotImplementedError: 如果在vLLM V1上使用多步调度则会引发NotImplementedError。
请从命令行中删除--num-scheduler-steps参数。
NotImplementedError: 如果在vLLM V1上使用MLA则会引发NotImplementedError。
请确保在使用MLA之前设置了`VLLM_ATTENTION_BACKEND`环境变量。
Returns:
None: No return value.
None: 无返回值。
"""
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
@@ -159,7 +166,7 @@ class KunlunPlatform(Platform):
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
@@ -198,9 +205,10 @@ class KunlunPlatform(Platform):
vllm_config.compilation_config.pass_config.enable_fusion = False
vllm_config.compilation_config.use_inductor = False
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,use_sink):
kv_cache_dtype, block_size, use_v1, use_mla,use_sink, use_sparse=False):
"""
Returns the class of attention backend based on the selected backend and other parameters.
@@ -227,15 +235,16 @@ class KunlunPlatform(Platform):
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
"""Gets the current memory usage of the device, including allocated and max allocated.
If no device is specified, defaults to the current context's device.
Args:
device (Optional[torch.types.Device], optional): Optional device object, defaults to None. Defaults to the current context's device.
Returns:
float: Returns a float representing the current memory usage of the device, in bytes.
"""
获取当前设备的内存使用情况,包括已分配和最大分配。
如果未指定设备,则默认为当前上下文中的设备。
Args:
device (Optional[torch.types.Device], optional): 可选的设备对象默认为None。默认为当前上下文中的设备。
Returns:
float: 返回一个浮点数表示当前设备的内存使用情况单位是字节bytes
Raises:
None.
"""
@@ -244,17 +253,18 @@ class KunlunPlatform(Platform):
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
"""Checks if asynchronous output is supported.
By default, Kunlun does not support asynchronous output.
Args:
enforce_eager (Optional[bool], optional): Whether to enforce eager execution. Defaults to None.
None means not to force eager execution, but to automatically select based on the current environment.
Returns:
bool: True means asynchronous output is supported, False means asynchronous output is not supported.
"""
# Assume Kunlun does not support asynchronous output
判断是否支持异步输出。
默认情况下Kunlun 不支持异步输出。
Args:
enforce_eager (Optional[bool], optional): 是否强制使用 eager execution. Defaults to None.
None 表示不强制使用 eager execution而是根据当前环境自动选择。
Returns:
bool: True 表示支持异步输出False 表示不支持异步输出。
"""
# 假设 Kunlun 不支持异步输出
return False
@classmethod
@@ -279,11 +289,42 @@ class KunlunPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
'''
'''
communicator
'''
return "vllm_kunlun.distributed.kunlun_communicator.KunlunCommunicator"
return "vllm_kunlun.distributed.kunlun_communicator.KunlunCommunicator"
@classmethod
def get_punica_wrapper(cls):
return "vllm_kunlun.lora.punica_wrapper.punica_kunlun.PunicaWrapperKunlun"
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
'''
Kunlun3平台支持的数据类型
'''
supported_dtypes = {
torch.float32,
torch.float16,
torch.bfloat16,
torch.int8,
}
if torch_dtype not in supported_dtypes:
raise ValueError(
f"Kunlun platform does not support dtype {torch_dtype}. "
"Supported dtypes are: fp32, fp16, bf16, int8."
)
def opaque_attention_op(cls) -> bool:
'''
确保V1 Graph在Kunlun3平台使用vllm.unified_attention_with_output_kunlun作为split ops
'''
return True
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True

View File

@@ -1,61 +1,28 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import vllm
from torch.utils._python_dispatch import TorchDispatchMode
import vllm_kunlun.platforms.envs as xenvs
import vllm_kunlun.platforms.envs as xenvs
from vllm.utils import weak_ref_tensor
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
get_origin,
get_args,
List,
)
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, TypeVar, Union, cast, overload,
get_origin, get_args, List)
import torch
from torch.library import Library
import inspect
import typing
def redirect_output():
"""
Redirect output to a specified directory and name the log files as pp=0_rank=X or pp=1_rank=X.
If it is the first process of the first process group, use pp=0; otherwise, use pp=1.
重定向输出到指定目录,并将日志文件命名为pp=0_rank=Xpp=1_rank=X
如果是第一个进程组的第一个进程则使用pp=0否则使用pp=1
Args:
No parameters.
无参数。
Returns:
No return value, directly modify the file descriptors of sys.stdout and sys.stderr.
无返回值,直接修改sys.stdoutsys.stderr的文件描述符。
"""
from vllm.distributed import get_tensor_model_parallel_rank, get_pp_group
rank = get_tensor_model_parallel_rank()
dir_path = xenvs.VLLM_MULTI_LOGPATH
os.makedirs(dir_path, exist_ok=True)
@@ -63,54 +30,48 @@ def redirect_output():
log_file = os.path.join(dir_path, f"pp=0_rank={rank}.log")
else:
log_file = os.path.join(dir_path, f"pp=1_rank={rank}.log")
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT| os.O_TRUNC, 0o644)
os.dup2(fd, sys.stdout.fileno())
os.dup2(fd, sys.stderr.fileno())
os.close(fd)
def multi_log_monkey_patch(func):
"""
Monkey patch function for logging multiple times, used to test log redirection functionality.
This function will print a log message each time the patched function is called.
多次打印日志的猴子补丁函数,用于测试日志重定向功能。
该函数会在每次调用被补丁的函数时打印一条日志信息。
Args:
func (function): The original function to be patched.
func (function): 需要被补丁的原始函数。
Returns:
function: A wrapped new function that prints a log message each time it is called.
function: 返回一个包装后的新函数,每次调用都会打印一条日志信息。
"""
def wrapper(*args, **kwargs):
print("[monkey patch] ensure_model_parallel_initialized")
func(*args, **kwargs)
redirect_output()
return wrapper
# if os.environ.get("VLLM_MULTI_LOG", "0") == "1":
if xenvs.ENABLE_VLLM_MULTI_LOG:
print("ENABLE_VLLM_MULTI_LOG monkey--------")
vllm.distributed.ensure_model_parallel_initialized = multi_log_monkey_patch(
vllm.distributed.ensure_model_parallel_initialized
)
vllm.distributed.ensure_model_parallel_initialized)
class StageHookPre(object):
def __call__(self, *args, **kwargs):
"""
This method will be automatically executed when the object is called.
If the current attention metadata is not None and a token has been processed, print "Per Token Start"; otherwise, print "First Token Start".
在调用对象时,会自动执行此方法。
如果当前的attention metadata不为None并且已经处理了一个token则打印"Per Token Start";否则打印"First Token Start"
Args:
args (tuple, optional): Variable length argument list, default is an empty tuple.
kwargs (dict, optional): Keyword arguments, default is an empty dictionary.
args (tuple, optional): 可变参数,默认为空元组。
kwargs (dict, optional): 关键字参数,默认为空字典。
Returns:
None: No return value.
None: 无返回值。
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
@@ -118,22 +79,20 @@ class StageHookPre(object):
else:
print("Per Token Start", flush=True)
class StageHookPost(object):
def __call__(self, *args, **kwargs):
"""
If the current context's attention metadata is not None and num_decode_tokens equals 0, print "First Token End".
Otherwise, print "Per Token End".
如果当前上下文中的attention metadata不为None并且num_decode_tokens等于0则打印"First Token End"
否则,打印"Per Token End"
Args:
args (Tuple[Any]): Variable length argument list, unused parameters are passed in.
kwargs (Dict[str, Any]): Keyword arguments, unused parameters are passed in.
args (Tuple[Any]): 可变长度参数列表,无用参数传入。
kwargs (Dict[str, Any]): 字典类型的关键字参数,无用参数传入。
Returns:
None: No return value.
None: 该函数没有返回值。
"""
from vllm.forward_context import get_forward_context
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
if attn_metadata.num_decode_tokens == 0:
@@ -145,24 +104,22 @@ class StageHookPost(object):
class ModuleLoggingHookPre(object):
def __init__(self):
"""
Initialization function to initialize the indentation list and name list.
The indentation list is used to store the indentation information of each line,
and the name list is used to store the name of each variable or function.
初始化函数,用于初始化缩进列表和名称列表。
缩进列表用于存储每一行的缩进信息,名称列表用于存储每一个变量或函数的名称。
"""
self.indent_list = list()
self.indent_list.append("")
self.name_list = list()
def __call__(self, *args, **kwargs):
"""
This method overrides the __call__ method and is used when the class is instantiated.
It increases the current indentation by one Tab and records the current class name.
It prints the start information, flush=True means it will be output to the console immediately.
重写了 __call__ 方法,用于在类实例化时调用。
将当前缩进增加一个 Tab并记录当前类名称。
打印开始信息flush=True 表示立即输出到控制台。
Args:
args (tuple): Variable length argument list, default is an empty tuple.
kwargs (dict): Keyword arguments, default is an empty dictionary.
args (tuple): 传入的参数列表,第一个元素是类实例。
kwargs (dict): 传入的关键字参数列表,不使用。
Returns:
None.
"""
@@ -174,70 +131,62 @@ class ModuleLoggingHookPre(object):
class ModuleLoggingHookPost(object):
def __init__(self, indent_list, name_list):
"""
Initialization function to set the indentation list and name list.
初始化函数,设置缩进列表和名称列表。
Args:
indent_list (List[str]): A list of indentation strings for each node, indexed from 0.
name_list (List[str]): A list of name strings for each node, indexed from 0.
Note: The indentation list and name list should have the same length, otherwise it will cause an error.
indent_list (List[str]): 包含每个节点的缩进字符串的列表索引从0开始。
name_list (List[str]): 包含每个节点的名称字符串的列表索引从0开始。
注意:缩进列表和名称列表应该有相同长度,否则会导致错误。
Returns:
None: No return value, directly modifies the instance's attributes.
None. 无返回值,直接修改了类实例的属性。
"""
self.indent_list = indent_list
self.name_list = name_list
def __call__(self, *args, **kwargs):
"""
This method is called when the object is invoked.
Args:
*args, **kwargs: Variable length argument list and keyword argument dictionary, unused.
Returns:
None: No return value.
当调用对象时,输出模块结束信息。
参数:*args、**kwargs - 可变长度的位置参数列表和关键字参数字典,未使用。
返回值None无返回值。
"""
print(self.indent_list[-1] + self.name_list[-1] + " Module End", flush=True)
self.indent_list.pop()
self.name_list.pop()
# if os.environ.get("ENABLE_VLLM_MODULE_HOOK", "0") == "1":
if xenvs.ENABLE_VLLM_MODULE_HOOK:
from torch.nn.modules.module import (
register_module_forward_pre_hook,
register_module_forward_hook,
)
from torch.nn.modules.module import register_module_forward_pre_hook, register_module_forward_hook
module_logging_hook_pre = ModuleLoggingHookPre()
module_logging_hook_post = ModuleLoggingHookPost(
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list
)
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list)
register_module_forward_pre_hook(module_logging_hook_pre)
register_module_forward_hook(module_logging_hook_post)
else:
module_logging_hook_pre = None
module_logging_hook_post = None
class LoggingDispatchMode(TorchDispatchMode):
def __init__(self):
"""
Initialization function to initialize the attributes and methods of the class.
Some initialization operations can be performed here, such as setting default values.
初始化函数,用于初始化类的属性和方法。
在此处可以进行一些初始化操作,例如设置默认值等。
"""
super().__init__()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""
Override the default dispatch behavior of torch.nn.Module.
This function will be called before and after each method call on this module.
It can be used to log information about the method calls.
Args:
func (function): The function that is being called on this module.
types (Tuple[str]): A tuple of strings representing the type signatures of the arguments.
See torch.types for more details.
args (Tuple[Any], optional): The positional arguments passed to the function. Defaults to ().
kwargs (Dict[str, Any], optional): The keyword arguments passed to the function. Defaults to {}.
Returns:
Any: The result returned by the function.
"""
@@ -249,20 +198,19 @@ class LoggingDispatchMode(TorchDispatchMode):
print(indent + "{} calling".format(func), flush=True)
result = func(*args, **(kwargs or {}))
print(indent + "{} called".format(func), flush=True)
return result
return result
class CUDAGraphInnerWatcher(TorchDispatchMode):
def __init__(self, name_list):
"""
Initialization function to save the name list to the class attribute.
It also creates a dictionary to keep track of the tensors that have been traced.
初始化函数,将传入的名称列表保存到类属性中。
同时创建一个字典来记录已经追踪过的张量。
Args:
name_list (List[str]): A list of names of tensors to be tracked.
name_list (List[str]): 包含需要追踪的张量名称的列表。
Returns:
None.
"""
@@ -274,13 +222,13 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
Override the default dispatch behavior of PyTorch tensors to track
the tracing process. If the result of a function call is a tensor on CUDA,
it will be added to the traced_tensor dictionary with the name of the function.
Args:
func (Callable): The function to be called.
types (Tuple[Type]): The type hints of the function.
args (Tuple[Any], optional): Positional arguments for the function. Defaults to ().
kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the function. Defaults to None.
Returns:
Any: The result of the function call.
"""
@@ -292,13 +240,13 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Clear the traced_tensor and name_list, and call the parent class's __exit__ method.
清空 traced_tensor name_list,并调用父类的 __exit__ 方法。
Args:
exc_type (Optional[Type[BaseException]]): The type of the exception, default is None.
exc_val (Optional[BaseException]): The value of the exception, default is None.
exc_tb (Optional[TracebackType]): he traceback object, default is None.
exc_type (Optional[Type[BaseException]]): 异常类型,默认为 None
exc_val (Optional[BaseException]): 异常值,默认为 None
exc_tb (Optional[TracebackType]): Traceback 对象,默认为 None
Returns:
None.
"""
@@ -308,11 +256,22 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
self.name_list.clear()
super(CUDAGraphInnerWatcher, self).__exit__(exc_type, exc_val, exc_tb)
# def patch_annotations_for_schema(func):
# sig = inspect.signature(func)
# new_params = []
# for name, param in sig.parameters.items():
# anno = param.annotation
# if anno == list[int]:
# anno = typing.List[int]
# new_params.append(param.replace(annotation=anno))
# new_sig = sig.replace(parameters=new_params)
# func.__signature__ = new_sig
# return func
def patch_annotations_for_schema(func):
"""
At runtime, replace list[int] and Optional[list[int]] in the function signature with typing.List[int] and Optional[typing.List[int]]
so that torch.library.infer_schema can recognize it.
运行时替换函数签名里的 list[int]Optional[list[int]] typing.List[int] / Optional[typing.List[int]]
torch.library.infer_schema 能识别
"""
sig = inspect.signature(func)
new_params = []
@@ -320,7 +279,7 @@ def patch_annotations_for_schema(func):
for name, param in sig.parameters.items():
ann = param.annotation
# If it is Optional[T]
# 如果是 Optional[T]
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
if get_origin(inner_type) is list: # Optional[list[int]]
@@ -328,7 +287,7 @@ def patch_annotations_for_schema(func):
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
param = param.replace(annotation=new_ann)
# If it is a direct list[int]
# 如果是直接 list[int]
elif get_origin(ann) is list:
args = get_args(ann)
new_ann = List[args[0] if args else typing.Any]
@@ -339,23 +298,20 @@ def patch_annotations_for_schema(func):
func.__signature__ = sig.replace(parameters=new_params)
return func
def supports_custom_op() -> bool:
"""supports_custom_op"""
return hasattr(torch.library, "custom_op")
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
@@ -374,28 +330,25 @@ def direct_register_custom_op(
"""
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies."
)
"the required dependencies.")
return
import torch.library
if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func)
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
my_lib._register_fake(op_name, fake_impl)

View File

@@ -1,8 +1,5 @@
#
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
# Author: Dong Xinyu, Bao Qian, Chen Zhennan, Ma Tianyu, Wang Haowen
# Email: dongxinyu03@baidu.com
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
#
from vllm.config import VllmConfig, get_layers_from_vllm_config
import xtorch_ops
from dataclasses import dataclass
@@ -24,8 +23,8 @@ import torch
import numpy as np
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
# from vllm.attention.backends.utils import CommonAttentionState
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
from vllm_kunlun.ops._kunlun_ops import KunlunOps
@@ -45,6 +44,7 @@ from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig, get_layers_from_vllm_config
class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend"""
# crucial to cuda graph
@@ -70,10 +70,10 @@ class KunlunAttentionBackend(AttentionBackend):
"""get_builder_cls"""
return KunlunAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
"""get_state_cls"""
return CommonAttentionState
# @staticmethod
# def get_state_cls() -> Type["CommonAttentionState"]:
# """get_state_cls"""
# return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
@@ -81,6 +81,7 @@ class KunlunAttentionBackend(AttentionBackend):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto"
) -> Tuple[int, ...]:
"""get_kv_cache_shape"""
# return (2, num_blocks, block_size, num_kv_heads * head_size)
@@ -132,7 +133,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
slot_mapping: torch.Tensor
block_tables: torch.Tensor
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
@@ -143,6 +148,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Prefix cache loc
kv_lod_cpu: Optional[torch.Tensor] = None
kv_lod_xpu: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
@@ -181,6 +191,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
enable_kv_scales_calculation: Optional[bool] = False
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
@@ -193,6 +204,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
use_cascade: Optional[bool] = False
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
def __post_init__(self):
"""__post_init__"""
@@ -253,6 +269,19 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
input_positions = (None if self.input_positions is None else
self.input_positions[-self.num_prefills:])
if self.kv_lod_cpu is None:
kv_lod_cpu = None
kv_lod_xpu = None
else:
start = -(self.num_prefills + 1)
base_cpu = self.kv_lod_cpu[start]
kv_lod_cpu = self.kv_lod_cpu[start:] - base_cpu
base_xpu = self.kv_lod_xpu[start]
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens,
@@ -264,7 +293,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=None,
seq_start_loc = None,
kv_lod_cpu=kv_lod_cpu,
kv_lod_xpu=kv_lod_xpu,
max_query_len=self.max_query_len,
max_kv_len=self.max_kv_len,
max_prefill_seq_len=self.max_prefill_seq_len,
@@ -413,18 +444,28 @@ class KunlunAttentionMetadataBuilder:
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> KunlunMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens_tensor.fill_(1)
return attn_metadata
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
"""build"""
num_reqs=common_attn_metadata.num_reqs
num_actual_tokens=common_attn_metadata.num_actual_tokens
max_query_len=common_attn_metadata.max_query_len
common_prefix_len=common_prefix_len
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
common_prefix_len = common_prefix_len
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
@@ -432,18 +473,18 @@ class KunlunAttentionMetadataBuilder:
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_start_loc = list(accumulate(seq_lens, initial=0))
if len(seq_start_loc) != num_reqs + 1:
seq_start_loc = query_start_loc_host.tolist()
if seq_start_loc[-1] != num_actual_tokens:
seq_start_loc = query_start_loc_host.tolist()
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
kv_lod_xpu = kv_lod_cpu.to(self.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
@@ -456,6 +497,7 @@ class KunlunAttentionMetadataBuilder:
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
if num_prefill_tokens == 0:
max_prefill_seq_len = 0
else:
@@ -473,6 +515,8 @@ class KunlunAttentionMetadataBuilder:
num_decode_tokens=num_decode_tokens,
seq_lens_tensor=seq_lens,
seq_lens_tensor_cpu=seq_lens_cpu,
kv_lod_xpu=kv_lod_xpu,
kv_lod_cpu=kv_lod_cpu,
max_query_len=max_prefill_seq_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
@@ -483,7 +527,6 @@ class KunlunAttentionMetadataBuilder:
use_cuda_graph=False,
use_cascade=use_cascade,
)
return attn_metadata
def can_run_in_cudagraph(
@@ -514,11 +557,15 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
sinks:Optional[torch.Tensor]= None,
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
) -> None:
"""__init__"""
if blocksparse_params is not None:
raise ValueError(
"kunlunAttention does not support block-sparse attention.")
# if logits_soft_cap is not None:
# raise ValueError(
# "kunlunAttention does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -547,6 +594,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
def forward(
self,
@@ -560,7 +608,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""forward"""
query = query.view(-1, self.num_heads, self.head_size)
@@ -597,12 +646,22 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
value = value.contiguous()
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping)
if key_cache.is_contiguous():
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping)
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
xtorch_ops.reshape_and_cache_flash(
key,
value,
cast_key_cache,
cast_value_cache,
updated_slot_mapping)
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
@@ -614,22 +673,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
assert prefill_query.shape[0] == num_prefill_tokens
output[num_decode_tokens:attn_metadata.num_actual_tokens] = KunlunOps.multi_query_kv_attention(
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, prefill_query, prefill_key, prefill_value,
alibi_slopes=self.alibi_slopes).view_as(prefill_query)
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache (block_num, head, block_size, dim)
v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=prefill_meta.block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes,
softmax_lse=None,
sink=self.sinks
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
decode_query = query[:num_decode_tokens]
if key_cache.is_contiguous():
tmp_block_tables = decode_meta.block_tables
else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
block_tables=decode_meta.block_tables,
block_tables=tmp_block_tables,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
is_context=False,
@@ -639,7 +714,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
@@ -650,13 +724,18 @@ def use_cascade_attention(
num_sms: int,
use_local_attention: bool = False,
) -> bool:
"""
TODO: Not Yet Supported on Kunlun platform
"""Decide whether to use cascade attention.
This function 1) checks whether cascade attention is supported with the
given configuration, and 2) heuristically decides whether using cascade
attention can improve performance.
"""
# Too short common prefix. Probably not worth using cascade attention.
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
return False
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.

View File

@@ -1,91 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
# Apply repetition penalties as a custom op
from vllm._custom_ops import apply_repetition_penalties_torch
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module):
Implementations may update the logits tensor in-place.
"""
def __init__(self):
def __init__(self, logprobs_mode):
super().__init__()
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
@@ -57,7 +57,7 @@ class TopKTopPSampler(nn.Module):
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
return random_sample(probs, generators), None
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
@@ -66,8 +66,7 @@ class TopKTopPSampler(nn.Module):
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators)
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def apply_top_k_top_p(
@@ -195,4 +194,4 @@ def flashinfer_sample(
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)
return next_token_ids.view(-1)
return next_token_ids.view(-1)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
"""worker"""
from typing import Dict, List, Optional, Set, Tuple, Type, Union
from vllm.v1.worker.gpu_worker import Worker, _check_if_gpu_supports_dtype, init_worker_distributed_environment
from vllm.model_executor import set_random_seed
from .model_runner import KunlunModelRunner
from vllm.utils import MemorySnapshot
import torch
import os
import gc
class KunlunWorker(Worker):
"""Worker"""
def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.init_snapshot = MemorySnapshot()
free_memory, total = torch.cuda.mem_get_info()
self.init_gpu_memory = free_memory
# 设置一个合理的初始值,比如总内存的 80%
self.requested_memory = int(total * 0.2) # 留出 20% 的余量
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: KunlunModelRunner = KunlunModelRunner(
self.vllm_config, self.device)