feat: Add a unified merge_state API (#5428)
This commit is contained in:
46
python/sglang/srt/layers/attention/merge_state.py
Normal file
46
python/sglang/srt/layers/attention/merge_state.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import merge_state_v2
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
|
||||||
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
|
||||||
|
# Automatically fallback to the Triton kernel in some cases
|
||||||
|
# (e.g., for AMD GPUs, when the head dimension is not a multiple
|
||||||
|
# of 4 or 8, and in FP8 precision)
|
||||||
|
def _supported_dtypes(o: torch.Tensor) -> bool:
|
||||||
|
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||||
|
|
||||||
|
|
||||||
|
def _supported_headdim(o: torch.Tensor) -> bool:
|
||||||
|
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
|
if o.dtype == torch.float32:
|
||||||
|
return headdim % 4 == 0
|
||||||
|
return headdim % 8 == 0
|
||||||
|
|
||||||
|
|
||||||
|
def merge_state(
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_lse: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if (
|
||||||
|
_is_cuda
|
||||||
|
and _supported_dtypes(prefix_output)
|
||||||
|
and _supported_headdim(prefix_output)
|
||||||
|
):
|
||||||
|
return merge_state_v2(
|
||||||
|
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to Triton kernel
|
||||||
|
return merge_state_triton(
|
||||||
|
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
||||||
|
)
|
||||||
96
python/sglang/srt/layers/attention/triton_ops/merge_state.py
Normal file
96
python/sglang/srt/layers/attention/triton_ops/merge_state.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def merge_state_kernel(
|
||||||
|
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
||||||
|
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
||||||
|
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
||||||
|
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
||||||
|
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
||||||
|
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
||||||
|
HEAD_SIZE: tl.constexpr,
|
||||||
|
PADDED_HEAD_SIZE: tl.constexpr,
|
||||||
|
OUTPUT_LSE: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_idx = tl.program_id(0)
|
||||||
|
num_tokens = tl.num_programs(0)
|
||||||
|
head_idx = tl.program_id(1)
|
||||||
|
num_heads = tl.num_programs(1)
|
||||||
|
|
||||||
|
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
||||||
|
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
||||||
|
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
||||||
|
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
||||||
|
|
||||||
|
max_lse = tl.maximum(p_lse, s_lse)
|
||||||
|
p_lse = p_lse - max_lse
|
||||||
|
s_lse = s_lse - max_lse
|
||||||
|
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
||||||
|
|
||||||
|
if OUTPUT_LSE:
|
||||||
|
out_lse = tl.log(out_se) + max_lse
|
||||||
|
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
||||||
|
|
||||||
|
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||||
|
head_mask = head_arange < HEAD_SIZE
|
||||||
|
p_out = tl.load(
|
||||||
|
prefix_output
|
||||||
|
+ token_idx * num_heads * HEAD_SIZE
|
||||||
|
+ head_idx * HEAD_SIZE
|
||||||
|
+ head_arange,
|
||||||
|
mask=head_mask,
|
||||||
|
)
|
||||||
|
s_out = tl.load(
|
||||||
|
suffix_output
|
||||||
|
+ token_idx * num_heads * HEAD_SIZE
|
||||||
|
+ head_idx * HEAD_SIZE
|
||||||
|
+ head_arange,
|
||||||
|
mask=head_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
p_scale = tl.exp(p_lse) / out_se
|
||||||
|
s_scale = tl.exp(s_lse) / out_se
|
||||||
|
out = p_out * p_scale + s_out * s_scale
|
||||||
|
tl.store(
|
||||||
|
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
||||||
|
out,
|
||||||
|
mask=head_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_state_triton(
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
suffix_output: torch.Tensor,
|
||||||
|
suffix_lse: torch.Tensor,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_lse: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Avoid creating new tensors if they are already provided
|
||||||
|
if output is None:
|
||||||
|
output = torch.empty_like(prefix_output)
|
||||||
|
if output_lse is None:
|
||||||
|
output_lse = torch.empty_like(prefix_lse)
|
||||||
|
|
||||||
|
num_tokens = output.shape[0]
|
||||||
|
num_query_heads = output.shape[1]
|
||||||
|
head_size = output.shape[2]
|
||||||
|
padded_head_size = triton.next_power_of_2(head_size)
|
||||||
|
|
||||||
|
merge_state_kernel[(num_tokens, num_query_heads)](
|
||||||
|
output,
|
||||||
|
output_lse,
|
||||||
|
prefix_output,
|
||||||
|
prefix_lse,
|
||||||
|
suffix_output,
|
||||||
|
suffix_lse,
|
||||||
|
head_size,
|
||||||
|
padded_head_size,
|
||||||
|
output_lse is not None,
|
||||||
|
)
|
||||||
|
return output, output_lse
|
||||||
Reference in New Issue
Block a user