feat: Add a unified merge_state API (#5428)
This commit is contained in:
46
python/sglang/srt/layers/attention/merge_state.py
Normal file
46
python/sglang/srt/layers/attention/merge_state.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sgl_kernel import merge_state_v2
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
|
||||
# Automatically fallback to the Triton kernel in some cases
|
||||
# (e.g., for AMD GPUs, when the head dimension is not a multiple
|
||||
# of 4 or 8, and in FP8 precision)
|
||||
def _supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
|
||||
def _supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
|
||||
def merge_state(
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if (
|
||||
_is_cuda
|
||||
and _supported_dtypes(prefix_output)
|
||||
and _supported_headdim(prefix_output)
|
||||
):
|
||||
return merge_state_v2(
|
||||
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
||||
)
|
||||
else:
|
||||
# Fallback to Triton kernel
|
||||
return merge_state_triton(
|
||||
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
||||
)
|
||||
96
python/sglang/srt/layers/attention/triton_ops/merge_state.py
Normal file
96
python/sglang/srt/layers/attention/triton_ops/merge_state.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_state_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
||||
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
||||
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
||||
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
||||
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
||||
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
||||
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(
|
||||
prefix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
s_out = tl.load(
|
||||
suffix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
p_scale = tl.exp(p_lse) / out_se
|
||||
s_scale = tl.exp(s_lse) / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(
|
||||
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
|
||||
def merge_state_triton(
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if output is None:
|
||||
output = torch.empty_like(prefix_output)
|
||||
if output_lse is None:
|
||||
output_lse = torch.empty_like(prefix_lse)
|
||||
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
merge_state_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
return output, output_lse
|
||||
Reference in New Issue
Block a user