Fuse write kv buffer into rope for qwen3 moe & bailing moe (#10749)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-09-26 15:18:41 +08:00
committed by GitHub
parent 6c2c467d77
commit a5095d6262
4 changed files with 105 additions and 34 deletions

View File

@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn(q, k, v, forward_batch)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
context_layer = self.attn(
q,
k,
v,
forward_batch,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
attn_output, _ = self.dense(context_layer)
return attn_output

View File

@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import (
LazyValue,
add_prefix,
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
# TODO maybe move to a model-common utils
def _create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
class GptOssAttention(nn.Module):
def __init__(
self,
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
q,
k,
fused_set_kv_buffer_arg=(
_create_fused_set_kv_buffer_arg(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if _enable_fused_set_kv_buffer(forward_batch)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output

View File

@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import (
add_prefix,
is_cuda,
@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state)
attn_output = self.attn(
*inner_state,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output

View File

@@ -0,0 +1,51 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
def create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)