Files
sglang/python/sglang/srt/layers/attention/merge_state.py
2025-05-05 10:32:33 -07:00

47 lines
1.4 KiB
Python

from typing import Optional, Tuple
import torch
from sgl_kernel import merge_state_v2
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
# Automatically fallback to the Triton kernel in some cases
# (e.g., for AMD GPUs, when the head dimension is not a multiple
# of 4 or 8, and in FP8 precision)
def _supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
def _supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0
def merge_state(
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output: Optional[torch.Tensor] = None,
output_lse: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if (
_is_cuda
and _supported_dtypes(prefix_output)
and _supported_headdim(prefix_output)
):
return merge_state_v2(
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
)
else:
# Fallback to Triton kernel
return merge_state_triton(
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
)