Sync from v0.13
This commit is contained in:
47
vllm/attention/ops/merge_attn_states.py
Normal file
47
vllm/attention/ops/merge_attn_states.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# is not support for FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and supported_dtypes(output)
|
||||
and supported_headdim(output)
|
||||
):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
|
||||
return merge_attn_states(
|
||||
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||
)
|
||||
else:
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
|
||||
return merge_attn_states(
|
||||
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
|
||||
)
|
||||
Reference in New Issue
Block a user