From 82653f66228efa1973205feb13fc6ecfc7a7ad7a Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 6 May 2025 01:32:33 +0800 Subject: [PATCH] feat: Add a unified merge_state API (#5428) --- .../srt/layers/attention/merge_state.py | 46 +++++++++ .../attention/triton_ops/merge_state.py | 96 +++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 python/sglang/srt/layers/attention/merge_state.py create mode 100644 python/sglang/srt/layers/attention/triton_ops/merge_state.py diff --git a/python/sglang/srt/layers/attention/merge_state.py b/python/sglang/srt/layers/attention/merge_state.py new file mode 100644 index 000000000..245418a9d --- /dev/null +++ b/python/sglang/srt/layers/attention/merge_state.py @@ -0,0 +1,46 @@ +from typing import Optional, Tuple + +import torch +from sgl_kernel import merge_state_v2 + +from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + + +# Automatically fallback to the Triton kernel in some cases +# (e.g., for AMD GPUs, when the head dimension is not a multiple +# of 4 or 8, and in FP8 precision) +def _supported_dtypes(o: torch.Tensor) -> bool: + return o.dtype in [torch.float32, torch.half, torch.bfloat16] + + +def _supported_headdim(o: torch.Tensor) -> bool: + headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + if o.dtype == torch.float32: + return headdim % 4 == 0 + return headdim % 8 == 0 + + +def merge_state( + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output: Optional[torch.Tensor] = None, + output_lse: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if ( + _is_cuda + and _supported_dtypes(prefix_output) + and _supported_headdim(prefix_output) + ): + return merge_state_v2( + prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse + ) + else: + # Fallback to Triton kernel + return merge_state_triton( + prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse + ) diff --git a/python/sglang/srt/layers/attention/triton_ops/merge_state.py b/python/sglang/srt/layers/attention/triton_ops/merge_state.py new file mode 100644 index 000000000..955097fc9 --- /dev/null +++ b/python/sglang/srt/layers/attention/triton_ops/merge_state.py @@ -0,0 +1,96 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def merge_state_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged + output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a + prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b + suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx) + s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx) + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = tl.exp(p_lse) + tl.exp(s_lse) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + token_idx * num_heads + head_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) + + +def merge_state_triton( + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output: Optional[torch.Tensor] = None, + output_lse: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Avoid creating new tensors if they are already provided + if output is None: + output = torch.empty_like(prefix_output) + if output_lse is None: + output_lse = torch.empty_like(prefix_lse) + + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + merge_state_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + return output, output_lse