2026-01-17 18:38:23 +08:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
2026-02-12 18:13:00 +08:00
|
|
|
import kunlun_ops
|
2026-01-17 18:38:23 +08:00
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_attn_states(
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
prefix_output: torch.Tensor,
|
|
|
|
|
prefix_lse: torch.Tensor,
|
|
|
|
|
suffix_output: torch.Tensor,
|
|
|
|
|
suffix_lse: torch.Tensor,
|
|
|
|
|
output_lse: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
|
2026-02-12 18:13:00 +08:00
|
|
|
return kunlun_ops.attention_merge_stage(
|
2026-01-17 18:38:23 +08:00
|
|
|
prefix_output,
|
|
|
|
|
prefix_lse,
|
|
|
|
|
suffix_output,
|
|
|
|
|
suffix_lse,
|
|
|
|
|
output,
|
|
|
|
|
output_lse
|
|
|
|
|
)
|