提交vllm0.11.0开发分支
This commit is contained in:
@@ -1,21 +1,3 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Xinyu Dong
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""vllm kunlun init"""
|
||||
from .platforms import current_platform
|
||||
import sys
|
||||
@@ -25,26 +7,24 @@ import builtins
|
||||
import os
|
||||
import time
|
||||
import vllm.envs as envs
|
||||
|
||||
OLD_IMPORT_HOOK = builtins.__import__
|
||||
|
||||
|
||||
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Module mapping table
|
||||
# 模块映射表
|
||||
module_mappings = {
|
||||
"vllm.model_executor.layers.fused_moe.layer": "vllm_kunlun.ops.fused_moe.layer",
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe": "vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
|
||||
"vllm.v1.worker.gpu_model_runner": "vllm_kunlun.v1.worker.gpu_model_runner"
|
||||
}
|
||||
|
||||
# Keep the original imported modules
|
||||
# 需要保持原始导入的模块
|
||||
original_imports = [
|
||||
"vllm.model_executor.layers.fused_moe.base",
|
||||
"vllm.model_executor.layers.fused_moe.config",
|
||||
"vllm.model_executor.layers.fused_moe.layer",
|
||||
"vllm.model_executor.layers.fused_moe.layer"
|
||||
]
|
||||
|
||||
if module_name in original_imports:
|
||||
@@ -55,7 +35,7 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level,
|
||||
level=level
|
||||
)
|
||||
|
||||
if module_name in module_mappings:
|
||||
@@ -68,15 +48,12 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
||||
return module
|
||||
|
||||
relative_mappings = {
|
||||
(
|
||||
"compressed_tensors_moe",
|
||||
"compressed_tensors",
|
||||
): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
("compressed_tensors_moe", "compressed_tensors"): "vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
("layer", "fused_moe"): "vllm_kunlun.ops.fused_moe.layer",
|
||||
}
|
||||
|
||||
if level == 1:
|
||||
parent = globals.get("__package__", "").split(".")[-1] if globals else ""
|
||||
parent = globals.get('__package__', '').split('.')[-1] if globals else ''
|
||||
key = (module_name, parent)
|
||||
if key in relative_mappings:
|
||||
if module_name in sys.modules:
|
||||
@@ -91,15 +68,18 @@ def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0)
|
||||
pass
|
||||
|
||||
return OLD_IMPORT_HOOK(
|
||||
module_name, globals=globals, locals=locals, fromlist=fromlist, level=level
|
||||
module_name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level
|
||||
)
|
||||
|
||||
|
||||
def import_hook():
|
||||
"""Apply import hook for VLLM Kunlun"""
|
||||
if not int(os.environ.get("DISABLE_KUNLUN_HOOK", "0")):
|
||||
builtins.__import__ = _custom_import
|
||||
|
||||
|
||||
try:
|
||||
modules_to_preload = [
|
||||
"vllm_kunlun.ops.quantization.compressed_tensors_moe",
|
||||
@@ -112,31 +92,39 @@ def import_hook():
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def register():
|
||||
"""Register the Kunlun platform"""
|
||||
from .utils import redirect_output
|
||||
from .vllm_utils_wrapper import (
|
||||
direct_register_custom_op,
|
||||
patch_annotations_for_schema,
|
||||
)
|
||||
|
||||
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
|
||||
patch_bitsandbytes_loader()
|
||||
import_hook()
|
||||
if envs.VLLM_USE_V1:
|
||||
patch_V1blockTable()
|
||||
# patch_V1blockTable()
|
||||
patch_V1top_p_K()
|
||||
patch_V1penalties()
|
||||
# TODO fixed fast top & k for vLLM 0.10.2,
|
||||
pass
|
||||
else:
|
||||
patch_sampler()
|
||||
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
|
||||
|
||||
|
||||
def register_model():
|
||||
"""Register models for training and inference"""
|
||||
from .models import register_model as _reg
|
||||
|
||||
_reg()
|
||||
|
||||
# [monkey patach sampler]
|
||||
import sys
|
||||
import sys, importlib, warnings
|
||||
|
||||
def patch_bitsandbytes_loader():
|
||||
try:
|
||||
# 载入你插件里自定义的 direct_register_custom_op 实现
|
||||
custom_utils = importlib.import_module("vllm_kunlun.models.model_loader.bitsandbytes_loader")
|
||||
# 覆盖 vllm.utils
|
||||
sys.modules["vllm.model_executor.model_loader.bitsandbytes_loader"] = custom_utils
|
||||
print("[vllm_kunlun] bitsandbytes_loader patched ->", custom_utils.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] bitsandbytes_loader patch failed: {e!r}")
|
||||
|
||||
def patch_sampler():
|
||||
try:
|
||||
@@ -149,24 +137,12 @@ def patch_sampler():
|
||||
|
||||
def patch_V1top_p_K():
|
||||
try:
|
||||
custom_sampler = importlib.import_module(
|
||||
"vllm_kunlun.v1.sample.ops.topk_topp_sampler"
|
||||
)
|
||||
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.topk_topp_sampler")
|
||||
sys.modules["vllm.v1.sample.ops.topk_topp_sampler"] = custom_sampler
|
||||
print("[vllm_kunlun] V1sampler top p & k patched ->", custom_sampler.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] V1 sampler top p & k patch failed: {e!r}")
|
||||
|
||||
|
||||
def patch_V1penalties():
|
||||
try:
|
||||
custom_sampler = importlib.import_module("vllm_kunlun.v1.sample.ops.penalties")
|
||||
sys.modules["vllm.v1.sample.ops.penalties"] = custom_sampler
|
||||
print("[vllm_kunlun] V1sampler penalties patched ->", custom_sampler.__file__)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] V1 sampler penalties patch failed: {e!r}")
|
||||
|
||||
|
||||
def patch_V1blockTable():
|
||||
try:
|
||||
custom_sampler = importlib.import_module("vllm_kunlun.v1.worker.block_table")
|
||||
@@ -175,6 +151,5 @@ def patch_V1blockTable():
|
||||
except Exception as e:
|
||||
warnings.warn(f"[vllm_kunlun] V1 block table patch failed: {e!r}")
|
||||
|
||||
|
||||
# Automatically apply patches when modules are imported
|
||||
# 在模块导入时自动应用补丁
|
||||
import_hook()
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian
|
||||
# Email: baoqian@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
@@ -46,7 +32,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
def __init__(self,
|
||||
compiled_callable: Optional[Callable] = None,
|
||||
compilation_level: int = 0):
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config, CUDAGraphMode
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
@@ -61,9 +47,13 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
fullgraph=True, #envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=backend,
|
||||
options=options)
|
||||
|
||||
# print(vllm_config.compilation_config)
|
||||
# vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
# vllm_config.compilation_config.cudagraph_capture_sizes = [32768]
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
@@ -126,7 +116,12 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if self.vllm_config.compilation_config.use_cudagraph and \
|
||||
# "update" in new_code.co_names:
|
||||
# import depyf
|
||||
# src = depyf.decompile(new_code)
|
||||
# msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||
# raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
|
||||
@@ -1,20 +1,3 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian, Dong Xinyu
|
||||
# Email: baoqian@baidu.com, dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun_communicator"""
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""
|
||||
|
||||
from vllm_kunlun.lora.ops.kunlun_ops.lora_ops import (bgmv_expand,bgmv_expand_slice, bgmv_shrink,
|
||||
sgmv_expand, sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_expand_slice",
|
||||
"sgmv_shrink"
|
||||
]
|
||||
@@ -1,443 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# Author: Wang Hao
|
||||
# Email: wanghao129@baidu.com
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun_ops for lora"""
|
||||
|
||||
import torch
|
||||
from torch._C import dtype
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
expert_m: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
"""
|
||||
sgmv_shrink
|
||||
"""
|
||||
|
||||
expert_num = 9
|
||||
device = inputs.device
|
||||
|
||||
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
|
||||
|
||||
|
||||
|
||||
torch.ops._C.gen_block_statistic(lora_ids, block_statistic)
|
||||
|
||||
|
||||
inputs_sorted = torch.zeros_like(inputs, dtype=inputs.dtype, device=device)
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
inputs,
|
||||
lora_ids,
|
||||
block_statistic,
|
||||
inputs_sorted,
|
||||
moe_index,
|
||||
expert_m,
|
||||
sorted_tokens_num_lod
|
||||
)
|
||||
|
||||
|
||||
output_tensor.unsqueeze_(1)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=inputs_sorted,
|
||||
weight=lora_a_weights,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=moe_index,
|
||||
moe_topk=1,
|
||||
y=output_tensor,
|
||||
act=None,
|
||||
x_perchannel_max=None,
|
||||
w_perchannel_max=None,
|
||||
topk_ids=None,
|
||||
topk_w=None,
|
||||
bias=None,
|
||||
tgemm_type=None,
|
||||
tweight_type=None,
|
||||
scale_n=0,
|
||||
scale_k=0,
|
||||
use_pack_int4=False
|
||||
)
|
||||
|
||||
output_tensor.squeeze_(1).mul_(scaling)
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
"""
|
||||
sgmv_expand
|
||||
"""
|
||||
|
||||
|
||||
expert_num = 9
|
||||
device = inputs.device
|
||||
|
||||
|
||||
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
|
||||
|
||||
out = torch.zeros((token_nums, 1, slice_size), dtype=inputs.dtype, device=device)
|
||||
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=inputs,
|
||||
weight=lora_b_weights,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=moe_index,
|
||||
moe_topk=1,
|
||||
y=out,
|
||||
act=None,
|
||||
x_perchannel_max=None,
|
||||
w_perchannel_max=None,
|
||||
topk_ids=None,
|
||||
topk_w=None,
|
||||
bias=None,
|
||||
tgemm_type=None,
|
||||
tweight_type=None,
|
||||
scale_n=0,
|
||||
scale_k=0,
|
||||
use_pack_int4=False
|
||||
)
|
||||
|
||||
output_post = out.squeeze(1)
|
||||
torch.ops._C.moe_post(
|
||||
output_post,
|
||||
moe_index.unsqueeze(1),
|
||||
normed_scale,
|
||||
normed_scale,
|
||||
output_post
|
||||
)
|
||||
|
||||
|
||||
common_len = min(output_post.shape[1], output_tensor.shape[1])
|
||||
|
||||
limit = min(output_post.shape[0], output_tensor.shape[0])
|
||||
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:limit, :common_len] += output_post[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:limit, :common_len] = output_post[:limit, :common_len]
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
normed_scale: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
|
||||
"""
|
||||
sgmv_expand_slice
|
||||
"""
|
||||
|
||||
expert_num = 9
|
||||
device = inputs.device
|
||||
|
||||
lora_ids = lora_indices_tensor.repeat_interleave(seq_len_tensor, dim=0).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1).unsqueeze_(1)
|
||||
|
||||
|
||||
out = torch.zeros((token_nums, 1, slice_size), dtype=inputs.dtype, device=device)
|
||||
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=inputs,
|
||||
weight=lora_b_weights,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=moe_index,
|
||||
moe_topk=1,
|
||||
y=out,
|
||||
act=None,
|
||||
x_perchannel_max=None,
|
||||
w_perchannel_max=None,
|
||||
topk_ids=None,
|
||||
topk_w=None,
|
||||
bias=None,
|
||||
tgemm_type=None,
|
||||
tweight_type=None,
|
||||
scale_n=0,
|
||||
scale_k=0,
|
||||
use_pack_int4=False
|
||||
)
|
||||
|
||||
output_post = out.squeeze(1)
|
||||
torch.ops._C.moe_post(
|
||||
output_post,
|
||||
moe_index.unsqueeze(1),
|
||||
normed_scale,
|
||||
normed_scale,
|
||||
output_post
|
||||
)
|
||||
|
||||
|
||||
slice_end = slice_offset + slice_size
|
||||
actual_slice_size = min(slice_size, output_tensor.shape[1] - slice_offset)
|
||||
|
||||
limit = min(output_post.shape[0], output_tensor.shape[0])
|
||||
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:limit, slice_offset:slice_end] += output_post[:limit, :actual_slice_size]
|
||||
else:
|
||||
output_tensor[:limit, slice_offset:slice_end] = output_post[:limit, :actual_slice_size]
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor, # [m, hidden_dim]
|
||||
lora_a_weights: torch.Tensor, # [n, 1, r, hidden_dim]
|
||||
output_tensor: torch.Tensor, # [m, r]
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
expert_m: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor, # [m]
|
||||
scaling: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
bgmv_shrink
|
||||
"""
|
||||
|
||||
expert_num = 9
|
||||
|
||||
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
|
||||
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
|
||||
|
||||
torch.ops._C.gen_block_statistic(lora_ids.unsqueeze(1), block_statistic)
|
||||
|
||||
inputs_sorted = torch.empty_like(inputs, dtype=inputs.dtype, device=inputs.device)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
inputs,
|
||||
lora_ids.unsqueeze(1),
|
||||
block_statistic,
|
||||
inputs_sorted,
|
||||
moe_index,
|
||||
expert_m,
|
||||
sorted_tokens_num_lod
|
||||
)
|
||||
|
||||
output_tensor.unsqueeze_(1) # Change to [m, 1, r]
|
||||
torch.ops._C.moe_fc(
|
||||
x=inputs_sorted,
|
||||
weight=lora_a_weights,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=moe_index,
|
||||
moe_topk=1,
|
||||
y=output_tensor,
|
||||
act=None,
|
||||
x_perchannel_max=None,
|
||||
w_perchannel_max=None,
|
||||
topk_ids=None,
|
||||
topk_w=None,
|
||||
bias=None,
|
||||
tgemm_type=None,
|
||||
tweight_type=None,
|
||||
scale_n=0,
|
||||
scale_k=0,
|
||||
use_pack_int4=False
|
||||
)
|
||||
|
||||
output_tensor.squeeze_(1).mul_(scaling)
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
""""
|
||||
bgmv_expand
|
||||
"""
|
||||
|
||||
|
||||
expert_num = 9
|
||||
device = inputs.device
|
||||
|
||||
|
||||
|
||||
|
||||
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
|
||||
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
|
||||
|
||||
out = torch.zeros((inputs.shape[0], 1, slice_size), dtype=inputs.dtype, device=device)
|
||||
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=inputs,
|
||||
weight=lora_b_weights,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=moe_index,
|
||||
moe_topk=1,
|
||||
y=out,
|
||||
act=None,
|
||||
x_perchannel_max=None,
|
||||
w_perchannel_max=None,
|
||||
topk_ids=None,
|
||||
topk_w=None,
|
||||
bias=None,
|
||||
tgemm_type=None,
|
||||
tweight_type=None,
|
||||
scale_n=0,
|
||||
scale_k=0,
|
||||
use_pack_int4=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
output_post = out.squeeze(1)
|
||||
torch.ops._C.moe_post(output_post, moe_index.unsqueeze(1), normed_scale, normed_scale, output_post)
|
||||
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if output_post.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
# LoRA adapter and model may add different amounts of padding to output
|
||||
common_len = min(output_post.shape[1], output_tensor.shape[1])
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :common_len] += output_post[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:, :common_len] = output_post[:limit, :common_len]
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
normed_scale: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True
|
||||
):
|
||||
"""
|
||||
bgmv_expand_slice
|
||||
"""
|
||||
|
||||
expert_num = 9
|
||||
device = inputs.device
|
||||
|
||||
|
||||
|
||||
|
||||
lora_ids = lora_indices_tensor.to(dtype=torch.int32, device=inputs.device)
|
||||
lora_ids.masked_fill_(lora_ids < 0, expert_num - 1)
|
||||
|
||||
out = torch.zeros((inputs.shape[0], 1, slice_size), dtype=inputs.dtype, device=device)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=inputs,
|
||||
weight=lora_b_weights,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=moe_index,
|
||||
moe_topk=1,
|
||||
y=out,
|
||||
act=None,
|
||||
x_perchannel_max=None,
|
||||
w_perchannel_max=None,
|
||||
topk_ids=None,
|
||||
topk_w=None,
|
||||
bias=None,
|
||||
tgemm_type=None,
|
||||
tweight_type=None,
|
||||
scale_n=0,
|
||||
scale_k=0,
|
||||
use_pack_int4=False
|
||||
)
|
||||
|
||||
|
||||
output_post = out.squeeze(1)
|
||||
torch.ops._C.moe_post(output_post, moe_index.unsqueeze(1), normed_scale, normed_scale, output_post)
|
||||
|
||||
|
||||
slice_end = slice_offset + slice_size
|
||||
actual_slice_size = min(slice_size, output_tensor.shape[1] - slice_offset)
|
||||
limit = min(output_post.shape[0], output_tensor.shape[0])
|
||||
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:limit, slice_offset:slice_end] += output_post[:limit, :actual_slice_size]
|
||||
else:
|
||||
output_tensor[:limit, slice_offset:slice_end] = output_post[:limit, :actual_slice_size]
|
||||
|
||||
return output_tensor
|
||||
@@ -1,547 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Wang Hao
|
||||
# Email: wanghao129@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
|
||||
from vllm_kunlun.lora.ops.kunlun_ops import (
|
||||
bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
import time
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
class PunicaWrapperKunlun(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperKunlun with moe_fc
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_batches: int,
|
||||
device: Union[torch.device, str],
|
||||
**kwargs,
|
||||
):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
|
||||
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
|
||||
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
expert_m,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
|
||||
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
|
||||
bgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
expert_m,
|
||||
self.token_lora_indices,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
bgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
self.token_lora_indices,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
|
||||
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
|
||||
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
normed_scale,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
|
||||
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
|
||||
|
||||
bgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
normed_scale,
|
||||
self.token_lora_indices,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (
|
||||
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
|
||||
)
|
||||
expand_slice_fun(
|
||||
y,
|
||||
x,
|
||||
w_t_all,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _apply_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
|
||||
shrink_fun: Callable = (
|
||||
self._shrink_prefill if self.is_prefill else self._shrink_decode
|
||||
)
|
||||
|
||||
shrink_fun(
|
||||
y, x, w_t_all, block_statistic, sorted_tokens_num_lod, moe_index, scale
|
||||
)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
scale: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||
scale (float): Scaling factor for the operation
|
||||
"""
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
for slice_idx in range(len(lora_a_stacked)): # Each slice represents a layer
|
||||
|
||||
self._apply_shrink(
|
||||
y[slice_idx],
|
||||
x,
|
||||
lora_a_stacked[slice_idx],
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
scale,
|
||||
)
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(
|
||||
self.token_lora_indices, y, output_slices, lora_bias_stacked
|
||||
)
|
||||
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_left += output_slices[slice_idx]
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
Semantics:
|
||||
y += x @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
expand_fun: Callable = (
|
||||
self._expand_prefill if self.is_prefill else self._expand_decode
|
||||
)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||
* scale
|
||||
).squeeze(0)+lora_bias_stacked[i]
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||
scale (float): Scaling factor.
|
||||
output_slices (Tuple[int, ...]): Every slice's size.
|
||||
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||
"""
|
||||
|
||||
if self.no_lora:
|
||||
return
|
||||
|
||||
expert_num = 9
|
||||
block_statistic = torch.zeros(
|
||||
[12, expert_num], dtype=torch.int32, device=x.device
|
||||
)
|
||||
sorted_tokens_num_lod = torch.zeros(
|
||||
expert_num + 1, dtype=torch.int32, device=x.device
|
||||
)
|
||||
token_nums = x.size(0)
|
||||
moe_index = torch.zeros(token_nums, dtype=torch.int32, device=x.device)
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
if lora_bias_stacked is not None:
|
||||
assert len(lora_bias_stacked) == len(output_slices)
|
||||
y = self._apply_bias(
|
||||
self.token_lora_indices, y, output_slices, lora_bias_stacked
|
||||
)
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
buffer = tuple(
|
||||
torch.zeros((x.size(0), r), dtype=torch.float16, device=x.device)
|
||||
for _ in range(len(output_slices))
|
||||
)
|
||||
# [tensor.squeeze_(1) for tensor in lora_a_stacked]
|
||||
new_lora_a_stacked = tuple(lora_a.squeeze(1) for lora_a in lora_a_stacked)
|
||||
self.add_shrink(
|
||||
buffer,
|
||||
x,
|
||||
new_lora_a_stacked,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
scale,
|
||||
**kwargs,
|
||||
)
|
||||
# [tensor.unsqueeze_(1) for tensor in lora_a_stacked]
|
||||
|
||||
# [tensor.squeeze_(1) for tensor in lora_b_stacked]
|
||||
new_lora_b_stacked = tuple(lora_b.squeeze(1) for lora_b in lora_b_stacked)
|
||||
self.add_expand(
|
||||
y,
|
||||
buffer,
|
||||
new_lora_b_stacked,
|
||||
block_statistic,
|
||||
sorted_tokens_num_lod,
|
||||
moe_index,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
)
|
||||
# [tensor.unsqueeze_(1) for tensor in lora_b_stacked]
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
Semantics:
|
||||
buffer = (x @ lora_a_stacked) * scale
|
||||
y += buffer @ lora_b_stacked
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (torch.Tensor): Input tensor.
|
||||
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||
scale (float): Scaling factor.
|
||||
buffer (Optional[torch.Tensor]):Default to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
if lora_a_stacked.dim() == 2:
|
||||
lora_a_stacked = lora_a_stacked.unsqueeze(0)
|
||||
if lora_b_stacked.dim() == 2:
|
||||
lora_b_stacked = lora_b_stacked.unsqueeze(0)
|
||||
|
||||
r = lora_a_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
if indices.max() >= lora_a_stacked.size(0):
|
||||
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
|
||||
|
||||
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
|
||||
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
|
||||
|
||||
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
@@ -7,6 +7,12 @@ def register_model():
|
||||
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
|
||||
from .qwen3 import Qwen3ForCausalLM #noqa: F401
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
|
||||
from .qwen3_vl import Qwen3VLForConditionalGeneration
|
||||
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
||||
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
|
||||
# from .llama4 import Llama4ForCausalLM #noqa: F401
|
||||
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
|
||||
# from .deepseek_v2 import KunlunDeepseekV2MoE
|
||||
|
||||
# ModelRegistry.register_model(
|
||||
# "DemoModel",
|
||||
@@ -27,6 +33,10 @@ def register_model():
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GlmForCausalLM",
|
||||
@@ -34,7 +44,8 @@ def register_model():
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternLM2ForCausalLM",
|
||||
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
|
||||
@@ -52,16 +63,20 @@ def register_model():
|
||||
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Glm4MoeForCausalLM",
|
||||
"vllm_kunlun.models.glm4_moe:Glm4MoeForCausalLM")
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Glm4ForCausalLM",
|
||||
"vllm_kunlun.models.glm4:Glm4ForCausalLM")
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Glm4vForConditionalGeneration",
|
||||
"vllm_kunlun.models.glm4_1v:Glm4vForConditionalGeneration")
|
||||
"Qwen3OmniMoeForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"SeedOssForCausalLM",
|
||||
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
|
||||
|
||||
|
||||
def register_quant_method():
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/glm4.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Glm4Config
|
||||
|
||||
from vllm.attention import AttentionType
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm_kunlun.models.llama import LlamaMLP as Glm4MLP
|
||||
from vllm_kunlun.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
||||
|
||||
|
||||
class Glm4Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Glm4Config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: Optional[int] = None,
|
||||
qkv_bias: bool = False,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.rotary_dim = self.head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
partial_rotary_factor=partial_rotary_factor,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=attn_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Glm4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
||||
self.self_attn = Glm4Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
self.mlp = Glm4MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_self_attn_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = self.post_self_attn_layernorm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_mlp_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
ALL_DECODER_LAYER_TYPES = {
|
||||
"attention": Glm4DecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class Glm4Model(LlamaModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=Glm4DecoderLayer)
|
||||
|
||||
|
||||
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Glm4Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,716 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/glm4_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
|
||||
import os
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.glm4_moe import Glm4MoeConfig
|
||||
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,get_dp_group,get_tp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
from vllm_kunlun.ops.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
class Glm4MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
# NOTE In the transformers implementation, the gate isn't an nn.Linear,
|
||||
# so we cannot use ReplicatedLinear here.
|
||||
# See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
|
||||
self.gate = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32))
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func="sigmoid",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = Glm4MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
kunlun_linear_weights = self.gate.weight
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
linear_weights=kunlun_linear_weights) * self.routed_scaling_factor
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class Glm4MoeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 131072,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
qkv_bias: bool = False,
|
||||
use_qk_norm: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
|
||||
if os.getenv('USE_ORI_ROPE') == "1" or not self.use_qk_norm:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q.reshape(-1, self.num_heads,
|
||||
self.head_dim)).reshape(q.shape)
|
||||
k = self.k_norm(k.reshape(-1, self.num_kv_heads,
|
||||
self.head_dim)).reshape(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
else:
|
||||
# Rope fusion operators
|
||||
q, k, v = Split_Norm_Rope(qkv,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
positions,
|
||||
self.max_position_embeddings,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
131072)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Glm4MoeAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=config.attention_bias,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_qk_norm=config.use_qk_norm,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace):
|
||||
self.mlp = Glm4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class Glm4MoeModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
enable_eplb = vllm_config.parallel_config.enable_eplb
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Glm4MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Glm4MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.expert_weights = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (config.num_hidden_layers -
|
||||
config.first_k_dense_replace)
|
||||
self.num_expert_groups = config.n_group
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Glm4MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Glm4MoE):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_moe is None:
|
||||
raise RuntimeError("No Glm4MoE layer found in model.layers.")
|
||||
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if f"layers.{layer_idx+i}." in weight_name:
|
||||
return layer_idx + i
|
||||
return None
|
||||
@@ -1,21 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/gpt_oss.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/interns1.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# --------------------------------------------------------
|
||||
# InternS1
|
||||
# Copyright (c) 2025 Shanghai AI Lab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
@@ -258,33 +248,39 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
||||
|
||||
return image_token * num_images + video_token * num_videos
|
||||
|
||||
# def get_dummy_mm_data(
|
||||
# self,
|
||||
# seq_len: int,
|
||||
# mm_counts: Mapping[str, int],
|
||||
# ) -> MultiModalDataDict:
|
||||
# target_width, target_height = \
|
||||
# self.info.get_image_size_with_most_features()
|
||||
# target_num_frames = \
|
||||
# self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
# num_images = mm_counts.get("image", 0)
|
||||
# num_videos = mm_counts.get("video", 0)
|
||||
|
||||
# config = self.info.get_hf_config()
|
||||
# image_size_h, image_size_w = config.vision_config.image_size
|
||||
|
||||
# return {
|
||||
# "image":
|
||||
# self._get_dummy_images(width=target_width,
|
||||
# height=target_height,
|
||||
# num_images=num_images),
|
||||
# "video":
|
||||
# self._get_dummy_videos(width=image_size_w,
|
||||
# height=image_size_h,
|
||||
# num_frames=target_num_frames,
|
||||
# num_videos=num_videos),
|
||||
# }
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
"""Generates dummy multimodal data on Kunlun3 platform for performance analysis and warmup.
|
||||
|
||||
Retrieves visual resolution based on configuration (defaulting to 224x224)
|
||||
and generates resized dummy data for images and videos.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length (unused).
|
||||
mm_counts: A mapping of multimodal type counts, containing "image"
|
||||
and "video" keys.
|
||||
|
||||
Returns:
|
||||
MultiModalDataDict: A dictionary containing the generated dummy image
|
||||
and video data, structured as:
|
||||
{
|
||||
"image": dummy_image_data,
|
||||
"video": dummy_video_data
|
||||
}
|
||||
|
||||
Author:
|
||||
Dong Xinyu
|
||||
"""
|
||||
# 读取配置里的视觉分辨率;若缺省则兜底 224×224
|
||||
config = self.info.get_hf_config()
|
||||
img_size = getattr(config.vision_config, "image_size", None)
|
||||
if isinstance(img_size, (tuple, list)) and len(img_size) == 2:
|
||||
@@ -292,13 +288,15 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
||||
else:
|
||||
cfg_h, cfg_w = 224, 224
|
||||
|
||||
# 统一缩减:不再使用 “with_most_features”,而是选择较小的安全尺寸
|
||||
target_width = min(cfg_w, 224)
|
||||
target_height = min(cfg_h, 224)
|
||||
target_num_frames = 1
|
||||
target_num_frames = 1 # profile/warmup 只造 1 帧即可
|
||||
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
# 统一让视频也按缩减后的分辨率生成
|
||||
return {
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/interns1_vit.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
@@ -26,6 +17,7 @@ from transformers import PretrainedConfig
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
# from vllm_kunlun.ops.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -253,6 +245,7 @@ class InternS1VisionMLP(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# self.activation_fn = GeluAndMul()
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/internvl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/llama.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -38,8 +44,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -33,7 +40,7 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm_kunlun.ops.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -44,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
# from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.adapters import as_seq_cls_model
|
||||
@@ -177,7 +184,12 @@ class Qwen2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
# INTERNVL_3暂时使用环境变量来控制是否使用原生rotary_embedding
|
||||
# 若要修改,可尝试参考 qwen3.py
|
||||
if os.getenv('INTERNVL_3') == "1":
|
||||
q, k = self.rotary_emb.forward_native(positions, q, k)
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
@@ -295,6 +307,7 @@ class Qwen2Model(nn.Module):
|
||||
))
|
||||
|
||||
self.config = config
|
||||
config = config.get_text_config()
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@@ -479,10 +492,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
# sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,)
|
||||
# sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2vl.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -38,7 +45,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
# from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -70,11 +77,12 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
import xspeedgate_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
_MAX_FRAMES_PER_VIDEO = 14
|
||||
|
||||
# === Vision Inputs === #
|
||||
|
||||
@@ -226,13 +234,10 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||
freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
|
||||
|
||||
if freqs.dim() == 3 and freqs.shape[1] == 2:
|
||||
# freqs: (seq_len, 2, head_dim)
|
||||
# Call custom XPU Kernel version
|
||||
import xspeedgate_ops
|
||||
return torch.ops.xspeedgate_ops.rope_vit(t_, freqs, interleaved = False).type_as(t)
|
||||
|
||||
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
apply_rotary_emb = apply_rotary_emb_torch
|
||||
@@ -922,10 +927,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
num_frames = start_num_frames
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
@@ -947,15 +952,23 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
# max_images = mm_counts.get("image", 0)
|
||||
# max_videos = mm_counts.get("video", 0)
|
||||
|
||||
# max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
# max_total_frames = self._get_max_video_frames(seq_len -
|
||||
# max_image_tokens)
|
||||
# max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
# _MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
# return max(max_frames_per_video, 1)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
max_frames_per_video)
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
@@ -1404,10 +1417,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
# sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
# sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
@@ -1507,4 +1520,4 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
@@ -1,9 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen3.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -18,65 +23,57 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
import xtorch_ops
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import os
|
||||
from torch import nn
|
||||
from transformers import Qwen3Config
|
||||
|
||||
from vllm.attention import AttentionType, AttentionMetadata
|
||||
from vllm.attention import AttentionType
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
|
||||
from vllm_kunlun.ops.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm import envs
|
||||
|
||||
from vllm.model_executor.models.adapters import as_seq_cls_model
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -98,10 +95,7 @@ class Qwen3Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position = max_position
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
self.max_position = int(self.max_position * scaling_factor)
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -123,18 +117,25 @@ class Qwen3Attention(nn.Module):
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=attn_type,
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {},
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=attn_type)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
@@ -142,35 +143,19 @@ class Qwen3Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
# TODO: Supports both original Rope and Kunlun Rope fusion operators
|
||||
if os.getenv('FUSED_QK_ROPE_OP') == "1":
|
||||
# Rope fusion operators
|
||||
q, k, v = Split_Norm_Rope(qkv,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
positions,
|
||||
self.max_position,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
@@ -190,6 +175,9 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen3 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
@@ -214,6 +202,7 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen3MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -231,7 +220,6 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -244,8 +232,6 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -259,6 +245,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
"attention": Qwen3DecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
@@ -268,189 +255,15 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class Qwen3Model(nn.Module):
|
||||
"""Qwen3Model"""
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = Qwen3DecoderLayer):
|
||||
super().__init__()
|
||||
class Qwen3Model(Qwen2Model):
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=Qwen3DecoderLayer)
|
||||
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
and hasattr(config, "max_window_layers")):
|
||||
assert config.max_window_layers == config.num_hidden_layers, (
|
||||
"Sliding window for some but all layers is not supported. "
|
||||
"This model uses sliding window but `max_window_layers` = {} "
|
||||
"is less than `num_hidden_layers` = {}. Please open an issue "
|
||||
"to discuss this feature.".format(
|
||||
config.max_window_layers,
|
||||
config.num_hidden_layers,
|
||||
))
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
||||
decoder_layer_type = decoder_layer_type or Qwen3DecoderLayer
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""get_input_embeddings"""
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""
|
||||
Args:
|
||||
input_ids (torch.Tensor): Input sequence of shape `(batch, seq_len)`.
|
||||
Indices are expected to be in the range `[0, config.vocab_size]`.
|
||||
positions (torch.Tensor): Positional tensor of shape `(batch, seq_len)`.
|
||||
intermediate_tensors (Optional[IntermediateTensors], optional):
|
||||
Intermediate tensors from previous forward pass. Defaults to `None`.
|
||||
inputs_embeds (Optional[torch.Tensor], optional):
|
||||
Optionally, instead of positional embeddings, you can choose to
|
||||
provide your own embedding lookup matrix of shape `(batch, seq_len, emb_dim)`.
|
||||
If None, the model will create one on its own using the input ids.
|
||||
Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, IntermediateTensors]:
|
||||
If `intermediate_tensors` is not None, returns a IntermediateTensors object.
|
||||
Otherwise, returns a tensor of shape `(batch, seq_len, hidden_size)` representing
|
||||
the output of the last transformer encoder layer.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i, layer in enumerate(self.layers[self.start_layer:self.end_layer], start=self.start_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
"""Load model weights.
|
||||
Args:
|
||||
weights (Iterable[tuple[str, torch.Tensor]]): An iterator containing weight names and their corresponding values.
|
||||
Returns (set[str]):
|
||||
A set of already loaded weight names.
|
||||
Exceptions:
|
||||
None.
|
||||
"""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -493,6 +306,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@@ -502,7 +322,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
kv_caches: list[torch.Tensor] = None
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
@@ -511,10 +330,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
@@ -525,6 +342,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1335
vllm_kunlun/models/qwen3_next.py
Normal file
1335
vllm_kunlun/models/qwen3_next.py
Normal file
File diff suppressed because it is too large
Load Diff
1780
vllm_kunlun/models/qwen3_omni_moe_thinker.py
Normal file
1780
vllm_kunlun/models/qwen3_omni_moe_thinker.py
Normal file
File diff suppressed because it is too large
Load Diff
1636
vllm_kunlun/models/qwen3_vl.py
Normal file
1636
vllm_kunlun/models/qwen3_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
358
vllm_kunlun/models/qwen3_vl_moe.py
Normal file
358
vllm_kunlun/models/qwen3_vl_moe.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The Qwen Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
|
||||
Qwen3VLMoeConfig)
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
|
||||
from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen3VLMoeConfig)
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
# the same shape as input_embeds
|
||||
"deepstack_input_embeds": 0
|
||||
})
|
||||
class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert self.start_layer >= len(
|
||||
vllm_config.model_config.hf_config.vision_config.
|
||||
deepstack_visual_indexes), (
|
||||
"start_layer should be greater than or equal to "
|
||||
"len(deepstack_visual_indexes)")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer_idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
layer_idx = layer_idx + self.start_layer
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if deepstack_input_embeds is not None and \
|
||||
layer_idx in range(0, len(deepstack_input_embeds)):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[
|
||||
f"deepstack_input_embeds_{layer_idx}"]
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_fused_expert_weights(self, name: str, params_dict: dict,
|
||||
loaded_weight: torch.Tensor, shard_id: str,
|
||||
num_experts: int) -> bool:
|
||||
param = params_dict[name]
|
||||
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
||||
loaded_local_expert = False
|
||||
for expert_id in range(num_experts):
|
||||
curr_expert_weight = loaded_weight[expert_id]
|
||||
success = weight_loader(param,
|
||||
curr_expert_weight,
|
||||
name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
loaded_local_expert = True
|
||||
|
||||
return loaded_local_expert
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
|
||||
".v_scale", "_v_scale", ".weight_scale",
|
||||
"_weight_scale", ".input_scale", "_input_scale")
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
is_fused_expert = False
|
||||
fused_expert_params_mapping = [
|
||||
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
||||
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
||||
]
|
||||
num_experts = self.config.get_text_config().num_experts
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if ("experts.gate_up_proj" in name
|
||||
or "experts.down_proj" in name):
|
||||
is_fused_expert = True
|
||||
expert_params_mapping = fused_expert_params_mapping
|
||||
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
if is_fused_expert:
|
||||
loaded_weight = loaded_weight.transpose(-1,
|
||||
-2) # no bias
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
success_w1 = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight[0],
|
||||
"w1", num_experts)
|
||||
success_w3 = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight[1],
|
||||
"w3", num_experts)
|
||||
success = success_w1 and success_w3
|
||||
else:
|
||||
# down_proj
|
||||
success = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight,
|
||||
shard_id, num_experts)
|
||||
else:
|
||||
# Skip loading extra parameters for GPTQ/modelopt models
|
||||
if name_mapped.endswith(
|
||||
ignore_suffixes
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(
|
||||
ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_kv_scale_name,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3MoeForCausalLM, self).__init__()
|
||||
self.config = vllm_config.model_config.hf_config.text_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.model = Qwen3MoeLLMModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLMoeProcessingInfo,
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder)
|
||||
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3VLForConditionalGeneration, self).__init__()
|
||||
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||
not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"language_model"))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
self.use_deepstack = hasattr(config.vision_config,
|
||||
'deepstack_visual_indexes')
|
||||
self.deepstack_num_level = len(
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
) if self.use_deepstack else 0
|
||||
# register buffer for deepstack
|
||||
if self.use_deepstack and self.visual is not None:
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
else:
|
||||
self.deepstack_input_embeds = None
|
||||
self.visual_dim = config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
500
vllm_kunlun/models/seed_oss.py
Normal file
500
vllm_kunlun/models/seed_oss.py
Normal file
@@ -0,0 +1,500 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The Seed team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only SeedOss model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig as SeedOssConfig
|
||||
|
||||
from vllm.attention import AttentionType
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm_kunlun.ops.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SeedOssMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeedOssAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
rope_scaling: tuple | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class SeedOssDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SeedOssConfig,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
||||
# By default, SeedOss uses causal attention as it is a
|
||||
# decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = SeedOssAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.mlp = SeedOssMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
)
|
||||
class SeedOssModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if cache_config.sliding_window is not None and hasattr(
|
||||
config, "max_window_layers"
|
||||
):
|
||||
assert config.max_window_layers == config.num_hidden_layers, (
|
||||
"Sliding window for some but all layers is not supported. "
|
||||
"This model uses sliding window but `max_window_layers` = {} "
|
||||
"is less than `num_hidden_layers` = {}. Please open an issue "
|
||||
"to discuss this feature.".format(
|
||||
config.max_window_layers,
|
||||
config.num_hidden_layers,
|
||||
)
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
# Use the provided decoder layer type or default to SeedDecoderLayer
|
||||
decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer_type(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = SeedOssModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@@ -12,10 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.layernorm
|
||||
@@ -1,20 +1,3 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
import torch
|
||||
@@ -29,7 +12,6 @@ logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import xtorch_ops
|
||||
|
||||
logger.info(f"Load custom ops library success!")
|
||||
except ImportError as e:
|
||||
logger.warning("Import error msg: %s", e.msg)
|
||||
@@ -37,15 +19,13 @@ except ImportError as e:
|
||||
|
||||
_per_token_smooth_quant = True
|
||||
|
||||
|
||||
def is_per_token_smooth_quant():
|
||||
"""is per token smooth quant"""
|
||||
""" is per token smooth quant """
|
||||
return _per_token_smooth_quant
|
||||
|
||||
|
||||
class KunlunOps:
|
||||
"""KunlunOps"""
|
||||
|
||||
# Attention ops
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
@@ -70,9 +50,10 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV1"""
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV1 """
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
@@ -83,7 +64,7 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128,
|
||||
vo_head_dim=128
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -112,9 +93,10 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV2"""
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV2 """
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
@@ -125,28 +107,31 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128,
|
||||
vo_head_dim=128
|
||||
)
|
||||
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
|
||||
"""silu and mul"""
|
||||
def silu_and_mul(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" silu and mul """
|
||||
xtorch_ops.silu_and_mul(
|
||||
x,
|
||||
axis=-1,
|
||||
turn=True,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
|
||||
"""quick gelu"""
|
||||
def quick_gelu(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" quick gelu """
|
||||
xtorch_ops.quick_gelu(
|
||||
x,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Layernorm
|
||||
@staticmethod
|
||||
@@ -157,7 +142,9 @@ class KunlunOps:
|
||||
epsilon,
|
||||
):
|
||||
"""rms_norm"""
|
||||
xtorch_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
|
||||
xtorch_ops.rmsnorm(
|
||||
x, weight.to(torch.float32), epsilon, out=out
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(
|
||||
@@ -175,11 +162,16 @@ class KunlunOps:
|
||||
residual.copy_(fused_input, non_blocking=True)
|
||||
x.copy_(output)
|
||||
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
@@ -196,43 +188,66 @@ class KunlunOps:
|
||||
key_x = key_x.unsqueeze(0)
|
||||
|
||||
xtorch_ops.rotary_embedding_gptj(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache
|
||||
)
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache)
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
key.data = key_x
|
||||
if query_x_dim != query_x.dim():
|
||||
query_x = query_x.unsqueeze(0)
|
||||
key_x = key_x.unsqueeze(0)
|
||||
return query, key
|
||||
|
||||
|
||||
# TODO: need opt
|
||||
if cos_sin_cache.dim() == 4:
|
||||
max_seq_len = cos_sin_cache.shape[2]
|
||||
head_dim = cos_sin_cache.shape[3]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(
|
||||
0
|
||||
) # Remove the first two dimensions [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.squeeze(0).squeeze(0) # 移除前两个维度 [1,1,L,D] -> [L,D]
|
||||
cos_sin_cache = cos_sin_cache.view(max_seq_len, 1, head_dim)
|
||||
|
||||
# Reshape query and key
|
||||
|
||||
# 重塑 query 和 key 的形状
|
||||
num_tokens = query_x.shape[0]
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
# # [num_tokens, num_heads * head_size] -> [num_tokens, num_heads, head_size]
|
||||
# query_x = query_x.view(num_tokens, num_heads, head_size)
|
||||
# # [num_tokens, num_kv_heads * head_size] -> [num_tokens, num_kv_heads, head_size]
|
||||
# key_x = key_x.view(num_tokens, num_kv_heads, head_size)
|
||||
|
||||
# # 确保形状正确
|
||||
# assert query_x.shape == (num_tokens, num_heads, head_size), \
|
||||
# f"Expected query shape [{num_tokens}, {num_heads}, {head_size}], got {query_x.shape}"
|
||||
# assert key_x.shape == (num_tokens, num_kv_heads, head_size), \
|
||||
# f"Expected key shape [{num_tokens}, {num_kv_heads}, {head_size}], got {key_x.shape}"
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||
)
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style)
|
||||
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
|
||||
# query.data = query_x
|
||||
# key.data = key_x
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def mrotary_embedding(
|
||||
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
positions,
|
||||
mrope_section,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
@@ -241,21 +256,35 @@ class KunlunOps:
|
||||
query_x_dim = query_x.dim()
|
||||
assert is_neox_style
|
||||
xtorch_ops.mrotary_embedding_neox(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
|
||||
)
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
mrope_section)
|
||||
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
key.data = key_x
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(src, dst, block_mapping):
|
||||
"""swap_blocks"""
|
||||
xtorch_ops.swap_blocks(src, dst, block_mapping)
|
||||
def swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping):
|
||||
""" swap_blocks """
|
||||
xtorch_ops.swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||
"""copy_blocks"""
|
||||
def copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping):
|
||||
""" copy_blocks """
|
||||
for i in range(len(key_caches)):
|
||||
key_caches[i] = key_caches[i].contiguous()
|
||||
value_caches[i] = value_caches[i].contiguous()
|
||||
@@ -273,9 +302,16 @@ class KunlunOps:
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
):
|
||||
"""reshape_and_cache"""
|
||||
xtorch_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
):
|
||||
""" reshape_and_cache """
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def multi_query_kv_attention(
|
||||
@@ -284,7 +320,7 @@ class KunlunOps:
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kargs,
|
||||
**kargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
@@ -303,7 +339,7 @@ class KunlunOps:
|
||||
KVh = key.size(2)
|
||||
if KVh != Qh:
|
||||
repeat = Qh // KVh
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
value = value.repeat_interleave(repeat, dim=2)
|
||||
xtorch_ops.attention(
|
||||
q=query,
|
||||
@@ -318,132 +354,85 @@ class KunlunOps:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def quant_fusedresidual_rmsnorm_op(
|
||||
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
def quant_fusedresidual_rmsnorm_op(x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale: bool,
|
||||
type: int = 1):
|
||||
"""Quantized fused residual layer normalization"""
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
out=out,
|
||||
out_scale=out_scale,
|
||||
residual_tensor=residual,
|
||||
)
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
||||
|
||||
if residual is None:
|
||||
return out, out_scale
|
||||
return out, out_scale, residual
|
||||
|
||||
@staticmethod
|
||||
def quant_rmsnorm_op(
|
||||
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
def quant_rmsnorm_op(x,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale : bool,
|
||||
type: int = 1):
|
||||
"""Quantized RMSNorm"""
|
||||
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||
out=out, out_scale=out_scale)
|
||||
return out, out_scale
|
||||
|
||||
@staticmethod
|
||||
def smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype,
|
||||
):
|
||||
def smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype):
|
||||
"""smooth_quant_matmul_column_row_kernels"""
|
||||
input_shape = input_tensor.shape
|
||||
weight_shape = weight.shape
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
||||
out = torch.empty(
|
||||
(input_shape[0] * input_shape[1], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
out = torch.empty((input_shape[0] * input_shape[1],
|
||||
weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
output_bs_shape = [input_shape[0], input_shape[1]]
|
||||
elif input_tensor.dim() == 2:
|
||||
out = torch.empty(
|
||||
(input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
out = torch.empty((input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
output_bs_shape = [-1]
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out,
|
||||
)
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight, smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out)
|
||||
|
||||
out = out.view(*output_bs_shape, weight_shape[0])
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe_ep(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -460,23 +449,23 @@ class KunlunOps:
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
x = hidden_states
|
||||
batch, hidden_size = x.shape
|
||||
batch, hidden_size = x.shape
|
||||
num_local_experts, up_gate_size, _ = w13_weight.shape
|
||||
|
||||
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(
|
||||
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
batch, top_k, dtype=torch.int32, device=router_logits.device
|
||||
)
|
||||
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(
|
||||
router_logits, topk_weights, topk_ids, block_static
|
||||
)
|
||||
router_logits = x.to(linear_weights.dtype)@linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=router_logits.dtype,
|
||||
device=router_logits.device)
|
||||
topk_ids = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=router_logits.device)
|
||||
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
||||
@@ -490,22 +479,50 @@ class KunlunOps:
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
torch.ops._C.swiglu(cur_token @ w13_weight[i].T, up_gate)
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
torch.ops._C.swiglu(cur_token@ w13_weight[i].T, up_gate)
|
||||
out[selected_token] = up_gate @ w2_weight[i].T
|
||||
output = (
|
||||
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(x.dtype)
|
||||
)
|
||||
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
expert_num = linear_weights.shape[0]
|
||||
|
||||
torch.ops._C.moe_ffn_block(
|
||||
x=hidden_states,
|
||||
gate_w=linear_weights,
|
||||
inter_w=w1,
|
||||
output_w=w2,
|
||||
expert_num=expert_num,
|
||||
moe_top_k=topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
expert_group_num=num_expert_group,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_multi_head_latent_page_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -538,11 +555,10 @@ class KunlunOps:
|
||||
prompt_lods_cpu: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
"""mla pa block"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
xtorch_ops.xft_multi_head_latent_page_attention_block(
|
||||
hidden_states,
|
||||
q_lora_rank,
|
||||
@@ -579,3 +595,42 @@ class KunlunOps:
|
||||
v_cache=v_cache,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
) -> torch.Tensor:
|
||||
"""fused_gdn_gating"""
|
||||
output = xtorch_ops.fused_gdn_gating(
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
h0_source: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool,
|
||||
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
'''
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
||||
cu_seqlens)
|
||||
return (o, final_state)
|
||||
@@ -1,8 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Custom activation functions."""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import LazyDict
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
|
||||
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||
d = x.shape[-1] // 2.
|
||||
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float = 0.):
|
||||
"""
|
||||
Initializes the instance.
|
||||
|
||||
Args:
|
||||
threshold (float, optional): Threshold value for the filter. Defaults to 0..
|
||||
|
||||
Returns:
|
||||
None: This method does not return anything.
|
||||
"""
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
计算输入张量的正向传播,并返回一个新的张量。
|
||||
该函数实现了原生的前向传播过程,即对输入张量进行阈值化处理后,将其乘以另一个张量。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor, shape=[*, d]):
|
||||
输入张量,其中*表示任意维度,d为特征维度。
|
||||
|
||||
Returns:
|
||||
torch.Tensor, shape=[*, d]:
|
||||
返回一个新的张量,其形状与输入张量相同,除了最后一个维度被设置为d/2。
|
||||
如果输入张量的最后一个维度小于等于d/2,则返回的张量将保持不变;否则,将对输入张量进行阈值化处理。
|
||||
"""
|
||||
d = x.shape[-1] // 2
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
x1 = F.threshold(x1, self.threshold, 0.0)
|
||||
return x1 * x2
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上执行前向传播。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为(N, C, H, W)。
|
||||
"""
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_silu_and_mul")
|
||||
@@ -15,9 +85,532 @@ class SiluAndMul(CustomOp):
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_cuda"""
|
||||
import xtorch_ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swiglu(x, out)
|
||||
return out
|
||||
return out
|
||||
|
||||
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
import xtorch_ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
xtorch_ops.swiglu(x, out)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the function on `x` using XPU backend.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of any shape. Must be a floating point tensor.
|
||||
The number of channels should be even.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as input except the last dimension is reduced by half.
|
||||
It has the same dtype as the input and lives on the same device.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播一个神经元,计算输入的信号。
|
||||
参数:
|
||||
x (torch.Tensor): 形状为(-1, d)的张量,其中d是输入的维度。
|
||||
每个元素表示一个输入信号。
|
||||
返回值(torch.Tensor):
|
||||
形状为(-1, d)的张量,其中d是输出的维度。
|
||||
每个元素表示一个输出信号。
|
||||
"""
|
||||
d = x.shape[-1] // 2
|
||||
x_reshaped = x.view(-1, x.shape[-1])
|
||||
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
|
||||
result = s * x_reshaped[:, d:]
|
||||
return result.view(*x.shape[:-1], d)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_mul_and_silu")
|
||||
class MulAndSilu(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化函数,用于实例化类的对象。
|
||||
如果当前平台是 CUDA 或 XPU,则使用 torch.ops._C.mul_and_silu 进行操作;
|
||||
否则,如果当前平台是 CPU,则使用 forward_native 方法进行操作。
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.mul_and_silu
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return x[..., :d] * F.silu(x[..., d:])
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上执行前向传播操作。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,其形状应为(..., d),其中d是特征维度。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,其形状与输入张量相同,但最后一个维度被替换为d/2。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for MulAndSilu
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
"""
|
||||
Initializes the instance.
|
||||
|
||||
Args:
|
||||
approximate (str, optional): The approximation method to use. Defaults to "none".
|
||||
Can be one of "none", "tanh".
|
||||
|
||||
Raises:
|
||||
ValueError: If the `approximate` parameter is not one of "none", "tanh".
|
||||
"""
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上进行前向传播。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(batch_size, ..., dim),其中dim是特征维度。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为(batch_size, ..., dim//2),其中dim是特征维度,除以2是因为GELU的输出是两个分量。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
# from vllm import _custom_ops as ops
|
||||
import xtorch_ops
|
||||
# d = x.shape[-1] // 2
|
||||
# output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(x, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
# ops.gelu_and_mul(out, x)
|
||||
print(x,x.shape)
|
||||
xtorch_ops.gelu(x, out)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d, _ = self._check_and_make_out(x)
|
||||
# 保守地用 contiguous,避免 view 相关坑
|
||||
x = x.contiguous()
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
return F.gelu(x1, approximate=self.approximate) * x2
|
||||
|
||||
# def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# """PyTorch-native implementation equivalent to forward()."""
|
||||
# d = x.shape[-1] // 2
|
||||
# return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply gelu activation function on input tensor using iPEX backend.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape (N, C, H, W).
|
||||
The data type can be float32 or float64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape and data type as input.
|
||||
The output will have a range of (-0.5, 0.5) for tanh approximation.
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
ops.gelu_and_mul(out, x)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""
|
||||
返回一个字符串,包含有关模型的额外信息。这个函数可以被用于打印出模型的概要信息。
|
||||
默认情况下,这个函数会返回一个包含模型是否使用近似值(approximate)的信息。
|
||||
|
||||
Returns:
|
||||
str (str): 一个字符串,包含有关模型的额外信息。
|
||||
"""
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
计算CUDA上的GELU函数。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: GELU函数的结果,形状与输入相同。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the GELU activation function element-wise.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with any shape. The data type is float32 or float64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as input. The data type is the same as input.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
return ops.gelu_new(x)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
计算输入张量x的CUDA版本GELU(Gaussian Error Linear Unit)。
|
||||
该函数调用了vllm模块中的_custom_ops模块中的gelu_fast函数,完成GELU操作。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W),类型为float32或float64。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: GELU后的输出张量,形状与x相同,类型与x相同。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the GELU function element-wise on input tensor ``x``.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with any shape. The data type can be float or half float.
|
||||
The range of the input values is expected to be -inf to inf.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape and data type as input ``x``.
|
||||
The output values are in the range [-0.5, 0.5] for float dtype and [-15, 15] for half float dtype.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input ``x`` is not a torch.Tensor.
|
||||
RuntimeError: If the input ``x`` contains non-finite numbers.
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
return ops.gelu_fast(x)
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
使用CUDA设备进行前向计算。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状与输入相同,值为GELU函数的结果。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the GELU function element-wise on input tensor ``x``.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with any shape. The data type is float32 or float64.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape and data type as input ``x``.
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm._kunlun_ops import KunlunOps as ops
|
||||
out = torch.empty_like(x)
|
||||
ops.quick_gelu(out, x)
|
||||
return out
|
||||
|
||||
|
||||
@CustomOp.register("kunlun_relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
在CUDA设备上执行前向传播。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W),数据类型为float32或float64。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状与输入相同,数据类型与输入一致。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the LayerNorm module.
|
||||
|
||||
Args:
|
||||
act_module (nn.Module): The activation function to use after layer norm.
|
||||
Default: nn.GELU()
|
||||
intermediate_size (int): The size of the intermediate representation.
|
||||
input_is_parallel (bool, optional): Whether the input is parallelly processed.
|
||||
Default: True
|
||||
params_dtype (Optional[torch.dtype], optional): The data type of parameters.
|
||||
If None, use the default data type. Default: None
|
||||
"""
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
前向传播函数,将输入的张量进行缩放和激活操作。
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量,形状为(N, C, H, W)或者(N, C, H, W, D)。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 返回处理后的张量,形状与输入相同。
|
||||
"""
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
"""
|
||||
加载权重,如果输入是并行的,则需要将其平均分配到每个模型参数中。
|
||||
参数:
|
||||
param (nn.Parameter): 需要加载权重的模型参数。
|
||||
loaded_weight (torch.Tensor): 加载的权重张量。
|
||||
返回值:
|
||||
无返回值,直接修改了param的数据。
|
||||
"""
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = LazyDict({
|
||||
"gelu":
|
||||
lambda: nn.GELU(),
|
||||
"gelu_fast":
|
||||
lambda: FastGELU(),
|
||||
"gelu_new":
|
||||
lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh":
|
||||
lambda: nn.GELU(approximate="tanh"),
|
||||
"relu":
|
||||
lambda: nn.ReLU(),
|
||||
"relu2":
|
||||
lambda: ReLUSquaredActivation(),
|
||||
"silu":
|
||||
lambda: nn.SiLU(),
|
||||
"quick_gelu":
|
||||
lambda: QuickGELU(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
# print(f"activation function name: {act_fn_name}")
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"geglu": lambda: GeluAndMul(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
|
||||
@@ -1,55 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Bao Qian, Dong Xinyu, Chen Zhennan, Ma Tianyu
|
||||
# Email: baoqian@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""kunlun attention wrapper for context and decode"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
from itertools import accumulate
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from .utils import CommonAttentionState, CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (
|
||||
is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx,
|
||||
compute_slot_mapping,
|
||||
)
|
||||
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from .utils import (CommonAttentionState, CommonMetadataBuilder)
|
||||
from vllm.attention.backends.utils import (is_block_tables_empty,
|
||||
compute_slot_mapping_start_idx, compute_slot_mapping)
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
|
||||
accept_output_buffer = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "KUNLUN_ATTENTION"
|
||||
@@ -80,9 +53,8 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size
|
||||
)
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -182,6 +154,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
@@ -194,27 +167,23 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
"""
|
||||
return (
|
||||
(self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None)
|
||||
)
|
||||
'''
|
||||
return ((self.encoder_seq_lens is not None)
|
||||
and (self.encoder_seq_lens_tensor is not None)
|
||||
and (self.max_encoder_seq_len is not None))
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
"""
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
"""
|
||||
return (
|
||||
self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None)
|
||||
)
|
||||
'''
|
||||
return (self.is_all_encoder_attn_metadata_set
|
||||
and (self.cross_slot_mapping is not None)
|
||||
and (self.cross_block_tables is not None))
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||
@@ -227,60 +196,43 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert (self.seq_lens is not None) or (self.encoder_seq_lens is not None)
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (
|
||||
None
|
||||
if self.query_start_loc is None
|
||||
else self.query_start_loc[: self.num_prefills + 1]
|
||||
)
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
# flash attention needs both lod information on host and device
|
||||
query_start_loc_host = (
|
||||
None
|
||||
if self.query_start_loc_host is None
|
||||
else self.query_start_loc_host[: self.num_prefills + 1]
|
||||
)
|
||||
kv_prefix_start_loc_host = (
|
||||
None
|
||||
if self.kv_prefix_start_loc_host is None
|
||||
else self.kv_prefix_start_loc_host[: self.num_prefills + 1]
|
||||
+ query_start_loc_host
|
||||
)
|
||||
kv_prefix_start_loc = (
|
||||
None
|
||||
if kv_prefix_start_loc_host is None
|
||||
else kv_prefix_start_loc_host.cuda()
|
||||
)
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[: self.num_prefill_tokens]
|
||||
)
|
||||
seq_lens = None if self.seq_lens is None else self.seq_lens[: self.num_prefills]
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
context_lens_tensor = (
|
||||
None
|
||||
if self.context_lens_tensor is None
|
||||
else self.context_lens_tensor[: self.num_prefills]
|
||||
)
|
||||
query_start_loc_host = (None if self.query_start_loc_host is None else
|
||||
self.query_start_loc_host[:self.num_prefills + 1])
|
||||
kv_prefix_start_loc_host = (None if self.kv_prefix_start_loc_host is None else
|
||||
self.kv_prefix_start_loc_host[:self.num_prefills + 1] + query_start_loc_host)
|
||||
kv_prefix_start_loc = (None if kv_prefix_start_loc_host is None else kv_prefix_start_loc_host.cuda())
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
# for prefix cache, block table only contains blocks that hit
|
||||
# if self.block_tables is None:
|
||||
# block_tables = None
|
||||
# elif self.block_tables.shape[1] == 0:
|
||||
# block_tables = self.block_tables[:self.num_prefills]
|
||||
# else:
|
||||
# block_tables = self.block_tables[:self.num_prefills][:, -1].clone()
|
||||
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[: self.num_prefills]
|
||||
)
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
@@ -305,8 +257,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_start_loc=self.seq_start_loc,
|
||||
)
|
||||
seq_start_loc=self.seq_start_loc)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@@ -319,35 +270,25 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert (self.seq_lens_tensor is not None) or (
|
||||
self.encoder_seq_lens_tensor is not None
|
||||
)
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (
|
||||
None
|
||||
if self.slot_mapping is None
|
||||
else self.slot_mapping[self.num_prefill_tokens :]
|
||||
)
|
||||
seq_lens_tensor = (
|
||||
None
|
||||
if self.seq_lens_tensor is None
|
||||
else self.seq_lens_tensor[self.num_prefills :]
|
||||
)
|
||||
seq_lens_tensor_cpu = (
|
||||
None
|
||||
if self.seq_lens_tensor_cpu is None
|
||||
else self.seq_lens_tensor_cpu[self.num_prefills :]
|
||||
)
|
||||
block_tables = (
|
||||
None
|
||||
if self.block_tables is None
|
||||
else self.block_tables[self.num_prefills :]
|
||||
)
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
||||
self.seq_lens_tensor_cpu[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = KunlunMetadata(
|
||||
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
@@ -364,16 +305,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""KunlunMetadataBuilder"""
|
||||
|
||||
_metadata_cls = KunlunMetadata
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
super().__init__(input_builder)
|
||||
self.prefix_cache_kv_lens: List[int] = []
|
||||
@@ -382,120 +320,90 @@ class KunlunMetadataBuilder(CommonMetadataBuilder[KunlunMetadata]):
|
||||
"""prepare"""
|
||||
super().prepare()
|
||||
self.prefix_cache_kv_lens = list()
|
||||
|
||||
def _add_seq_group(
|
||||
self,
|
||||
inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool,
|
||||
):
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (
|
||||
seq_id,
|
||||
token_len,
|
||||
seq_len,
|
||||
curr_seq_len,
|
||||
query_len,
|
||||
context_len,
|
||||
curr_sliding_window_block,
|
||||
) in zip(
|
||||
inter_data.seq_ids,
|
||||
[len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens,
|
||||
inter_data.seq_lens,
|
||||
inter_data.query_lens,
|
||||
inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
):
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(placeholders)
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert (
|
||||
query_len == 1
|
||||
), "seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len
|
||||
)
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
block_table = []
|
||||
assert (
|
||||
not chunked_prefill_enabled
|
||||
), "chunk prefill not supported for kunlun attention"
|
||||
assert not chunked_prefill_enabled, "chunk prefill not supported for kunlun attention"
|
||||
if inter_data.prefix_cache_hit:
|
||||
assert context_len != 0
|
||||
assert context_len % self.block_size == 0
|
||||
block_table = block_tables[seq_id][: context_len // self.block_size]
|
||||
elif (not is_prompt) and block_tables is not None:
|
||||
# block_table = block_tables[seq_id]
|
||||
block_table = block_tables[seq_id][:context_len // self.block_size]
|
||||
elif ((not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
if is_prompt:
|
||||
self.prefix_cache_kv_lens.append(context_len)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(
|
||||
is_prompt, query_len, context_len, self.sliding_window
|
||||
)
|
||||
compute_slot_mapping(
|
||||
is_profile_run,
|
||||
self.slot_mapping,
|
||||
seq_id,
|
||||
seq_len,
|
||||
context_len,
|
||||
start_idx,
|
||||
self.block_size,
|
||||
inter_data.block_tables,
|
||||
)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
cuda_graph_pad_size: int,
|
||||
batch_size: int,
|
||||
):
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""build"""
|
||||
attn_meta = super().build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
query_start_loc_host = torch.tensor(
|
||||
query_start_loc, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
query_start_loc_host = torch.tensor(query_start_loc, dtype=torch.int32, device='cpu')
|
||||
attn_meta.query_start_loc_host = query_start_loc_host
|
||||
# max_kv_len = max(query_lens + prefix_cache_kv_lens)
|
||||
attn_meta.max_kv_len = max(self.prefix_cache_kv_lens + attn_meta.seq_lens)
|
||||
|
||||
# If kv cache is included and there is a hit
|
||||
# 包含kv cache ,且存在命中的情况
|
||||
if len(self.prefix_cache_kv_lens) != 0 and max(self.prefix_cache_kv_lens) != 0:
|
||||
self.prefix_cache_kv_lens = list(
|
||||
accumulate(self.prefix_cache_kv_lens, initial=0)
|
||||
)
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(
|
||||
self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.prefix_cache_kv_lens = list(accumulate(self.prefix_cache_kv_lens, initial=0))
|
||||
prefix_cache_kv_lens_tensor = torch.tensor(self.prefix_cache_kv_lens, dtype=torch.int32, device="cpu")
|
||||
attn_meta.kv_prefix_start_loc_host = prefix_cache_kv_lens_tensor
|
||||
attn_meta.seq_lens_tensor_cpu = attn_meta.seq_lens_tensor.to("cpu")
|
||||
return attn_meta
|
||||
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: KunlunMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: AttentionType,
|
||||
) -> tuple:
|
||||
"""
|
||||
'''
|
||||
The particular choice of sequence-length- and block-table-related
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
@@ -517,7 +425,7 @@ def _get_seq_len_block_table_args(
|
||||
* Appropriate sequence-lengths tensor
|
||||
* Appropriate max sequence-length scalar
|
||||
* Appropriate block tables (or None)
|
||||
"""
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
@@ -526,26 +434,23 @@ def _get_seq_len_block_table_args(
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables)
|
||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
||||
attn_metadata.block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# No block tables associated with encoder attention
|
||||
return (
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
None,
|
||||
)
|
||||
return (attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len, None)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
|
||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"""KunlunAttentionImpl"""
|
||||
|
||||
@@ -564,7 +469,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
@@ -585,8 +491,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if head_size not in suppored_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {suppored_head_sizes}."
|
||||
)
|
||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -654,21 +560,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if attn_type == AttentionType.ENCODER and (
|
||||
not attn_metadata.is_all_encoder_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder attention requires setting " "encoder metadata attributes."
|
||||
)
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER and (
|
||||
not attn_metadata.is_all_cross_attn_metadata_set
|
||||
):
|
||||
raise AttributeError(
|
||||
"Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes."
|
||||
)
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
@@ -682,7 +583,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
@@ -691,8 +592,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
@@ -701,14 +601,10 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
else:
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
value = value.contiguous()
|
||||
KunlunOps.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
)
|
||||
KunlunOps.reshape_and_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype)
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
# Encoder attention - chunked prefill is not applicable;
|
||||
@@ -753,20 +649,14 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
out = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.query_start_loc_host,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
).view_as(query)
|
||||
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, query, key, value,
|
||||
alibi_slopes=self.alibi_slopes).view_as(query)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert (
|
||||
attn_type != AttentionType.ENCODER_ONLY
|
||||
), "Encoder-only models should not have decode metadata."
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
@@ -791,4 +681,4 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
@@ -4,13 +4,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List, Dict, Any
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
@@ -20,10 +19,8 @@ from torch.library import custom_op, impl
|
||||
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
|
||||
class Attention(VllmAttention):
|
||||
"""Attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -75,8 +72,11 @@ class Attention(VllmAttention):
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.zeros(output_shape, dtype=query.dtype, device=query.device)
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
@@ -97,13 +97,16 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata, output=output
|
||||
)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output_kunlun(
|
||||
query, key, value, output, self.layer_name
|
||||
)
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
if self.use_direct_call:
|
||||
@@ -112,15 +115,13 @@ class Attention(VllmAttention):
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(
|
||||
self, query, key, value, self_kv_cache, attn_metadata
|
||||
)
|
||||
return self.impl.forward(self, query, key, value,
|
||||
self_kv_cache, attn_metadata)
|
||||
else:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
return unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
|
||||
#
|
||||
# Rewritten from the MultiHeadAttention class in vllm.attention.layer
|
||||
# 重写自 vllm.attention.layer 中的 MultiHeadAttention 类
|
||||
class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -130,15 +131,14 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads = num_heads,
|
||||
head_size = head_size,
|
||||
scale = scale,
|
||||
num_kv_heads = num_kv_heads,
|
||||
)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -159,31 +159,34 @@ class MultiHeadAttention(VllmMultiHeadAttention):
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
# kunlun only supports flash_attn
|
||||
# kunlun只支持flash_attn
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query, key, value, scale=self.scale
|
||||
)
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
"""wait_for_kv_layer_from_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
@@ -198,10 +201,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str, kv_cache_layer: List[torch.Tensor]
|
||||
):
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor]):
|
||||
"""maybe_save_kv_layer_to_connector"""
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
@@ -213,8 +215,8 @@ def maybe_save_kv_layer_to_connector(
|
||||
if attn_metadata is None:
|
||||
return
|
||||
assert isinstance(attn_metadata, dict)
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
|
||||
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||
attn_metadata[layer_name])
|
||||
|
||||
@custom_op("vllm::unified_attention_with_output_kunlun", mutates_args=())
|
||||
def unified_attention_with_output_kunlun(
|
||||
@@ -223,8 +225,7 @@ def unified_attention_with_output_kunlun(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@@ -232,26 +233,26 @@ def unified_attention_with_output_kunlun(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output=output)
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
|
||||
def _fake_unified_attention_with_output_kunlun(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
output_scale: Optional[torch.Tensor] = None,) -> None:
|
||||
return None
|
||||
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(
|
||||
_fake_unified_attention_with_output_kunlun
|
||||
)
|
||||
|
||||
unified_attention_with_output_kunlun.register_fake(_fake_unified_attention_with_output_kunlun)
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -268,7 +269,8 @@ def unified_attention(
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
return output
|
||||
return output
|
||||
9
vllm_kunlun/ops/fla/__init__.py
Normal file
9
vllm_kunlun/ops/fla/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .layernorm_guard import RMSNormGated
|
||||
from .torch_fla import l2norm, torch_chunk_gated_delta_rule
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
]
|
||||
247
vllm_kunlun/ops/fla/chunk.py
Normal file
247
vllm_kunlun/ops/fla/chunk.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
A = A.transpose(1,2)
|
||||
sequence_length = A.shape[-2]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
|
||||
|
||||
A = A.masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = A[..., i, :i].clone()
|
||||
sub = A[..., :i, :i].clone()
|
||||
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=q.dtype)
|
||||
|
||||
#torch版
|
||||
for i in range(len(cu_seqlens)-1):
|
||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
||||
Default: `False`.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
return o, final_state
|
||||
251
vllm_kunlun/ops/fla/chunk_delta_h.py
Normal file
251
vllm_kunlun/ops/fla/chunk_delta_h.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp
|
||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BK, BV]
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += (boh * H + i_h) * K * V
|
||||
v += (bos * H + i_h) * V
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
w += (bos * H + i_h) * K
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += (bos * H + i_h) * V
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * K * V
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * K * V
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_v_new = (
|
||||
tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
if SAVE_NEW_VALUE
|
||||
else None
|
||||
)
|
||||
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
|
||||
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_v_new = b_v_new * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None]
|
||||
b_g_last = tl.exp(b_g_last)
|
||||
b_h1 = b_h1 * b_g_last
|
||||
if K > 64:
|
||||
b_h2 = b_h2 * b_g_last
|
||||
if K > 128:
|
||||
b_h3 = b_h3 * b_g_last
|
||||
if K > 192:
|
||||
b_h4 = b_h4 * b_g_last
|
||||
b_v_new = b_v_new.to(k.dtype.element_ty)
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.dot(b_k, b_v_new)
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.dot(b_k, b_v_new)
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.dot(b_k, b_v_new)
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.dot(b_k, b_v_new)
|
||||
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
|
||||
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = k.new_empty(
|
||||
N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), N * H)
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BV=64,
|
||||
)
|
||||
return h, v_new, final_state
|
||||
180
vllm_kunlun/ops/fla/chunk_o.py
Normal file
180
vllm_kunlun/ops/fla/chunk_o.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({
|
||||
# 'BK': BK,
|
||||
# 'BV': BV
|
||||
# },
|
||||
# num_warps=num_warps,
|
||||
# num_stages=num_stages) for BK in BKV_LIST
|
||||
# for BV in BKV_LIST for num_warps in NUM_WARPS
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=['H', 'K', 'V', 'BT'],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
||||
(BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
||||
(BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_o = b_o * tl.exp(b_g)[:, None]
|
||||
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
# m_t = o_t < T
|
||||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
# b_A = tl.where(m_A, b_A, 0)
|
||||
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
BV=32
|
||||
)
|
||||
return o
|
||||
144
vllm_kunlun/ops/fla/chunk_scaled_dot_kkt.py
Normal file
144
vllm_kunlun/ops/fla/chunk_scaled_dot_kkt.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
'USE_G': lambda args: args['g_cumsum'] is not None
|
||||
})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for BK in [32, 64, 128] for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=['H', 'K', 'BT', 'IS_VARLEN'],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta,
|
||||
g_cumsum,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
#m_t = o_t < T
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
||||
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
||||
(1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * tl.exp(b_g_diff) # 使用了triton而非vllm中的exp
|
||||
|
||||
#m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
#b_A = tl.where(m_A, b_A, 0)
|
||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g_cumsum (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
||||
Default: None
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
|
||||
B, T, Hg, K = k.shape
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g_cumsum,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
)
|
||||
return A
|
||||
229
vllm_kunlun/ops/fla/cumsum.py
Normal file
229
vllm_kunlun/ops/fla/cumsum.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import check_shared_mem, input_guard
|
||||
|
||||
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
# @triton.autotune(configs=[
|
||||
# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]
|
||||
# ],
|
||||
# key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
# [BT]
|
||||
b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32)
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, ))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
# @triton.autotune(configs=[
|
||||
# triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
|
||||
# for num_warps in [2, 4, 8]
|
||||
# ],
|
||||
# key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_vector_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
if REVERSE:
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
||||
else:
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
chunk_local_cumsum_scalar_kernel[grid](
|
||||
s=g_org,
|
||||
o=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
is_use_mask_zero = True
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
|
||||
|
||||
# keep cumulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_local_cumsum_vector_kernel[grid](g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
S=S,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse)
|
||||
return g
|
||||
|
||||
|
||||
@input_guard
|
||||
def chunk_local_cumsum(g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs) -> torch.Tensor:
|
||||
if not head_first and g.shape[1] < g.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[
|
||||
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape {g.shape}. "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise")
|
||||
153
vllm_kunlun/ops/fla/fused_recurrent.py
Normal file
153
vllm_kunlun/ops/fla/fused_recurrent.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
g.contiguous(),
|
||||
beta.contiguous(),
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state=inplace_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
h0_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
38
vllm_kunlun/ops/fla/index.py
Normal file
38
vllm_kunlun/ops/fla/index.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
|
||||
1).to(cu_seqlens)
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([
|
||||
cu_seqlens.new_tensor([0]),
|
||||
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
|
||||
]).cumsum(-1)
|
||||
143
vllm_kunlun/ops/fla/l2norm.py
Normal file
143
vllm_kunlun/ops/fla/l2norm.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
y,
|
||||
D,
|
||||
BD: tl.constexpr,
|
||||
eps,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
x += i_t * D
|
||||
y += i_t * D
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BD)
|
||||
mask = cols < D
|
||||
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=0)
|
||||
b_rstd = 1 / tl.sqrt(b_var + eps)
|
||||
# tl.store(Rstd + i_t, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
b_y = b_x * b_rstd
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB,
|
||||
T,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=1)
|
||||
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
||||
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * MBLOCK
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
343
vllm_kunlun/ops/fla/layernorm_guard.py
Normal file
343
vllm_kunlun/ops/fla/layernorm_guard.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Tri Dao
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
# ruff: noqa: E501
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def layer_norm_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float,
|
||||
z: torch.Tensor = None,
|
||||
out: torch.Tensor = None,
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
# if weight.shape != (N,):
|
||||
# weight = weight.reshape(N)
|
||||
# print("weight",weight.shape)
|
||||
# print("x",x.shape)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
layer_norm_fwd_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
# if z.shape != x_shape_og:
|
||||
# z = z.reshape(x_shape_og)
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
39
vllm_kunlun/ops/fla/op.py
Normal file
39
vllm_kunlun/ops/fla/op.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def div_normal(x, y):
|
||||
return x / y
|
||||
|
||||
div = div_normal
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
if not hasattr(tl, 'gather'):
|
||||
|
||||
@triton.jit
|
||||
def gather(src, index, axis, _builder=None):
|
||||
# This is a fallback implementation when tl.gather is not supported
|
||||
# In order to pass triton compiler, there is no actual gather operation
|
||||
return src
|
||||
else:
|
||||
gather = tl.gather
|
||||
422
vllm_kunlun/ops/fla/solve_tril.py
Normal file
422
vllm_kunlun/ops/fla/solve_tril.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import os
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import input_guard
|
||||
|
||||
base_dir = os.path.dirname(__file__)
|
||||
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
]
|
||||
)
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [1, 2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["BT"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ad = Ad + (bos * H + i_h) * 16
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
p_A = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
tl.store(
|
||||
p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel_modified(
|
||||
i_t,
|
||||
i_bh,
|
||||
i_n,
|
||||
bos,
|
||||
i_b,
|
||||
i_h,
|
||||
subA,
|
||||
subAd,
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T, # 32
|
||||
H: tl.constexpr, # 4
|
||||
BT: tl.constexpr, # 64
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
A = A + (bos * H + i_h) * BT
|
||||
print("for A Base offset ", (bos * H + i_h) * BT)
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
|
||||
range16 = tl.arange(0, 16)
|
||||
newp_A = subA + range16[:, None] * 16 + range16[None, :]
|
||||
b_A = tl.load(newp_A).to(tl.float32)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
print("[naive impl-0]loopIdx:", i)
|
||||
# print("for A start (i_t * 16 + i) * H * BT", (i_t * 16 + i) * H * BT)
|
||||
# print("for A start offset", offset)
|
||||
# print("for A start", (i_t * 16 + i) * H * BT + offset)
|
||||
print("[naive impl-1]b_A value in now loopIdx:", b_A)
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
# print("[naive impl-2]b_a value in now loopIdx:", b_a)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
print("[naive impl-2-1]b_a value after reduce in now loopIdx:", b_a)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
print("[naive impl-2-2]b_A value after oimask in now loopIdx:", b_A)
|
||||
# print("[naive impl-3]b_A result in now loopIdx:", b_A)
|
||||
# print(f"[naive impl-4] b_A value after allLoop = {b_A}")
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
# print(f"[naive impl-5] b_A value after mask = {b_A}")
|
||||
|
||||
newp_Ad = subAd + range16[:, None] * 16 + range16[None, :]
|
||||
tl.store(
|
||||
newp_Ad,
|
||||
b_A.to(subAd.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [1, 2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["BT"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel_modified_in_Loop(
|
||||
i_t,
|
||||
i_bh,
|
||||
i_n,
|
||||
bos,
|
||||
i_b,
|
||||
i_h,
|
||||
subA,
|
||||
subAd,
|
||||
AInLoop,
|
||||
ba_reduce,
|
||||
loopIdx,
|
||||
reduce_res,
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T, # 32
|
||||
H: tl.constexpr, # 4
|
||||
BT: tl.constexpr, # 64
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
range16 = tl.arange(0, 16)
|
||||
newp_A = subA + range16[:, None] * 16 + range16[None, :]
|
||||
b_A = tl.load(newp_A).to(tl.float32)
|
||||
# print("[loop impl-0]loopIdx:", loopIdx)
|
||||
# print("[loop impl-1]b_A value in now loopIdx:", b_A)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
i=loopIdx
|
||||
b_a = -tl.load(AInLoop + o_i)
|
||||
# print("[loop impl-2]b_a value in now loopIdx:", b_a)
|
||||
red_res = b_a[:, None] * b_A
|
||||
# print("[Triton]red_res=", red_res)
|
||||
tl.store(reduce_res + range16[:, None] * 16 + range16[None, :], red_res)
|
||||
# b_a = b_a + tl.sum(b_a[:, None] * b_A, 1) # TODO: revert to 0
|
||||
# # print("triton reduce b_a", b_a)
|
||||
# tl.store(ba_reduce + o_i, b_a)
|
||||
|
||||
# mask = o_i == i
|
||||
# # print("mask", mask[:, None])
|
||||
# # print("b_a", b_a)
|
||||
# # print("b_A", b_A)
|
||||
# print("before b_A", b_A)
|
||||
# b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
# print("[loop impl-3]b_A result in now loopIdx:", b_A)
|
||||
|
||||
# tl.store(newp_A, b_A)
|
||||
|
||||
|
||||
def solve_tril_16x16_kernel_new(
|
||||
NT,
|
||||
B,
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H,
|
||||
BT,
|
||||
IS_VARLEN,
|
||||
):
|
||||
Ad_modify = Ad
|
||||
for loopX in range(NT):
|
||||
# i_n, i_t = tl.load(chunk_indices ...
|
||||
chunk_indices_load_offset_1 = loopX * 2
|
||||
row_idx = chunk_indices_load_offset_1 // chunk_indices.shape[1]
|
||||
col_idx = chunk_indices_load_offset_1 % chunk_indices.shape[1]
|
||||
i_n = int(chunk_indices[row_idx, col_idx])
|
||||
chunk_indices_load_offset_2 = loopX * 2 + 1
|
||||
row_idx = chunk_indices_load_offset_2 // chunk_indices.shape[1]
|
||||
col_idx = chunk_indices_load_offset_2 % chunk_indices.shape[1]
|
||||
i_t = int(chunk_indices[row_idx, col_idx])
|
||||
|
||||
# bos, eos = tl.load(cu_seqlens ...
|
||||
cu_seqlens_load_offset_1 = i_n
|
||||
bos = int(cu_seqlens[cu_seqlens_load_offset_1])
|
||||
cu_seqlens_load_offset_2 = i_n + 1
|
||||
eos = int(cu_seqlens[cu_seqlens_load_offset_2])
|
||||
T = eos - bos
|
||||
|
||||
for loopY in range(B * H):
|
||||
i_b = loopY // H
|
||||
i_h = loopY % H
|
||||
|
||||
# get subA
|
||||
if (bos * H + i_h) < H:
|
||||
Tstart = loopX * 16 % BT
|
||||
Tend = Tstart + 16
|
||||
BTstart = loopX * 16 % BT
|
||||
BTend = BTstart + 16
|
||||
subA = A[0, Tstart:Tend, loopY, BTstart:BTend].contiguous().clone()
|
||||
# print(f"subA slice A dim[0, {Tstart}:{Tend}, {loopY}, {BTstart}:{BTend}]")
|
||||
if (Tend > T): # bondary check
|
||||
subA[T-16:, :] = 0
|
||||
|
||||
# subA.shape torch.Size([9, 16])
|
||||
# vvv
|
||||
# subA.shape torch.Size([16, 16]) 用0补齐
|
||||
if subA.shape[0] < 16:
|
||||
pad_rows = 16 - subA.shape[0]
|
||||
zeros = torch.zeros((pad_rows, subA.shape[1]), dtype=subA.dtype, device=subA.device)
|
||||
subA = torch.cat([subA, zeros], dim=0)
|
||||
else:
|
||||
assert(0) & "need deal this situation"
|
||||
|
||||
# get subAd
|
||||
if (bos * H + i_h) < H:
|
||||
Tstart = loopX * 16
|
||||
Tend = Tstart + 16
|
||||
BTstart = 0 * 16
|
||||
BTend = BTstart + 16
|
||||
subAd = Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend].contiguous().clone()
|
||||
# print(f'T={T}, Tstart={Tstart}, Tend={Tend}, BTstart={BTstart}, BTend={BTend}')
|
||||
else:
|
||||
assert(0) & "need deal this situation"
|
||||
|
||||
mask = (torch.arange(16, device=subA.device)[:, None] > torch.arange(16, device=subA.device)[None, :])
|
||||
subA = -torch.where(mask, subA, torch.zeros_like(subA))
|
||||
|
||||
for inLoopIdx in range(1, min(16, T - i_t * 16)):
|
||||
# print(f"loopX={loopX}, loopY={loopY}, inLoopIdx={inLoopIdx}")
|
||||
offsetStart=loopX*16 % BT
|
||||
offsetEnd=offsetStart+16
|
||||
|
||||
AInLoop = A[0, (loopX * 16 + inLoopIdx), loopY, offsetStart:offsetEnd]
|
||||
# print(f"AInLoop slice A dim[0, {(loopX * 16 + inLoopIdx)}, {loopY}, {offsetStart}:{offsetEnd}")
|
||||
|
||||
ba_reduce = torch.empty_like(AInLoop)
|
||||
reduce_res = torch.empty_like(subA)
|
||||
solve_tril_16x16_kernel_modified_in_Loop[1, 1](
|
||||
i_t,
|
||||
loopY,
|
||||
i_n,
|
||||
bos,
|
||||
i_b,
|
||||
i_h,
|
||||
subA=subA,
|
||||
subAd=subAd,
|
||||
AInLoop=AInLoop,
|
||||
ba_reduce=ba_reduce,
|
||||
loopIdx=inLoopIdx,
|
||||
reduce_res=reduce_res,
|
||||
A=A,
|
||||
Ad=Ad_modify,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=1,
|
||||
num_stages=4,
|
||||
)
|
||||
AInLoop = AInLoop.flatten()
|
||||
b_A = subA # [16x16]
|
||||
b_a = -AInLoop[0:16] # [16]
|
||||
b_a = b_a + torch.sum(reduce_res, 0)
|
||||
ba_reduce = b_a
|
||||
o_i = torch.arange(16, device=ba_reduce.device)
|
||||
mask = (o_i == inLoopIdx)
|
||||
mask_expand = mask[:, None]
|
||||
subA = torch.where(mask_expand, ba_reduce, subA)
|
||||
|
||||
subAd = subA + (torch.arange(16, device=subA.device)[:, None] == torch.arange(16, device=subA.device)[None, :])
|
||||
|
||||
# deal store mask
|
||||
Tstart = loopX * 16
|
||||
Tend = Tstart + 16
|
||||
BTstart = 0 * 16
|
||||
BTend = BTstart + 16
|
||||
# print(f"slice Ad_modify dim[0, {Tend-needMaskRow}:{Tend}, {loopY}, {BTstart}:{BTend}]")
|
||||
if (Tend > T): # bondary mask
|
||||
needMaskRow = Tend - T
|
||||
Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend] = subAd[:T-Tstart, :]
|
||||
else:
|
||||
# assert (Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend].shape == subAd.shape)
|
||||
Ad_modify[0, Tstart:Tend, loopY, BTstart:BTend] = subAd
|
||||
|
||||
# if BT == 16:
|
||||
# return Ad
|
||||
|
||||
return Ad_modify
|
||||
|
||||
# @input_guard
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the lower triangular matrix
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, K]
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
# cnt = 0
|
||||
# for b in range(B):
|
||||
# for t in range(T):
|
||||
# for h in range(H):
|
||||
# for d in range(BT):
|
||||
# A[b, t, h, d] = cnt
|
||||
# cnt += 1
|
||||
|
||||
Ad = -999 * torch.ones(
|
||||
B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
|
||||
)
|
||||
# cnt = 0
|
||||
# for b in range(B):
|
||||
# for t in range(T):
|
||||
# for h in range(H):
|
||||
# for d in range(16):
|
||||
# Ad[b, t, h, d] = cnt
|
||||
# cnt += 1
|
||||
|
||||
Ad_modify = Ad.clone()
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
||||
|
||||
import os
|
||||
if os.getenv("TRITON_INTERPRET", None) == "1":
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=1,
|
||||
num_stages=4,
|
||||
)
|
||||
return Ad
|
||||
|
||||
Ad_modify = solve_tril_16x16_kernel_new(
|
||||
NT,
|
||||
B,
|
||||
A=A,
|
||||
Ad=Ad_modify,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
IS_VARLEN= True if cu_seqlens is not None else False,
|
||||
# num_warps=1,
|
||||
# num_stages=4,
|
||||
).to(A.dtype)
|
||||
return Ad_modify
|
||||
85
vllm_kunlun/ops/fla/torch_fla.py
Normal file
85
vllm_kunlun/ops/fla/torch_fla.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return x * inv_norm
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = l2norm(query, dim=-1, eps=1e-6)
|
||||
key = l2norm(key, dim=-1, eps=1e-6)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size))
|
||||
key = F.pad(key, (0, 0, 0, pad_size))
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
total_sequence_length = sequence_length + pad_size
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
last_recurrent_state = (
|
||||
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, total_sequence_length // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
180
vllm_kunlun/ops/fla/utils.py
Normal file
180
vllm_kunlun/ops/fla/utils.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
|
||||
|
||||
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
||||
|
||||
|
||||
def tensor_cache(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
|
||||
cache_size = 4
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache_entries, cache_size
|
||||
for i, entry in enumerate(cache_entries):
|
||||
last_args, last_kwargs, last_result = entry
|
||||
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
|
||||
and all(a is b for a, b in zip(args, last_args)) \
|
||||
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
|
||||
cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
|
||||
(args, kwargs, last_result)
|
||||
]
|
||||
return last_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if len(cache_entries) >= cache_size:
|
||||
cache_entries = cache_entries[1:]
|
||||
cache_entries.append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def input_guard(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else
|
||||
i.contiguous() for i in args)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = torch.cuda.device(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except BaseException:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
|
||||
device = get_available_device()
|
||||
mapping = {
|
||||
"cuda": "nvidia",
|
||||
"hip": "amd",
|
||||
"xpu": "intel",
|
||||
}
|
||||
# return the mapped value, or the original if not found
|
||||
return mapping.get(device, device)
|
||||
|
||||
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = (device_platform == 'amd')
|
||||
is_intel = (device_platform == 'nvidia')
|
||||
is_nvidia = (device_platform == 'nvidia')
|
||||
is_intel_alchemist = (is_intel
|
||||
and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
|
||||
is_nvidia_hopper = (is_nvidia
|
||||
and ('NVIDIA H' in torch.cuda.get_device_name(0)
|
||||
or torch.cuda.get_device_capability()[0] >= 9))
|
||||
use_cuda_graph = (is_nvidia
|
||||
and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)
|
||||
['max_shared_mem'] for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
||||
247
vllm_kunlun/ops/fla/wy_fast.py
Normal file
247
vllm_kunlun/ops/fla/wy_fast.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
|
||||
RESOLUTION = {
|
||||
torch.bool: 0,
|
||||
torch.int16: 0,
|
||||
torch.int32: 0,
|
||||
torch.int64: 0,
|
||||
torch.float16: 1e-3,
|
||||
torch.float32: 1.3e-6,
|
||||
torch.bfloat16: 0.016,
|
||||
torch.complex32: 1e-3,
|
||||
torch.complex64: 1.3e-6,
|
||||
}
|
||||
|
||||
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
assert res.dtype == dtype
|
||||
ref = ref.to(dtype)
|
||||
atol = 1e-3 * reduce_dim
|
||||
rtol = RESOLUTION[dtype]
|
||||
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_u_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
p_u = tl.make_block_ptr(
|
||||
u + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_w_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K),
|
||||
(Hg * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
p_w = tl.make_block_ptr(
|
||||
w + (bos * H + i_h) * K,
|
||||
(T, K),
|
||||
(H * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
recompute_w_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return w, u
|
||||
@@ -1,29 +1,13 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Author: Dong Xinyu, Chen Zhennan, Bao Qian, Yuan Jizhong
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""layer.py"""
|
||||
import torch
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
|
||||
@@ -101,7 +85,6 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if self.moe.use_ep:
|
||||
return ops.fused_moe_ep(x,
|
||||
layer.w13_weight,
|
||||
@@ -116,6 +99,96 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
# fused_moe do not support expert number > 400
|
||||
elif layer.local_num_experts > 400:
|
||||
hidden_states = x
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
router_logits = router_logits.float()
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=layer.w13_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.swiglu(y, out1)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=layer.w2_weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
@@ -155,6 +228,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
@@ -189,7 +263,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig.make(
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
@@ -197,7 +271,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
# quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
@@ -307,4 +381,35 @@ class FusedMoE(VllmFusedMoE):
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
return final_hidden_states
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
num_physical_experts = num_experts + num_redundant_experts
|
||||
|
||||
# In the returned mapping:
|
||||
# - `expert_id` is the physical expert id
|
||||
# - `weight_name` contains the weight name of the logical expert
|
||||
# So that we should map the expert id to logical in `weight_name`
|
||||
physical_to_logical_map = \
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
num_experts, num_redundant_experts)
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_" if weight_name
|
||||
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
|
||||
expert_id, shard_id) for expert_id in range(num_physical_experts)
|
||||
for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
@@ -12,49 +12,101 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||
from vllm.model_executor.layers import layernorm
|
||||
from typing import Optional, Union
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
return out
|
||||
|
||||
|
||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
@staticmethod
|
||||
def forward_xpu(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
|
||||
if residual is not None:
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=weight+1,
|
||||
eps=variance_epsilon,
|
||||
output=x
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
weight+1,
|
||||
out,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if torch.compiler.is_compiling():
|
||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static, backend="aot_eager")
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||
File diff suppressed because it is too large
Load Diff
1217
vllm_kunlun/ops/mamba/causal_conv1d.py
Normal file
1217
vllm_kunlun/ops/mamba/causal_conv1d.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,6 @@ _PARTITION_SIZE = 512
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
@@ -53,18 +52,18 @@ class PagedAttention:
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
"""
|
||||
Get the shape of the KV cache. Returns different shapes based on whether the computation is on-chip.
|
||||
If on-chip (is_kunlun() is True), returns shape (2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
Otherwise, returns shape (2, num_blocks, block_size * num_kv_heads * head_size).
|
||||
|
||||
获取KV缓存的形状,根据是否在芯片上进行计算返回不同的形状。
|
||||
如果在芯片上(is_kunlun()为True),则返回形状(2, num_blocks, num_kv_heads, block_size, head_size);
|
||||
否则,返回形状(2, num_blocks, block_size * num_kv_heads * head_size)。
|
||||
|
||||
Args:
|
||||
num_blocks (int): The number of blocks.
|
||||
block_size (int): The size of each block.
|
||||
num_kv_heads (int): The number of KV heads.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
num_blocks (int): 块数量。
|
||||
block_size (int): 每个块大小。
|
||||
num_kv_heads (int): KV头数量。
|
||||
head_size (int): 每个头大小。
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The shape of the KV cache, including two elements: the first element is 2, indicating the number of dimensions is 2; the second element is one of num_blocks, num_kv_heads, block_size, and head_size.
|
||||
Tuple[int, ...]: KV缓存的形状,包括两个元素:第一个元素为2,表示维度数量为2;第二个元素为num_blocks、num_kv_heads、block_size和head_size中的任意一个。
|
||||
"""
|
||||
if current_platform.is_kunlun():
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
@@ -77,20 +76,20 @@ class PagedAttention:
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split a cached tensor (containing key and value) into two parts, each part is a tensor.
|
||||
If running on KUNLUN, the first returned tensor is the key cache, and the second tensor is the value cache.
|
||||
Otherwise, the first tensor is the key cache, and the second tensor is a view of the key cache with shape (num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
and the third tensor is the value cache with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
|
||||
将一个缓存张量(包含key和value)分成两部分,每个部分是一个张量。
|
||||
如果在KUNLUN上运行,则返回的第一个张量是key缓存,第二个张量是value缓存。
|
||||
否则,第一个张量是key缓存,第二个张量是key缓存的view,其形状为(num_blocks, num_kv_heads, head_size//x, -1, x),
|
||||
第三个张量是value缓存,其形状为(num_blocks, num_kv_heads, head_size, -1)。
|
||||
|
||||
Args:
|
||||
kv_cache (torch.Tensor): A tensor containing key and value, with shape (2, num_blocks, kv_cache_size).
|
||||
num_kv_heads (int): The number of heads in multi-head attention.
|
||||
head_size (int): The size of each head.
|
||||
|
||||
kv_cache (torch.Tensor): 包含key和value的张量,形状为(2, num_blocks, kv_cache_size)。
|
||||
num_kv_heads (int): 多头注意力中的头数。
|
||||
head_size (int): 每个头的大小。
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- key_cache (torch.Tensor): A tensor containing the key cache, with shape (num_blocks, num_kv_heads, head_size//x, -1, x).
|
||||
- value_cache (torch.Tensor): A tensor containing the value cache, with shape (num_blocks, num_kv_heads, head_size, -1).
|
||||
- key_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size//x, -1, x),包含key缓存。
|
||||
- value_cache (torch.Tensor): 形状为(num_blocks, num_kv_heads, head_size, -1),包含value缓存。
|
||||
"""
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
@@ -100,7 +99,8 @@ class PagedAttention:
|
||||
value_cache = kv_cache[1]
|
||||
else:
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
@@ -152,17 +152,16 @@ class PagedAttention:
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
assert (
|
||||
blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0
|
||||
), (
|
||||
f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables."
|
||||
)
|
||||
assert (blocksparse_block_size > 0 and
|
||||
blocksparse_block_size % block_size == 0), \
|
||||
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||
f"{block_size=} used in block_tables.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
@@ -170,10 +169,9 @@ class PagedAttention:
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = max_seq_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
@@ -302,4 +300,4 @@ class PagedAttention:
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
@@ -1,128 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, Pan Xiakai, You Zeyu
|
||||
# Email: liwei157@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
|
||||
Input: packed[N, K], dtype=int32, saved as AWQ order
|
||||
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order
|
||||
"""
|
||||
N, K = packed.shape
|
||||
self.align_type = 1 if K % 8 == 0 else 0
|
||||
assert num_bits == 4, "Only int4 supported now"
|
||||
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
|
||||
|
||||
if self.align_type == 0: # NORMAL MODE
|
||||
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
|
||||
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
|
||||
|
||||
# Reverse AWQ order and convert to KUNLUN order
|
||||
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
|
||||
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
|
||||
|
||||
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
|
||||
packed_kunlun = (unpacked_kunlun << shifts).sum(
|
||||
dim=-1, dtype=torch.int32
|
||||
) # [N, K]
|
||||
elif self.align_type == 1: # FAST MODEL
|
||||
# Unpack AWQ order
|
||||
unpacked_awq = (
|
||||
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
|
||||
) & 0xF # [N, K//8, 8, 8]
|
||||
|
||||
# Reverse AWQ order and convert to KUNLUN order
|
||||
AWQ_TO_KUNLUN_ORDER_FAST = [
|
||||
32, 0, 36, 4, 33, 1, 37, 5,
|
||||
34, 2, 38, 6, 35, 3, 39, 7,
|
||||
40, 8, 44, 12, 41, 9, 45, 13,
|
||||
42, 10, 46, 14, 43, 11, 47, 15,
|
||||
48, 16, 52, 20, 49, 17, 53, 21,
|
||||
50, 18, 54, 22, 51, 19, 55, 23,
|
||||
56, 24, 60, 28, 57, 25, 61, 29,
|
||||
58, 26, 62, 30, 59, 27, 63, 31
|
||||
]
|
||||
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
packed_kunlun = (
|
||||
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
|
||||
) # [N, K]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(
|
||||
(
|
||||
self.repack_int4_for_kunlun(layer.qweight.data)
|
||||
if layer.qweight.data.dtype == torch.int32
|
||||
else layer.qweight.data
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.qzeros = torch.nn.Parameter(
|
||||
(
|
||||
self.repack_int4_for_kunlun(layer.qzeros.data)
|
||||
if layer.qzeros.data.dtype == torch.int32
|
||||
else layer.qzeros.data
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = torch.ops._C.awq_dequantize(
|
||||
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
|
||||
)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = torch.ops._C.awq_gemm(
|
||||
reshaped_x, qweight, scales, qzeros, align_type=self.align_type
|
||||
)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
|
||||
AWQLinearMethod.process_weights_after_loading = process_weights_after_loading
|
||||
AWQLinearMethod.apply = apply
|
||||
@@ -1,37 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# Author: Chen Zhennan, Dong Xinyu
|
||||
# Email: chenzhennan@baidu.com
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||||
|
||||
from compressed_tensors.config import (
|
||||
CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure,
|
||||
)
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
|
||||
# TODO: import position will be changed after 0.9.0
|
||||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||||
|
||||
@@ -42,7 +19,6 @@ import xtorch_ops
|
||||
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||||
@@ -50,239 +26,177 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
linear_cfg = None
|
||||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||||
if k in tsm and isinstance(tsm[k], dict):
|
||||
linear_cfg = tsm[k]
|
||||
break
|
||||
linear_cfg = tsm[k]; break
|
||||
if not linear_cfg:
|
||||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
wq = linear_cfg.get("weights")
|
||||
aq = linear_cfg.get("input_activations")
|
||||
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
|
||||
if not wq or not aq:
|
||||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
# Other branches are handled as needed; default fallback:
|
||||
# 其它分流按需;默认回落:
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
|
||||
# copied from vllm 0.9.0
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Directly create a default quantization config dictionary to avoid validation issues with QuantizationArgs
|
||||
|
||||
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
|
||||
# print("Creating default INT8 quantization config for MoE")
|
||||
|
||||
# 创建默认的权重量化配置字典
|
||||
self.weight_quant = type('WeightQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'channel',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': False,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 创建默认的输入激活量化配置字典
|
||||
self.input_quant = type('InputQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'token',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': True,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# Create a default weight quantization config dictionary
|
||||
self.weight_quant = type(
|
||||
"WeightQuant",
|
||||
(),
|
||||
{
|
||||
"type": "int",
|
||||
"num_bits": 8,
|
||||
"strategy": "channel",
|
||||
"group_size": 128,
|
||||
"symmetric": True,
|
||||
"dynamic": False,
|
||||
"actorder": "none",
|
||||
"observer": None,
|
||||
"observer_kwargs": {},
|
||||
"block_structure": None,
|
||||
},
|
||||
)()
|
||||
|
||||
# Create a default input activation quantization config dictionary
|
||||
self.input_quant = type(
|
||||
"InputQuant",
|
||||
(),
|
||||
{
|
||||
"type": "int",
|
||||
"num_bits": 8,
|
||||
"strategy": "token",
|
||||
"group_size": 128,
|
||||
"symmetric": True,
|
||||
"dynamic": True,
|
||||
"actorder": "none",
|
||||
"observer": None,
|
||||
"observer_kwargs": {},
|
||||
"block_structure": None,
|
||||
},
|
||||
)()
|
||||
|
||||
# Change comparison method to directly compare strings
|
||||
# 修改比较方式,直接比较字符串
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == "channel"
|
||||
and self.input_quant.strategy == "token"
|
||||
)
|
||||
and self.input_quant.strategy == "token")
|
||||
if not per_channel:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}"
|
||||
)
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found static input scales."
|
||||
)
|
||||
"dynamic per token quantization. Found static input scales.")
|
||||
|
||||
def create_weights1(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Use float32 as a placeholder for weights to facilitate loading original weights from ckpt
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
), # generally is torch.bfloat16
|
||||
requires_grad=False,
|
||||
)
|
||||
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# 权重先用浮点占位,便于从 ckpt 加载原始权重
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype), # 通常是 torch.bfloat16
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Channel scale: float32 + 2D [E, out] (aligned with fused_moe/UT)
|
||||
# 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Input scale can be dynamically calculated
|
||||
# 输入 scale 动态计算即可
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8,
|
||||
), # directly use int8
|
||||
requires_grad=False,
|
||||
)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
), # directly use int8
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Scale factors
|
||||
# 缩放因子
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Input scale can be dynamically calculated
|
||||
# 输入 scale 动态计算
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
# Convert original weights to float32 for more robust statistics
|
||||
#原始权重转 float32 做统计更稳健
|
||||
w13_f = layer.w13_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
|
||||
# Each column (abs_max) -> per-column scale (out dimension is dim=1, column is dim=-1)
|
||||
# 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
|
||||
qmax = 127.0
|
||||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
|
||||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
|
||||
# Quantization: broadcast 3D scale and store back to 2D scale
|
||||
# 量化:用 3D scale 广播,存回 2D scale
|
||||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
|
||||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
|
||||
|
||||
# Optional: If your fused/kernel expects scale pre-multiplied by 127 (to be consistent with some UT backends), uncomment the following two lines:
|
||||
# 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
|
||||
w13_scale_2d = w13_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
|
||||
# Write back parameters: weight int8; scale uses float32 + 2D
|
||||
replace_parameter(
|
||||
layer, "w13_weight", torch.nn.Parameter(w13_q, requires_grad=False)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, "w2_weight", torch.nn.Parameter(w2_q, requires_grad=False)
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False),
|
||||
)
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w2_weight_scale",
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False),
|
||||
)
|
||||
|
||||
# Brief check
|
||||
print(
|
||||
f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}"
|
||||
)
|
||||
# 回写参数:权重 int8;scale 用 float32 + 2D
|
||||
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w13_weight_scale',
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight_scale',
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
|
||||
|
||||
# 简要检查
|
||||
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -300,11 +214,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False, # Add this parameter
|
||||
expert_load_view: Optional[torch.Tensor] = None, # Add this parameter
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # Add this parameter
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # Add this parameter
|
||||
linear_weights: Optional[torch.Tensor] = None, # Add this parameter
|
||||
enable_eplb: bool = False, # 添加这个参数
|
||||
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
) -> torch.Tensor:
|
||||
|
||||
output = torch.empty_like(x)
|
||||
@@ -326,8 +240,5 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
print(
|
||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod"
|
||||
)
|
||||
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
|
||||
@@ -1,108 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Li Wei, You Zeyu
|
||||
# Email: liwei157@baidu.com, youzeyu@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(
|
||||
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits)
|
||||
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
|
||||
requires_grad=False
|
||||
)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
|
||||
# No need shuffle on xpu
|
||||
# ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
# self.quant_config.weight_bits)
|
||||
|
||||
|
||||
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
|
||||
N, K = packed.shape
|
||||
assert num_bits == 4, "Only int4 supported now"
|
||||
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
|
||||
|
||||
# Unpack int32 to int4 values
|
||||
unpacked_gptq = (
|
||||
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
|
||||
) & 0xF # [N, K//8, 8, 8]
|
||||
|
||||
# Convert to KUNLUN order
|
||||
GPTQ_TO_KUNLUN_ORDER_FAST = [
|
||||
32, 0, 33, 1, 34, 2, 35, 3,
|
||||
36, 4, 37, 5, 38, 6, 39, 7,
|
||||
40, 8, 41, 9, 42, 10, 43, 11,
|
||||
44, 12, 45, 13, 46, 14, 47, 15,
|
||||
48, 16, 49, 17, 50, 18, 51, 19,
|
||||
52, 20, 53, 21, 54, 22, 55, 23,
|
||||
56, 24, 57, 25, 58, 26, 59, 27,
|
||||
60, 28, 61, 29, 62, 30, 63, 31,
|
||||
]
|
||||
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
|
||||
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
|
||||
|
||||
# Pack to int32
|
||||
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
|
||||
packed_kunlun = (
|
||||
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
|
||||
) # [N, K]
|
||||
|
||||
return packed_kunlun
|
||||
|
||||
|
||||
def apply(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
output = torch.ops.xspeedgate_ops.gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
layer.qzeros,
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
|
||||
GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading
|
||||
GPTQLinearMethod.apply = apply
|
||||
@@ -12,35 +12,33 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import xspeedgate_ops
|
||||
import os
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding,
|
||||
MRotaryEmbedding,
|
||||
)
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, MRotaryEmbedding)
|
||||
from typing import Optional, Tuple
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
if hasattr(self, "scaling_factor"):
|
||||
self.max_position_embeddings = int(
|
||||
self.max_position_embeddings * self.scaling_factor
|
||||
)
|
||||
if hasattr(self, 'scaling_factor'):
|
||||
self.max_position_embeddings = int(self.max_position_embeddings * self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
if os.getenv("FUSED_QK_ROPE_OP") == "1":
|
||||
#对于glm4-9b-chat,rope跑forward_native,所以需要cache保持特定的形状,这里通过环境变量控制
|
||||
#对于qwen2.5-vl,rope跑mrope,也需要cache保持特定的形状
|
||||
#也就是说跑glm4-9b-chat、qwen2.5-vl,需要设置GLM4_CHAT环境变量为1
|
||||
if os.getenv('ROPE_NATIVE_2D') == "1":
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
if os.getenv('USE_ORI_ROPE') == "0":
|
||||
cache_cos = torch.cat((cos, cos), dim=-1)
|
||||
cache_sin = torch.cat((sin, sin), dim=-1)
|
||||
# [2, self.max_position_embeddings, self.rotary_dim * 2]
|
||||
@@ -51,89 +49,108 @@ def vllm_kunlun_compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if (
|
||||
self.cos_sin_cache.device != query.device
|
||||
or self.cos_sin_cache.dtype != query.dtype
|
||||
):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
self.rotary_dim,
|
||||
offsets,
|
||||
)
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def apply_interleaved_rope(x: torch.Tensor,
|
||||
mrope_section: list[int]) -> torch.Tensor:
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||
"""
|
||||
x_t = x[0].clone()
|
||||
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
|
||||
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
|
||||
return x_t
|
||||
|
||||
def vllm_kunlun_apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
|
||||
is_neox_style: bool) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
query, key = ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
def vllm_kunlun_mrope_forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
|
||||
query,
|
||||
key,
|
||||
positions.to(dtype=torch.int32),
|
||||
self.cos_sin_cache,
|
||||
self.mrope_interleaved,
|
||||
self.is_neox_style,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_section[0],
|
||||
self.mrope_section[1],
|
||||
self.mrope_section[2]
|
||||
)
|
||||
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
return query, key
|
||||
|
||||
query, key = torch.ops.xspeedgate_ops.mrotary_embedding_fwd_v0(
|
||||
query,
|
||||
key,
|
||||
positions.to(dtype=torch.int32),
|
||||
self.cos_sin_cache,
|
||||
False, # self.mrope_interleaved,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_section[0],
|
||||
self.mrope_section[1],
|
||||
self.mrope_section[2],
|
||||
)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
if os.getenv("KUNLUN_ENABLE_MULTI_LORA") == "1" or os.getenv("FUSED_QK_ROPE_OP") == "1":
|
||||
RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
else:
|
||||
pass
|
||||
# RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
# RotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
# MRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||
# YaRNScalingRotaryEmbedding._compute_cos_sin_cache = vllm_kunlun_compute_cos_sin_cache
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
@@ -145,36 +162,27 @@ def Split_Norm_Rope(
|
||||
max_position_embeddings: int,
|
||||
q_head_num: int,
|
||||
kv_head_num: int,
|
||||
head_dim: int,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
head_dim:int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_tokens = qkv.shape[0]
|
||||
rotary_dim = head_dim
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
q_emb_out = torch.empty(
|
||||
(num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
k_emb_out = torch.empty(
|
||||
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
v_out = torch.empty(
|
||||
(num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device
|
||||
)
|
||||
rotary_dim=head_dim
|
||||
q_emb_out = torch.empty((num_tokens, q_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
|
||||
k_emb_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
|
||||
v_out = torch.empty((num_tokens, kv_head_num * head_dim), dtype=qkv.dtype, device=qkv.device)
|
||||
torch.ops._C.split_norm_rope_neox(
|
||||
q_emb_out,
|
||||
k_emb_out,
|
||||
v_out,
|
||||
qkv,
|
||||
cos_sin_cache,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
positions,
|
||||
num_tokens,
|
||||
max_position_embeddings,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
)
|
||||
return q_emb_out, k_emb_out, v_out
|
||||
q_emb_out,
|
||||
k_emb_out,
|
||||
v_out,
|
||||
qkv,
|
||||
cos_sin_cache,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
positions,
|
||||
num_tokens,
|
||||
max_position_embeddings,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
)
|
||||
return q_emb_out, k_emb_out, v_out
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -98,12 +98,10 @@ cached_backends: Dict[int, CompilerFn] = {}
|
||||
unset = Unset.token
|
||||
|
||||
from torch._C._dynamo.eval_frame import set_eval_frame
|
||||
|
||||
|
||||
def _maybe_set_eval_frame(callback: DynamoCallback):
|
||||
# A wrapper on set_eval_frame that is guarded by a Justknob.
|
||||
# Users can disable torchDynamo by setting the JK to False.
|
||||
# from torch._C._dynamo.eval_frame import set_eval_frame
|
||||
#from torch._C._dynamo.eval_frame import set_eval_frame
|
||||
|
||||
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
|
||||
torch._dynamo.utils.warn_once(
|
||||
@@ -130,7 +128,7 @@ DONT_WRAP_FILES = {
|
||||
|
||||
|
||||
def _debug_get_cache_entry_list(
|
||||
code: Union[types.CodeType, Callable[..., Any]],
|
||||
code: Union[types.CodeType, Callable[..., Any]]
|
||||
) -> List[CacheEntry]:
|
||||
"""
|
||||
Given a code object or a callable object, retrieve the cache entries
|
||||
@@ -373,9 +371,9 @@ class _TorchDynamoContext:
|
||||
# add context containing GraphModule to any GraphModule forward functions
|
||||
if isinstance(fn, GraphModule):
|
||||
# add context containing GraphModule to any GraphModule forward functions
|
||||
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = (
|
||||
weakref.ref(fn)
|
||||
)
|
||||
code_context.get_context(fn.forward.__code__)[
|
||||
"orig_graphmodule"
|
||||
] = weakref.ref(fn)
|
||||
|
||||
# Optimize the forward method of torch.nn.Module object
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
@@ -789,11 +787,9 @@ def _optimize(
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
dynamic=dynamic,
|
||||
compiler_config=(
|
||||
backend.get_compiler_config()
|
||||
if hasattr(backend, "get_compiler_config")
|
||||
else None
|
||||
),
|
||||
compiler_config=backend.get_compiler_config()
|
||||
if hasattr(backend, "get_compiler_config")
|
||||
else None,
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
)
|
||||
|
||||
@@ -907,11 +903,9 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
|
||||
flat_args[i],
|
||||
symbolic_context=StatelessSymbolicContext(
|
||||
dynamic_sizes=[
|
||||
(
|
||||
DimDynamic.DYNAMIC
|
||||
if d in flat_args_dynamic_dims[i]
|
||||
else DimDynamic.STATIC
|
||||
)
|
||||
DimDynamic.DYNAMIC
|
||||
if d in flat_args_dynamic_dims[i]
|
||||
else DimDynamic.STATIC
|
||||
for d in range(len(flat_args[i].shape))
|
||||
],
|
||||
constraint_sizes=[None] * len(flat_args[i].shape),
|
||||
|
||||
@@ -4,28 +4,26 @@ import os
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
VLLM_MULTI_LOGPATH: str = ("./log",)
|
||||
ENABLE_VLLM_MULTI_LOG: bool = (False,)
|
||||
ENABLE_VLLM_INFER_HOOK: bool = (False,)
|
||||
ENABLE_VLLM_OPS_HOOK: bool = (False,)
|
||||
ENABLE_VLLM_MODULE_HOOK: bool = False
|
||||
|
||||
VLLM_MULTI_LOGPATH : str = "./log",
|
||||
ENABLE_VLLM_MULTI_LOG : bool = False,
|
||||
ENABLE_VLLM_INFER_HOOK : bool = False,
|
||||
ENABLE_VLLM_OPS_HOOK : bool = False,
|
||||
ENABLE_VLLM_MODULE_HOOK : bool = False
|
||||
|
||||
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
||||
"""
|
||||
If the value is None, return None; otherwise, convert the string to an integer and return it.
|
||||
|
||||
如果值是None,则返回None;否则将字符串转换为整数并返回。
|
||||
|
||||
Args:
|
||||
value (Optional[str], optional): The optional string to convert. Defaults to None.
|
||||
|
||||
value (Optional[str], optional): 要转换的可选字符串. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Optional[int]: If the value is None, return None; otherwise, convert the string to an integer and return it.
|
||||
Optional[int]: 如果值是None,则返回None;否则将字符串转换为整数并返回.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the used env vars.
|
||||
|
||||
@@ -33,56 +31,59 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
||||
|
||||
xvllm_environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# path to the logs of redirect-output, abstrac of related are ok
|
||||
"VLLM_MULTI_LOGPATH": lambda: os.environ.get("VLLM_MULTI_LOGPATH", "./logs"),
|
||||
# turn on / off multi-log of multi nodes & multi cards
|
||||
"ENABLE_VLLM_MULTI_LOG": lambda: (
|
||||
os.environ.get("ENABLE_VLLM_MULTI_LOG", "False").lower() in ("true", "1")
|
||||
),
|
||||
# turn on / off XVLLM infer stage log ability
|
||||
"ENABLE_VLLM_INFER_HOOK": lambda: (
|
||||
os.environ.get("ENABLE_VLLM_INFER_HOOK", "False").lower() in ("true", "1")
|
||||
),
|
||||
# turn on / off XVLLM infer_ops log ability
|
||||
"ENABLE_VLLM_OPS_HOOK": lambda: (
|
||||
os.environ.get("ENABLE_VLLM_OPS_HOOK", "False").lower() in ("true", "1")
|
||||
),
|
||||
"ENABLE_VLLM_MODULE_HOOK": lambda: (
|
||||
os.environ.get("ENABLE_VLLM_MODULE_HOOK", "False").lower() in ("true", "1")
|
||||
),
|
||||
"VLLM_MULTI_LOGPATH":
|
||||
lambda: os.environ.get("VLLM_MULTI_LOGPATH", "./logs"),
|
||||
|
||||
# turn on / off multi-log of multi nodes & multi cards
|
||||
"ENABLE_VLLM_MULTI_LOG":
|
||||
lambda: (os.environ.get("ENABLE_VLLM_MULTI_LOG", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# turn on / off XVLLM infer stage log ability
|
||||
"ENABLE_VLLM_INFER_HOOK":
|
||||
lambda: (os.environ.get("ENABLE_VLLM_INFER_HOOK", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# turn on / off XVLLM infer_ops log ability
|
||||
"ENABLE_VLLM_OPS_HOOK":
|
||||
lambda: (os.environ.get("ENABLE_VLLM_OPS_HOOK", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
"ENABLE_VLLM_MODULE_HOOK":
|
||||
lambda: (os.environ.get("ENABLE_VLLM_MODULE_HOOK", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# fuse sorted op with fused_moe kernel
|
||||
"ENABLE_VLLM_MOE_FC_SORTED": lambda: (
|
||||
os.environ.get("ENABLE_VLLM_MOE_FC_SORTED", "False").lower() in ("true", "1")
|
||||
),
|
||||
"ENABLE_VLLM_MOE_FC_SORTED":
|
||||
lambda: (os.environ.get("ENABLE_VLLM_MOE_FC_SORTED", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# enable custom dpsk scaling rope
|
||||
"ENABLE_CUSTOM_DPSK_SCALING_ROPE": lambda: (
|
||||
os.environ.get("ENABLE_CUSTOM_DPSK_SCALING_ROPE", "False").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
"ENABLE_CUSTOM_DPSK_SCALING_ROPE":
|
||||
lambda: (os.environ.get("ENABLE_CUSTOM_DPSK_SCALING_ROPE", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# fuse qkv split & qk norm & qk rope
|
||||
# only works for qwen3 dense and qwen3 moe models
|
||||
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE": lambda: (
|
||||
os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE":
|
||||
lambda: (os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower() in
|
||||
("true", "1")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""
|
||||
This function is called when an attribute that doesn't exist is accessed.
|
||||
If the attribute is one of the xvllm_environment_variables, return the corresponding value.
|
||||
Otherwise, raise an AttributeError.
|
||||
|
||||
当调用不存在的属性时,该函数被调用。如果属性是xvllm_environment_variables中的一个,则返回相应的值。否则引发AttributeError异常。
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to retrieve.
|
||||
|
||||
name (str): 要获取的属性名称。
|
||||
|
||||
Raises:
|
||||
AttributeError (Exception): If the attribute is not one of xvllm_environment_variables, this exception is raised.
|
||||
|
||||
AttributeError (Exception): 如果属性不是xvllm_environment_variables中的一个,则会引发此异常。
|
||||
|
||||
Returns:
|
||||
Any, optional: If the attribute is one of xvllm_environment_variables, the corresponding value is returned; otherwise, None is returned.
|
||||
Any, optional: 如果属性是xvllm_environment_variables中的一个,则返回相应的值;否则返回None。
|
||||
"""
|
||||
# lazy evaluation of environment variables
|
||||
if name in xvllm_environment_variables:
|
||||
@@ -92,14 +93,13 @@ def __getattr__(name: str):
|
||||
|
||||
def __dir__():
|
||||
"""
|
||||
Returns a list of all visible variable names.
|
||||
|
||||
返回一个包含所有可见的变量名称的列表。
|
||||
|
||||
返回值(list):一个包含所有可见的变量名称的列表,这些变量是通过`xvllm_environment_variables`字典定义的。
|
||||
|
||||
Returns:
|
||||
list: A list of all visible variable names, which are defined through the `xvllm_environment_variables` dictionary.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of all visible variable names.
|
||||
These variables are defined through the `xvllm_environment_variables` dictionary.
|
||||
List[str]: 一个包含所有可见的变量名称的列表。
|
||||
这些变量是通过`xvllm_environment_variables`字典定义的。
|
||||
"""
|
||||
return list(xvllm_environment_variables.keys())
|
||||
|
||||
|
||||
@@ -19,10 +19,11 @@ class KunlunPlatform(Platform):
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
"""Returns the device type, which is fixed as 'cuda'.
|
||||
"""
|
||||
返回设备类型,固定为'cuda'。
|
||||
"""
|
||||
return "cuda"
|
||||
|
||||
|
||||
def is_kunlun(self) -> bool:
|
||||
"""is_kunlun"""
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
@@ -69,13 +70,14 @@ class KunlunPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Returns the device name, which defaults to "kunlun".
|
||||
|
||||
"""
|
||||
获取设备名称,默认返回 "kunlun"。
|
||||
|
||||
Args:
|
||||
device_id (int, optional): The device ID, default is 0. Ignored in this method. Defaults to 0.
|
||||
|
||||
device_id (int, optional): 设备ID,默认为0. Ignored in this method. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
str: The device name, which is fixed as "kunlun".
|
||||
str: 设备名称,固定返回 "kunlun".
|
||||
"""
|
||||
return "kunlun"
|
||||
|
||||
@@ -89,23 +91,26 @@ class KunlunPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
"""Returns the total memory size of the device in bytes (B). Defaults to the total memory size of the first device.
|
||||
If the `device_id` parameter is not an integer or exceeds the available device range, a ValueError will be raised.
|
||||
|
||||
"""
|
||||
获取设备总内存大小,单位为字节(B)。默认返回第一个设备的总内存大小。
|
||||
如果传入参数`device_id`不是整数或者超出了可用设备范围,将会引发ValueError异常。
|
||||
|
||||
Args:
|
||||
device_id (int, optional): The device ID, default is 0. Defaults to 0.
|
||||
|
||||
device_id (int, optional): 设备ID,默认为0. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `device_id` parameter is not an integer or exceeds the available device range, this exception is raised.
|
||||
|
||||
ValueError: 当传入的`device_id`不是整数或者超出了可用设备范围时引发此异常。
|
||||
|
||||
Returns:
|
||||
int: The total memory size of the device in bytes (B).
|
||||
int: 设备总内存大小,单位为字节(B)。
|
||||
"""
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
"""Returns a context manager that disables gradient computation.
|
||||
"""
|
||||
进入推理模式,禁止计算梯度。
|
||||
返回:torch.no_grad(),一个上下文管理器,用于禁止计算梯度。
|
||||
"""
|
||||
return torch.no_grad()
|
||||
|
||||
@@ -114,29 +119,31 @@ class KunlunPlatform(Platform):
|
||||
"""get_device_capability"""
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""Updates the default values of various components based on the configuration.
|
||||
If not specified, automatically selects the worker class based on certain conditions.
|
||||
If the block size is not set in the cache configuration, it is set to 16.
|
||||
If using MLA and `VLLM_ATTENTION_BACKEND` is not set or is set to "FLASHMLA",
|
||||
the cache block size is set to 64.
|
||||
If running in DeepEP high throughput backend, data parallelism greater than 1, and CUDA graph mode,
|
||||
it forces the use of eager mode, as DP + DeepEP high throughput kernels are not CUDA graph compatible,
|
||||
and using DeepEP low latency kernels can resolve this issue.
|
||||
|
||||
"""
|
||||
根据配置更新各个部分的默认值。
|
||||
如果未指定,则根据某些条件自动选择worker类。
|
||||
如果缓存配置中没有设置块大小,则将其设置为16。
|
||||
如果使用MLA,并且`VLLM_ATTENTION_BACKEND`未设置或设置为"FLASHMLA",
|
||||
则将缓存块大小设置为64。
|
||||
如果在DeepEP高吞吐量后端、数据并行大于1和CUDA图形模式下运行,则强制
|
||||
强制执行即时模式,因为DP + DeepEP高吞吐量内核不是CUDA图形兼容的,而且
|
||||
使用DeepEP低延迟内核可以解决这个问题。
|
||||
|
||||
Args:
|
||||
vllm_config (VllmConfig): VLLM configuration object.
|
||||
|
||||
vllm_config (VllmConfig): VLLM配置对象。
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If multi-step scheduling is used on vLLM V1, this exception is raised.
|
||||
Please remove the --num-scheduler-steps argument from the command line.
|
||||
NotImplementedError: If MLA is used on vLLM V1, this exception is raised.
|
||||
Please ensure that the `VLLM_ATTENTION_BACKEND` environment variable is set before using MLA.
|
||||
|
||||
NotImplementedError: 如果在vLLM V1上使用多步调度,则会引发NotImplementedError。
|
||||
请从命令行中删除--num-scheduler-steps参数。
|
||||
NotImplementedError: 如果在vLLM V1上使用MLA,则会引发NotImplementedError。
|
||||
请确保在使用MLA之前设置了`VLLM_ATTENTION_BACKEND`环境变量。
|
||||
|
||||
Returns:
|
||||
None: No return value.
|
||||
None: 无返回值。
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
@@ -159,7 +166,7 @@ class KunlunPlatform(Platform):
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
@@ -198,9 +205,10 @@ class KunlunPlatform(Platform):
|
||||
vllm_config.compilation_config.pass_config.enable_fusion = False
|
||||
vllm_config.compilation_config.use_inductor = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,use_sink):
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,use_sink, use_sparse=False):
|
||||
"""
|
||||
Returns the class of attention backend based on the selected backend and other parameters.
|
||||
|
||||
@@ -227,15 +235,16 @@ class KunlunPlatform(Platform):
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
"""Gets the current memory usage of the device, including allocated and max allocated.
|
||||
If no device is specified, defaults to the current context's device.
|
||||
|
||||
Args:
|
||||
device (Optional[torch.types.Device], optional): Optional device object, defaults to None. Defaults to the current context's device.
|
||||
|
||||
Returns:
|
||||
float: Returns a float representing the current memory usage of the device, in bytes.
|
||||
|
||||
"""
|
||||
获取当前设备的内存使用情况,包括已分配和最大分配。
|
||||
如果未指定设备,则默认为当前上下文中的设备。
|
||||
|
||||
Args:
|
||||
device (Optional[torch.types.Device], optional): 可选的设备对象,默认为None。默认为当前上下文中的设备。
|
||||
|
||||
Returns:
|
||||
float: 返回一个浮点数,表示当前设备的内存使用情况,单位是字节(bytes)。
|
||||
|
||||
Raises:
|
||||
None.
|
||||
"""
|
||||
@@ -244,17 +253,18 @@ class KunlunPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
"""Checks if asynchronous output is supported.
|
||||
By default, Kunlun does not support asynchronous output.
|
||||
|
||||
Args:
|
||||
enforce_eager (Optional[bool], optional): Whether to enforce eager execution. Defaults to None.
|
||||
None means not to force eager execution, but to automatically select based on the current environment.
|
||||
|
||||
Returns:
|
||||
bool: True means asynchronous output is supported, False means asynchronous output is not supported.
|
||||
"""
|
||||
# Assume Kunlun does not support asynchronous output
|
||||
判断是否支持异步输出。
|
||||
默认情况下,Kunlun 不支持异步输出。
|
||||
|
||||
Args:
|
||||
enforce_eager (Optional[bool], optional): 是否强制使用 eager execution. Defaults to None.
|
||||
None 表示不强制使用 eager execution,而是根据当前环境自动选择。
|
||||
|
||||
Returns:
|
||||
bool: True 表示支持异步输出,False 表示不支持异步输出。
|
||||
"""
|
||||
# 假设 Kunlun 不支持异步输出
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@@ -279,11 +289,42 @@ class KunlunPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
'''
|
||||
'''
|
||||
communicator
|
||||
'''
|
||||
return "vllm_kunlun.distributed.kunlun_communicator.KunlunCommunicator"
|
||||
return "vllm_kunlun.distributed.kunlun_communicator.KunlunCommunicator"
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls):
|
||||
return "vllm_kunlun.lora.punica_wrapper.punica_kunlun.PunicaWrapperKunlun"
|
||||
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
'''
|
||||
Kunlun3平台支持的数据类型
|
||||
'''
|
||||
supported_dtypes = {
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
}
|
||||
if torch_dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"Kunlun platform does not support dtype {torch_dtype}. "
|
||||
"Supported dtypes are: fp32, fp16, bf16, int8."
|
||||
)
|
||||
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
'''
|
||||
确保V1 Graph在Kunlun3平台使用vllm.unified_attention_with_output_kunlun作为split ops
|
||||
'''
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
@@ -1,61 +1,28 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import vllm
|
||||
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
import vllm_kunlun.platforms.envs as xenvs
|
||||
import vllm_kunlun.platforms.envs as xenvs
|
||||
from vllm.utils import weak_ref_tensor
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
get_origin,
|
||||
get_args,
|
||||
List,
|
||||
)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||
Optional, Tuple, TypeVar, Union, cast, overload,
|
||||
get_origin, get_args, List)
|
||||
import torch
|
||||
from torch.library import Library
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
|
||||
def redirect_output():
|
||||
"""
|
||||
Redirect output to a specified directory and name the log files as pp=0_rank=X or pp=1_rank=X.
|
||||
If it is the first process of the first process group, use pp=0; otherwise, use pp=1.
|
||||
|
||||
重定向输出到指定目录,并将日志文件命名为pp=0_rank=X或pp=1_rank=X。
|
||||
如果是第一个进程组的第一个进程,则使用pp=0;否则使用pp=1。
|
||||
|
||||
Args:
|
||||
No parameters.
|
||||
|
||||
无参数。
|
||||
|
||||
Returns:
|
||||
No return value, directly modify the file descriptors of sys.stdout and sys.stderr.
|
||||
无返回值,直接修改sys.stdout和sys.stderr的文件描述符。
|
||||
"""
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_pp_group
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
dir_path = xenvs.VLLM_MULTI_LOGPATH
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
@@ -63,54 +30,48 @@ def redirect_output():
|
||||
log_file = os.path.join(dir_path, f"pp=0_rank={rank}.log")
|
||||
else:
|
||||
log_file = os.path.join(dir_path, f"pp=1_rank={rank}.log")
|
||||
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
|
||||
fd = os.open(log_file, os.O_WRONLY | os.O_CREAT| os.O_TRUNC, 0o644)
|
||||
os.dup2(fd, sys.stdout.fileno())
|
||||
os.dup2(fd, sys.stderr.fileno())
|
||||
os.close(fd)
|
||||
|
||||
|
||||
def multi_log_monkey_patch(func):
|
||||
"""
|
||||
Monkey patch function for logging multiple times, used to test log redirection functionality.
|
||||
This function will print a log message each time the patched function is called.
|
||||
|
||||
多次打印日志的猴子补丁函数,用于测试日志重定向功能。
|
||||
该函数会在每次调用被补丁的函数时打印一条日志信息。
|
||||
|
||||
Args:
|
||||
func (function): The original function to be patched.
|
||||
|
||||
func (function): 需要被补丁的原始函数。
|
||||
|
||||
Returns:
|
||||
function: A wrapped new function that prints a log message each time it is called.
|
||||
function: 返回一个包装后的新函数,每次调用都会打印一条日志信息。
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
print("[monkey patch] ensure_model_parallel_initialized")
|
||||
func(*args, **kwargs)
|
||||
redirect_output()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# if os.environ.get("VLLM_MULTI_LOG", "0") == "1":
|
||||
if xenvs.ENABLE_VLLM_MULTI_LOG:
|
||||
print("ENABLE_VLLM_MULTI_LOG monkey--------")
|
||||
vllm.distributed.ensure_model_parallel_initialized = multi_log_monkey_patch(
|
||||
vllm.distributed.ensure_model_parallel_initialized
|
||||
)
|
||||
|
||||
vllm.distributed.ensure_model_parallel_initialized)
|
||||
|
||||
class StageHookPre(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
This method will be automatically executed when the object is called.
|
||||
If the current attention metadata is not None and a token has been processed, print "Per Token Start"; otherwise, print "First Token Start".
|
||||
|
||||
在调用对象时,会自动执行此方法。
|
||||
如果当前的attention metadata不为None,并且已经处理了一个token,则打印"Per Token Start";否则打印"First Token Start"。
|
||||
|
||||
Args:
|
||||
args (tuple, optional): Variable length argument list, default is an empty tuple.
|
||||
kwargs (dict, optional): Keyword arguments, default is an empty dictionary.
|
||||
|
||||
args (tuple, optional): 可变参数,默认为空元组。
|
||||
kwargs (dict, optional): 关键字参数,默认为空字典。
|
||||
|
||||
Returns:
|
||||
None: No return value.
|
||||
None: 无返回值。
|
||||
"""
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
if attn_metadata.num_decode_tokens == 0:
|
||||
@@ -118,22 +79,20 @@ class StageHookPre(object):
|
||||
else:
|
||||
print("Per Token Start", flush=True)
|
||||
|
||||
|
||||
class StageHookPost(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
If the current context's attention metadata is not None and num_decode_tokens equals 0, print "First Token End".
|
||||
Otherwise, print "Per Token End".
|
||||
|
||||
如果当前上下文中的attention metadata不为None,并且num_decode_tokens等于0,则打印"First Token End"。
|
||||
否则,打印"Per Token End"。
|
||||
|
||||
Args:
|
||||
args (Tuple[Any]): Variable length argument list, unused parameters are passed in.
|
||||
kwargs (Dict[str, Any]): Keyword arguments, unused parameters are passed in.
|
||||
|
||||
args (Tuple[Any]): 可变长度参数列表,无用参数传入。
|
||||
kwargs (Dict[str, Any]): 字典类型的关键字参数,无用参数传入。
|
||||
|
||||
Returns:
|
||||
None: No return value.
|
||||
None: 该函数没有返回值。
|
||||
"""
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
if attn_metadata.num_decode_tokens == 0:
|
||||
@@ -145,24 +104,22 @@ class StageHookPost(object):
|
||||
class ModuleLoggingHookPre(object):
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialization function to initialize the indentation list and name list.
|
||||
The indentation list is used to store the indentation information of each line,
|
||||
and the name list is used to store the name of each variable or function.
|
||||
初始化函数,用于初始化缩进列表和名称列表。
|
||||
缩进列表用于存储每一行的缩进信息,名称列表用于存储每一个变量或函数的名称。
|
||||
"""
|
||||
self.indent_list = list()
|
||||
self.indent_list.append("")
|
||||
self.name_list = list()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
This method overrides the __call__ method and is used when the class is instantiated.
|
||||
It increases the current indentation by one Tab and records the current class name.
|
||||
It prints the start information, flush=True means it will be output to the console immediately.
|
||||
|
||||
重写了 __call__ 方法,用于在类实例化时调用。
|
||||
将当前缩进增加一个 Tab,并记录当前类名称。
|
||||
打印开始信息,flush=True 表示立即输出到控制台。
|
||||
|
||||
Args:
|
||||
args (tuple): Variable length argument list, default is an empty tuple.
|
||||
kwargs (dict): Keyword arguments, default is an empty dictionary.
|
||||
|
||||
args (tuple): 传入的参数列表,第一个元素是类实例。
|
||||
kwargs (dict): 传入的关键字参数列表,不使用。
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
@@ -174,70 +131,62 @@ class ModuleLoggingHookPre(object):
|
||||
class ModuleLoggingHookPost(object):
|
||||
def __init__(self, indent_list, name_list):
|
||||
"""
|
||||
Initialization function to set the indentation list and name list.
|
||||
|
||||
初始化函数,设置缩进列表和名称列表。
|
||||
|
||||
Args:
|
||||
indent_list (List[str]): A list of indentation strings for each node, indexed from 0.
|
||||
name_list (List[str]): A list of name strings for each node, indexed from 0.
|
||||
Note: The indentation list and name list should have the same length, otherwise it will cause an error.
|
||||
|
||||
indent_list (List[str]): 包含每个节点的缩进字符串的列表,索引从0开始。
|
||||
name_list (List[str]): 包含每个节点的名称字符串的列表,索引从0开始。
|
||||
注意:缩进列表和名称列表应该有相同长度,否则会导致错误。
|
||||
|
||||
Returns:
|
||||
None: No return value, directly modifies the instance's attributes.
|
||||
None. 无返回值,直接修改了类实例的属性。
|
||||
"""
|
||||
self.indent_list = indent_list
|
||||
self.name_list = name_list
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
This method is called when the object is invoked.
|
||||
Args:
|
||||
*args, **kwargs: Variable length argument list and keyword argument dictionary, unused.
|
||||
Returns:
|
||||
None: No return value.
|
||||
当调用对象时,输出模块结束信息。
|
||||
参数:*args、**kwargs - 可变长度的位置参数列表和关键字参数字典,未使用。
|
||||
返回值:None,无返回值。
|
||||
"""
|
||||
print(self.indent_list[-1] + self.name_list[-1] + " Module End", flush=True)
|
||||
self.indent_list.pop()
|
||||
self.name_list.pop()
|
||||
|
||||
|
||||
# if os.environ.get("ENABLE_VLLM_MODULE_HOOK", "0") == "1":
|
||||
if xenvs.ENABLE_VLLM_MODULE_HOOK:
|
||||
from torch.nn.modules.module import (
|
||||
register_module_forward_pre_hook,
|
||||
register_module_forward_hook,
|
||||
)
|
||||
|
||||
from torch.nn.modules.module import register_module_forward_pre_hook, register_module_forward_hook
|
||||
module_logging_hook_pre = ModuleLoggingHookPre()
|
||||
module_logging_hook_post = ModuleLoggingHookPost(
|
||||
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list
|
||||
)
|
||||
module_logging_hook_pre.indent_list, module_logging_hook_pre.name_list)
|
||||
register_module_forward_pre_hook(module_logging_hook_pre)
|
||||
register_module_forward_hook(module_logging_hook_post)
|
||||
else:
|
||||
module_logging_hook_pre = None
|
||||
module_logging_hook_post = None
|
||||
|
||||
|
||||
class LoggingDispatchMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialization function to initialize the attributes and methods of the class.
|
||||
Some initialization operations can be performed here, such as setting default values.
|
||||
初始化函数,用于初始化类的属性和方法。
|
||||
在此处可以进行一些初始化操作,例如设置默认值等。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
"""
|
||||
Override the default dispatch behavior of torch.nn.Module.
|
||||
This function will be called before and after each method call on this module.
|
||||
It can be used to log information about the method calls.
|
||||
|
||||
|
||||
Args:
|
||||
func (function): The function that is being called on this module.
|
||||
types (Tuple[str]): A tuple of strings representing the type signatures of the arguments.
|
||||
See torch.types for more details.
|
||||
args (Tuple[Any], optional): The positional arguments passed to the function. Defaults to ().
|
||||
kwargs (Dict[str, Any], optional): The keyword arguments passed to the function. Defaults to {}.
|
||||
|
||||
|
||||
Returns:
|
||||
Any: The result returned by the function.
|
||||
"""
|
||||
@@ -249,20 +198,19 @@ class LoggingDispatchMode(TorchDispatchMode):
|
||||
print(indent + "{} calling".format(func), flush=True)
|
||||
result = func(*args, **(kwargs or {}))
|
||||
print(indent + "{} called".format(func), flush=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return result
|
||||
|
||||
class CUDAGraphInnerWatcher(TorchDispatchMode):
|
||||
|
||||
|
||||
def __init__(self, name_list):
|
||||
"""
|
||||
Initialization function to save the name list to the class attribute.
|
||||
It also creates a dictionary to keep track of the tensors that have been traced.
|
||||
|
||||
初始化函数,将传入的名称列表保存到类属性中。
|
||||
同时创建一个字典来记录已经追踪过的张量。
|
||||
|
||||
Args:
|
||||
name_list (List[str]): A list of names of tensors to be tracked.
|
||||
|
||||
name_list (List[str]): 包含需要追踪的张量名称的列表。
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
@@ -274,13 +222,13 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
|
||||
Override the default dispatch behavior of PyTorch tensors to track
|
||||
the tracing process. If the result of a function call is a tensor on CUDA,
|
||||
it will be added to the traced_tensor dictionary with the name of the function.
|
||||
|
||||
|
||||
Args:
|
||||
func (Callable): The function to be called.
|
||||
types (Tuple[Type]): The type hints of the function.
|
||||
args (Tuple[Any], optional): Positional arguments for the function. Defaults to ().
|
||||
kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the function. Defaults to None.
|
||||
|
||||
|
||||
Returns:
|
||||
Any: The result of the function call.
|
||||
"""
|
||||
@@ -292,13 +240,13 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Clear the traced_tensor and name_list, and call the parent class's __exit__ method.
|
||||
|
||||
清空 traced_tensor 和 name_list,并调用父类的 __exit__ 方法。
|
||||
|
||||
Args:
|
||||
exc_type (Optional[Type[BaseException]]): The type of the exception, default is None.
|
||||
exc_val (Optional[BaseException]): The value of the exception, default is None.
|
||||
exc_tb (Optional[TracebackType]): he traceback object, default is None.
|
||||
|
||||
exc_type (Optional[Type[BaseException]]): 异常类型,默认为 None。
|
||||
exc_val (Optional[BaseException]): 异常值,默认为 None。
|
||||
exc_tb (Optional[TracebackType]): Traceback 对象,默认为 None。
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
@@ -308,11 +256,22 @@ class CUDAGraphInnerWatcher(TorchDispatchMode):
|
||||
self.name_list.clear()
|
||||
super(CUDAGraphInnerWatcher, self).__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
# def patch_annotations_for_schema(func):
|
||||
# sig = inspect.signature(func)
|
||||
# new_params = []
|
||||
# for name, param in sig.parameters.items():
|
||||
# anno = param.annotation
|
||||
# if anno == list[int]:
|
||||
# anno = typing.List[int]
|
||||
# new_params.append(param.replace(annotation=anno))
|
||||
# new_sig = sig.replace(parameters=new_params)
|
||||
# func.__signature__ = new_sig
|
||||
# return func
|
||||
|
||||
def patch_annotations_for_schema(func):
|
||||
"""
|
||||
At runtime, replace list[int] and Optional[list[int]] in the function signature with typing.List[int] and Optional[typing.List[int]]
|
||||
so that torch.library.infer_schema can recognize it.
|
||||
运行时替换函数签名里的 list[int]、Optional[list[int]] 为 typing.List[int] / Optional[typing.List[int]]
|
||||
让 torch.library.infer_schema 能识别
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
new_params = []
|
||||
@@ -320,7 +279,7 @@ def patch_annotations_for_schema(func):
|
||||
for name, param in sig.parameters.items():
|
||||
ann = param.annotation
|
||||
|
||||
# If it is Optional[T]
|
||||
# 如果是 Optional[T]
|
||||
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
|
||||
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
|
||||
if get_origin(inner_type) is list: # Optional[list[int]]
|
||||
@@ -328,7 +287,7 @@ def patch_annotations_for_schema(func):
|
||||
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
|
||||
param = param.replace(annotation=new_ann)
|
||||
|
||||
# If it is a direct list[int]
|
||||
# 如果是直接 list[int]
|
||||
elif get_origin(ann) is list:
|
||||
args = get_args(ann)
|
||||
new_ann = List[args[0] if args else typing.Any]
|
||||
@@ -339,23 +298,20 @@ def patch_annotations_for_schema(func):
|
||||
func.__signature__ = sig.replace(parameters=new_params)
|
||||
return func
|
||||
|
||||
|
||||
def supports_custom_op() -> bool:
|
||||
"""supports_custom_op"""
|
||||
return hasattr(torch.library, "custom_op")
|
||||
|
||||
|
||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
tags: tuple[torch.Tag, ...] = (),
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
tags: tuple[torch.Tag, ...] = (),
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
@@ -374,28 +330,25 @@ def direct_register_custom_op(
|
||||
"""
|
||||
if not supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
assert not current_platform.is_cuda_alike(), (
|
||||
"cuda platform needs torch>=2.4 to support custom op, "
|
||||
"chances are you are using an old version of pytorch "
|
||||
"or a custom build of pytorch. It is recommended to "
|
||||
"use vLLM in a fresh new environment and let it install "
|
||||
"the required dependencies."
|
||||
)
|
||||
"the required dependencies.")
|
||||
return
|
||||
|
||||
import torch.library
|
||||
|
||||
if hasattr(torch.library, "infer_schema"):
|
||||
patched_func = patch_annotations_for_schema(op_func)
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
schema_str = torch.library.infer_schema(op_func,
|
||||
mutates_args=mutates_args)
|
||||
else:
|
||||
# for pytorch 2.4
|
||||
import torch._custom_op.impl
|
||||
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str, tags=tags)
|
||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
@@ -1,8 +1,5 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Author: Dong Xinyu, Bao Qian, Chen Zhennan, Ma Tianyu, Wang Haowen
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,6 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
import xtorch_ops
|
||||
from dataclasses import dataclass
|
||||
@@ -24,8 +23,8 @@ import torch
|
||||
import numpy as np
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
||||
# from vllm.attention.backends.utils import CommonAttentionState
|
||||
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
||||
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
||||
|
||||
@@ -45,6 +44,7 @@ from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
|
||||
|
||||
class KunlunAttentionBackend(AttentionBackend):
|
||||
"""KunlunAttentionBackend"""
|
||||
# crucial to cuda graph
|
||||
@@ -70,10 +70,10 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
"""get_builder_cls"""
|
||||
return KunlunAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
"""get_state_cls"""
|
||||
return CommonAttentionState
|
||||
# @staticmethod
|
||||
# def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
# """get_state_cls"""
|
||||
# return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
@@ -81,6 +81,7 @@ class KunlunAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto"
|
||||
) -> Tuple[int, ...]:
|
||||
"""get_kv_cache_shape"""
|
||||
# return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
@@ -132,7 +133,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
@@ -143,6 +148,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# Prefix cache loc
|
||||
kv_lod_cpu: Optional[torch.Tensor] = None
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
@@ -181,6 +191,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
enable_kv_scales_calculation: Optional[bool] = False
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
@@ -193,6 +204,11 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
use_cascade: Optional[bool] = False
|
||||
|
||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""__post_init__"""
|
||||
@@ -253,6 +269,19 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[-self.num_prefills:])
|
||||
|
||||
|
||||
if self.kv_lod_cpu is None:
|
||||
kv_lod_cpu = None
|
||||
kv_lod_xpu = None
|
||||
else:
|
||||
start = -(self.num_prefills + 1)
|
||||
base_cpu = self.kv_lod_cpu[start]
|
||||
kv_lod_cpu = self.kv_lod_cpu[start:] - base_cpu
|
||||
|
||||
base_xpu = self.kv_lod_xpu[start]
|
||||
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
|
||||
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = KunlunMetadata(
|
||||
num_actual_tokens=self.num_actual_tokens,
|
||||
@@ -264,7 +293,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
seq_start_loc=None,
|
||||
seq_start_loc = None,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
max_query_len=self.max_query_len,
|
||||
max_kv_len=self.max_kv_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
@@ -413,18 +444,28 @@ class KunlunAttentionMetadataBuilder:
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
return modified_batch
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> KunlunMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens_tensor.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
"""build"""
|
||||
num_reqs=common_attn_metadata.num_reqs
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens
|
||||
max_query_len=common_attn_metadata.max_query_len
|
||||
common_prefix_len=common_prefix_len
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
common_prefix_len = common_prefix_len
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
@@ -432,18 +473,18 @@ class KunlunAttentionMetadataBuilder:
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if len(seq_start_loc) != num_reqs + 1:
|
||||
seq_start_loc = query_start_loc_host.tolist()
|
||||
|
||||
if seq_start_loc[-1] != num_actual_tokens:
|
||||
seq_start_loc = query_start_loc_host.tolist()
|
||||
|
||||
|
||||
|
||||
|
||||
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
|
||||
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
|
||||
|
||||
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
|
||||
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
|
||||
kv_lod_xpu = kv_lod_cpu.to(self.device)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
|
||||
@@ -456,6 +497,7 @@ class KunlunAttentionMetadataBuilder:
|
||||
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
|
||||
|
||||
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
|
||||
|
||||
if num_prefill_tokens == 0:
|
||||
max_prefill_seq_len = 0
|
||||
else:
|
||||
@@ -473,6 +515,8 @@ class KunlunAttentionMetadataBuilder:
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens_tensor=seq_lens,
|
||||
seq_lens_tensor_cpu=seq_lens_cpu,
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
max_query_len=max_prefill_seq_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
@@ -483,7 +527,6 @@ class KunlunAttentionMetadataBuilder:
|
||||
use_cuda_graph=False,
|
||||
use_cascade=use_cascade,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
@@ -514,11 +557,15 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
sinks:Optional[torch.Tensor]= None,
|
||||
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
|
||||
) -> None:
|
||||
"""__init__"""
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"kunlunAttention does not support block-sparse attention.")
|
||||
# if logits_soft_cap is not None:
|
||||
# raise ValueError(
|
||||
# "kunlunAttention does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -547,6 +594,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}.")
|
||||
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -560,7 +608,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward"""
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@@ -597,12 +646,22 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
value = value.contiguous()
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping)
|
||||
if key_cache.is_contiguous():
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping)
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
cast_key_cache,
|
||||
cast_value_cache,
|
||||
updated_slot_mapping)
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
@@ -614,22 +673,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
||||
assert prefill_query.shape[0] == num_prefill_tokens
|
||||
output[num_decode_tokens:attn_metadata.num_actual_tokens] = KunlunOps.multi_query_kv_attention(
|
||||
prefill_meta.query_start_loc,prefill_meta.query_start_loc_host, prefill_query, prefill_key, prefill_value,
|
||||
alibi_slopes=self.alibi_slopes).view_as(prefill_query)
|
||||
xtorch_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache (block_num, head, block_size, dim)
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
||||
is_causal=True,
|
||||
is_prefix_cache=True,
|
||||
block_table=prefill_meta.block_tables,
|
||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_lse=None,
|
||||
sink=self.sinks
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
decode_query = query[:num_decode_tokens]
|
||||
|
||||
|
||||
if key_cache.is_contiguous():
|
||||
tmp_block_tables = decode_meta.block_tables
|
||||
else:
|
||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
||||
|
||||
xtorch_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_tables=decode_meta.block_tables,
|
||||
block_tables=tmp_block_tables,
|
||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||
is_context=False,
|
||||
@@ -639,7 +714,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
@@ -650,13 +724,18 @@ def use_cascade_attention(
|
||||
num_sms: int,
|
||||
use_local_attention: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
TODO: Not Yet Supported on Kunlun platform
|
||||
"""Decide whether to use cascade attention.
|
||||
|
||||
This function 1) checks whether cascade attention is supported with the
|
||||
given configuration, and 2) heuristically decides whether using cascade
|
||||
attention can improve performance.
|
||||
"""
|
||||
# Too short common prefix. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||
# possible to avoid any unnecessary computation.
|
||||
return False
|
||||
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def get_token_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=tokens.device)
|
||||
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
||||
bin_counts = bin_counts[:, :vocab_size]
|
||||
mask = bin_counts > 0
|
||||
|
||||
return bin_counts, mask
|
||||
|
||||
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies penalties in place to the logits tensor
|
||||
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
||||
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
||||
are padded to the maximum prompt length within the batch using
|
||||
`vocab_size` as the padding value. The value `vocab_size` is used
|
||||
for padding because it does not correspond to any valid token ID
|
||||
in the vocabulary.
|
||||
output_tokens_tensor: The output tokens tensor.
|
||||
presence_penalties: The presence penalties of shape (num_seqs, )
|
||||
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
||||
"""
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
|
||||
# Apply repetition penalties as a custom op
|
||||
from vllm._custom_ops import apply_repetition_penalties_torch
|
||||
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
@@ -22,7 +22,7 @@ class TopKTopPSampler(nn.Module):
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, logprobs_mode):
|
||||
super().__init__()
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.")
|
||||
@@ -57,7 +57,7 @@ class TopKTopPSampler(nn.Module):
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
return random_sample(probs, generators), None
|
||||
if generators:
|
||||
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
@@ -66,8 +66,7 @@ class TopKTopPSampler(nn.Module):
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators)
|
||||
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
@@ -195,4 +194,4 @@ def flashinfer_sample(
|
||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
||||
probs, top_k=k, top_p=p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
return next_token_ids.view(-1)
|
||||
4184
vllm_kunlun/v1/worker/gpu_model_runner.py
Normal file
4184
vllm_kunlun/v1/worker/gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2043
vllm_kunlun/worker/model_runner.py
Normal file
2043
vllm_kunlun/worker/model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
50
vllm_kunlun/worker/worker.py
Normal file
50
vllm_kunlun/worker/worker.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""worker"""
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from vllm.v1.worker.gpu_worker import Worker, _check_if_gpu_supports_dtype, init_worker_distributed_environment
|
||||
from vllm.model_executor import set_random_seed
|
||||
from .model_runner import KunlunModelRunner
|
||||
from vllm.utils import MemorySnapshot
|
||||
import torch
|
||||
import os
|
||||
import gc
|
||||
|
||||
class KunlunWorker(Worker):
|
||||
"""Worker"""
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.init_snapshot = MemorySnapshot()
|
||||
free_memory, total = torch.cuda.mem_get_info()
|
||||
self.init_gpu_memory = free_memory
|
||||
# 设置一个合理的初始值,比如总内存的 80%
|
||||
self.requested_memory = int(total * 0.2) # 留出 20% 的余量
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
# Construct the model runner
|
||||
self.model_runner: KunlunModelRunner = KunlunModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
Reference in New Issue
Block a user