26 lines
618 B
Python
26 lines
618 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import xtorch_ops
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
def merge_attn_states(
|
|
output: torch.Tensor,
|
|
prefix_output: torch.Tensor,
|
|
prefix_lse: torch.Tensor,
|
|
suffix_output: torch.Tensor,
|
|
suffix_lse: torch.Tensor,
|
|
output_lse: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
|
|
return xtorch_ops.attention_merge_stage(
|
|
prefix_output,
|
|
prefix_lse,
|
|
suffix_output,
|
|
suffix_lse,
|
|
output,
|
|
output_lse
|
|
) |