Fuse writing KV buffer into rope kernel (part 2: srt) (#9014)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/pr-test-pd-router.yml
vendored
2
.github/workflows/pr-test-pd-router.yml
vendored
@@ -119,7 +119,7 @@ jobs:
|
|||||||
python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages
|
python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages
|
||||||
python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.5
|
python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.5
|
||||||
python3 -m pip --no-cache-dir install --user --force-reinstall genai-bench==0.0.1
|
python3 -m pip --no-cache-dir install --user --force-reinstall genai-bench==0.0.1
|
||||||
python3 -m pip --no-cache-dir install sgl-kernel==0.3.3
|
python3 -m pip --no-cache-dir install sgl-kernel==0.3.4
|
||||||
|
|
||||||
- name: Build and install sgl-router
|
- name: Build and install sgl-router
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.3.3",
|
"sgl-kernel==0.3.4",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
|||||||
@@ -655,7 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.3.3",
|
"0.3.4",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
||||||
apply_rope_with_cos_sin_cache_inplace(
|
apply_rope_with_cos_sin_cache_inplace(
|
||||||
@@ -231,8 +232,17 @@ class RotaryEmbedding(CustomOp):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
cos_sin_cache=self.cos_sin_cache,
|
cos_sin_cache=self.cos_sin_cache,
|
||||||
is_neox=self.is_neox_style,
|
is_neox=self.is_neox_style,
|
||||||
|
# Compatible with old sgl-kernel
|
||||||
|
**(
|
||||||
|
dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
|
||||||
|
if fused_set_kv_buffer_arg is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert (
|
||||||
|
fused_set_kv_buffer_arg is None
|
||||||
|
), "save kv cache is not supported for vllm_rotary_embedding."
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
self.vllm_rotary_embedding(
|
self.vllm_rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@@ -66,10 +66,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
|
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
|
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel import FusedSetKVBufferArg
|
||||||
|
|
||||||
|
|
||||||
class GptOssConfig(PretrainedConfig):
|
class GptOssConfig(PretrainedConfig):
|
||||||
model_type = "gpt_oss"
|
model_type = "gpt_oss"
|
||||||
|
|
||||||
@@ -196,6 +201,32 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_fused_set_kv_buffer():
|
||||||
|
return _is_cuda
|
||||||
|
|
||||||
|
|
||||||
|
# TODO maybe move to a model-common utils
|
||||||
|
def _create_fused_set_kv_buffer_arg(
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
layer_id = layer.layer_id
|
||||||
|
token_to_kv_pool = forward_batch.token_to_kv_pool
|
||||||
|
|
||||||
|
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
||||||
|
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
||||||
|
|
||||||
|
return FusedSetKVBufferArg(
|
||||||
|
value=value,
|
||||||
|
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
||||||
|
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
||||||
|
k_scale=layer.k_scale,
|
||||||
|
v_scale=layer.v_scale,
|
||||||
|
cache_loc=forward_batch.out_cache_loc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GptOssAttention(nn.Module):
|
class GptOssAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -303,7 +334,21 @@ class GptOssAttention(nn.Module):
|
|||||||
return hidden_states, forward_batch, None
|
return hidden_states, forward_batch, None
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
|
||||||
|
q, k = self.rotary_emb(
|
||||||
|
positions,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
fused_set_kv_buffer_arg=(
|
||||||
|
_create_fused_set_kv_buffer_arg(
|
||||||
|
value=v,
|
||||||
|
layer=self.attn,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
if _enable_fused_set_kv_buffer()
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
inner_state = q, k, v, forward_batch
|
inner_state = q, k, v, forward_batch
|
||||||
return None, forward_batch, inner_state
|
return None, forward_batch, inner_state
|
||||||
|
|
||||||
@@ -311,7 +356,11 @@ class GptOssAttention(nn.Module):
|
|||||||
hidden_states, forward_batch, inner_state = intermediate_state
|
hidden_states, forward_batch, inner_state = intermediate_state
|
||||||
if inner_state is None:
|
if inner_state is None:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
attn_output = self.attn(
|
||||||
|
*inner_state,
|
||||||
|
sinks=self.sinks,
|
||||||
|
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
||||||
|
)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user