Fuse writing KV buffer into rope kernel (part 2: srt) (#9014)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/pr-test-pd-router.yml
vendored
2
.github/workflows/pr-test-pd-router.yml
vendored
@@ -119,7 +119,7 @@ jobs:
|
||||
python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages
|
||||
python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.5
|
||||
python3 -m pip --no-cache-dir install --user --force-reinstall genai-bench==0.0.1
|
||||
python3 -m pip --no-cache-dir install sgl-kernel==0.3.3
|
||||
python3 -m pip --no-cache-dir install sgl-kernel==0.3.4
|
||||
|
||||
- name: Build and install sgl-router
|
||||
run: |
|
||||
|
||||
@@ -58,7 +58,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.3",
|
||||
"sgl-kernel==0.3.4",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
|
||||
@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.3.3",
|
||||
"0.3.4",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -222,6 +222,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
@@ -231,8 +232,17 @@ class RotaryEmbedding(CustomOp):
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox=self.is_neox_style,
|
||||
# Compatible with old sgl-kernel
|
||||
**(
|
||||
dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
fused_set_kv_buffer_arg is None
|
||||
), "save kv cache is not supported for vllm_rotary_embedding."
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
self.vllm_rotary_embedding(
|
||||
positions,
|
||||
|
||||
@@ -66,10 +66,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import FusedSetKVBufferArg
|
||||
|
||||
|
||||
class GptOssConfig(PretrainedConfig):
|
||||
model_type = "gpt_oss"
|
||||
|
||||
@@ -196,6 +201,32 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
return ans
|
||||
|
||||
|
||||
def _enable_fused_set_kv_buffer():
|
||||
return _is_cuda
|
||||
|
||||
|
||||
# TODO maybe move to a model-common utils
|
||||
def _create_fused_set_kv_buffer_arg(
|
||||
value: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
token_to_kv_pool = forward_batch.token_to_kv_pool
|
||||
|
||||
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
||||
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
||||
|
||||
return FusedSetKVBufferArg(
|
||||
value=value,
|
||||
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
||||
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
cache_loc=forward_batch.out_cache_loc,
|
||||
)
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -303,7 +334,21 @@ class GptOssAttention(nn.Module):
|
||||
return hidden_states, forward_batch, None
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
q, k = self.rotary_emb(
|
||||
positions,
|
||||
q,
|
||||
k,
|
||||
fused_set_kv_buffer_arg=(
|
||||
_create_fused_set_kv_buffer_arg(
|
||||
value=v,
|
||||
layer=self.attn,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if _enable_fused_set_kv_buffer()
|
||||
else None
|
||||
),
|
||||
)
|
||||
inner_state = q, k, v, forward_batch
|
||||
return None, forward_batch, inner_state
|
||||
|
||||
@@ -311,7 +356,11 @@ class GptOssAttention(nn.Module):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
||||
attn_output = self.attn(
|
||||
*inner_state,
|
||||
sinks=self.sinks,
|
||||
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user