From a5095d62623f661fc1f9b72a1f0bb59f0328ab90 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Fri, 26 Sep 2025 15:18:41 +0800 Subject: [PATCH] Fuse write kv buffer into rope for qwen3 moe & bailing moe (#10749) Co-authored-by: luoyuan.luo --- python/sglang/srt/models/bailing_moe.py | 27 ++++++++++++- python/sglang/srt/models/gpt_oss.py | 37 ++++-------------- python/sglang/srt/models/qwen3_moe.py | 24 +++++++++++- python/sglang/srt/models/utils.py | 51 +++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 34 deletions(-) create mode 100644 python/sglang/srt/models/utils.py diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 2d1929ead..b6063aa2c 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import ( + create_fused_set_kv_buffer_arg, + enable_fused_set_kv_buffer, +) from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers LoraConfig = None @@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_qk_norm: q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) - context_layer = self.attn(q, k, v, forward_batch) + q, k = self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + else None + ), + ) + context_layer = self.attn( + q, + k, + v, + forward_batch, + save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), + ) attn_output, _ = self.dense(context_layer) return attn_output diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 7231a5d75..982400514 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import ( + create_fused_set_kv_buffer_arg, + enable_fused_set_kv_buffer, +) from sglang.srt.utils import ( LazyValue, add_prefix, @@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module): return ans -def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch): - """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" - return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 - - -# TODO maybe move to a model-common utils -def _create_fused_set_kv_buffer_arg( - value: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, -): - layer_id = layer.layer_id - token_to_kv_pool = forward_batch.token_to_kv_pool - - k_buffer = token_to_kv_pool.get_key_buffer(layer_id) - v_buffer = token_to_kv_pool.get_value_buffer(layer_id) - - return FusedSetKVBufferArg( - value=value, - k_buffer=k_buffer.view(k_buffer.shape[0], -1), - v_buffer=v_buffer.view(v_buffer.shape[0], -1), - k_scale=layer.k_scale, - v_scale=layer.v_scale, - cache_loc=forward_batch.out_cache_loc, - ) - - class GptOssAttention(nn.Module): def __init__( self, @@ -337,12 +314,12 @@ class GptOssAttention(nn.Module): q, k, fused_set_kv_buffer_arg=( - _create_fused_set_kv_buffer_arg( + create_fused_set_kv_buffer_arg( value=v, layer=self.attn, forward_batch=forward_batch, ) - if _enable_fused_set_kv_buffer(forward_batch) + if enable_fused_set_kv_buffer(forward_batch) else None ), ) @@ -356,7 +333,7 @@ class GptOssAttention(nn.Module): attn_output = self.attn( *inner_state, sinks=self.sinks, - save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch), + save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), ) output, _ = self.o_proj(attn_output) return output diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9d92a3b13..d9ac4684e 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel +from sglang.srt.models.utils import ( + create_fused_set_kv_buffer_arg, + enable_fused_set_kv_buffer, +) from sglang.srt.utils import ( add_prefix, is_cuda, @@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + else None + ), + ) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states - attn_output = self.attn(*inner_state) + attn_output = self.attn( + *inner_state, + save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), + ) output, _ = self.o_proj(attn_output) return output diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py new file mode 100644 index 000000000..f4c2a0e3e --- /dev/null +++ b/python/sglang/srt/models/utils.py @@ -0,0 +1,51 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + + +if _is_cuda: + from sgl_kernel import FusedSetKVBufferArg + + +def enable_fused_set_kv_buffer(forward_batch: ForwardBatch): + """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" + return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 + + +def create_fused_set_kv_buffer_arg( + value: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, +): + layer_id = layer.layer_id + token_to_kv_pool = forward_batch.token_to_kv_pool + + k_buffer = token_to_kv_pool.get_key_buffer(layer_id) + v_buffer = token_to_kv_pool.get_value_buffer(layer_id) + + return FusedSetKVBufferArg( + value=value, + k_buffer=k_buffer.view(k_buffer.shape[0], -1), + v_buffer=v_buffer.view(v_buffer.shape[0], -1), + k_scale=layer.k_scale, + v_scale=layer.v_scale, + cache_loc=forward_batch.out_cache_loc, + )