44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
def merge_attn_states(
|
|
output: torch.Tensor,
|
|
prefix_output: torch.Tensor,
|
|
prefix_lse: torch.Tensor,
|
|
suffix_output: torch.Tensor,
|
|
suffix_lse: torch.Tensor,
|
|
output_lse: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
|
|
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
|
# is not support for FP8 dtype, fallback to use Triton kernel.
|
|
def supported_dtypes(o: torch.Tensor) -> bool:
|
|
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
|
|
|
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
|
# kernel load/store 128b(16 bytes) per memory issue within
|
|
# thread. Namely, the headsize(headdim) must be multiple of
|
|
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
|
def supported_headdim(o: torch.Tensor) -> bool:
|
|
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
|
if o.dtype == torch.float32:
|
|
return headdim % 4 == 0
|
|
return headdim % 8 == 0
|
|
|
|
if (current_platform.is_cuda() and supported_dtypes(output)
|
|
and supported_headdim(output)):
|
|
from vllm._custom_ops import merge_attn_states
|
|
return merge_attn_states(output, prefix_output, prefix_lse,
|
|
suffix_output, suffix_lse, output_lse)
|
|
else:
|
|
from vllm.attention.ops.triton_merge_attn_states import (
|
|
merge_attn_states)
|
|
return merge_attn_states(output, prefix_output, prefix_lse,
|
|
suffix_output, suffix_lse, output_lse)
|