Piecewise CUDA Graph Support & Torch Compile Backend (#10062)

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
Yuwei An
2025-10-11 20:55:57 -07:00
committed by GitHub
parent 20a6c0a63d
commit 4ac8e09df0
21 changed files with 2706 additions and 19 deletions

View File

@@ -43,11 +43,16 @@ _is_cpu = is_cpu()
_is_xpu = is_xpu()
if _is_cuda:
if _is_flashinfer_available:
from flashinfer.norm import fused_add_rmsnorm
else:
from sgl_kernel import fused_add_rmsnorm
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
# if _is_flashinfer_available:
# from flashinfer.norm import fused_add_rmsnorm
# else:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm

View File

@@ -17,12 +17,18 @@ from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Optional
import torch
from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.compilation.piecewise_context_manager import (
get_forward_context,
)
from sglang.srt.utils import direct_register_custom_op
class AttentionType(Enum):
"""
@@ -105,12 +111,58 @@ class RadixAttention(nn.Module):
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
output = torch.zeros_like(q)
torch.ops.sglang.unified_attention_with_output(
q, k, v, output, save_kv_cache, self.layer_id
)
return output
else:
return forward_batch.attn_backend.forward(
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
save_kv_cache: bool,
layer_id: int,
) -> None:
context = get_forward_context()
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]
ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache
)
assert output.shape == ret.shape
output.copy_(ret)
return
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
save_kv_cache: bool,
layer_id: int,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
)