Fuse write kv buffer into rope for qwen3 moe & bailing moe (#10749)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.utils import (
|
||||
create_fused_set_kv_buffer_arg,
|
||||
enable_fused_set_kv_buffer,
|
||||
)
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
|
||||
|
||||
LoraConfig = None
|
||||
@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
context_layer = self.attn(q, k, v, forward_batch)
|
||||
q, k = self.rotary_emb(
|
||||
positions,
|
||||
q,
|
||||
k,
|
||||
fused_set_kv_buffer_arg=(
|
||||
create_fused_set_kv_buffer_arg(
|
||||
value=v,
|
||||
layer=self.attn,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if enable_fused_set_kv_buffer(forward_batch)
|
||||
else None
|
||||
),
|
||||
)
|
||||
context_layer = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_batch,
|
||||
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
||||
)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
|
||||
|
||||
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.utils import (
|
||||
create_fused_set_kv_buffer_arg,
|
||||
enable_fused_set_kv_buffer,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
return ans
|
||||
|
||||
|
||||
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
||||
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
||||
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
||||
|
||||
|
||||
# TODO maybe move to a model-common utils
|
||||
def _create_fused_set_kv_buffer_arg(
|
||||
value: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
token_to_kv_pool = forward_batch.token_to_kv_pool
|
||||
|
||||
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
||||
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
||||
|
||||
return FusedSetKVBufferArg(
|
||||
value=value,
|
||||
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
||||
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
cache_loc=forward_batch.out_cache_loc,
|
||||
)
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
|
||||
q,
|
||||
k,
|
||||
fused_set_kv_buffer_arg=(
|
||||
_create_fused_set_kv_buffer_arg(
|
||||
create_fused_set_kv_buffer_arg(
|
||||
value=v,
|
||||
layer=self.attn,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if _enable_fused_set_kv_buffer(forward_batch)
|
||||
if enable_fused_set_kv_buffer(forward_batch)
|
||||
else None
|
||||
),
|
||||
)
|
||||
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
|
||||
attn_output = self.attn(
|
||||
*inner_state,
|
||||
sinks=self.sinks,
|
||||
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
|
||||
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
from sglang.srt.models.utils import (
|
||||
create_fused_set_kv_buffer_arg,
|
||||
enable_fused_set_kv_buffer,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
add_prefix,
|
||||
is_cuda,
|
||||
@@ -412,7 +416,20 @@ class Qwen3MoeAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q, k = self.rotary_emb(
|
||||
positions,
|
||||
q,
|
||||
k,
|
||||
fused_set_kv_buffer_arg=(
|
||||
create_fused_set_kv_buffer_arg(
|
||||
value=v,
|
||||
layer=self.attn,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if enable_fused_set_kv_buffer(forward_batch)
|
||||
else None
|
||||
),
|
||||
)
|
||||
inner_state = q, k, v, forward_batch
|
||||
return None, forward_batch, inner_state
|
||||
|
||||
@@ -420,7 +437,10 @@ class Qwen3MoeAttention(nn.Module):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state)
|
||||
attn_output = self.attn(
|
||||
*inner_state,
|
||||
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
51
python/sglang/srt/models/utils.py
Normal file
51
python/sglang/srt/models/utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2023-2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import FusedSetKVBufferArg
|
||||
|
||||
|
||||
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
||||
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
||||
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def create_fused_set_kv_buffer_arg(
|
||||
value: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
token_to_kv_pool = forward_batch.token_to_kv_pool
|
||||
|
||||
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
||||
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
||||
|
||||
return FusedSetKVBufferArg(
|
||||
value=value,
|
||||
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
||||
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
cache_loc=forward_batch.out_cache_loc,
|
||||
)
|
||||
Reference in New Issue
Block a user