Files
2026-03-10 13:31:25 +08:00

45 lines
1.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
raise e
def bgmv_shrink(inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0) -> None:
ipex.llm.functional.bgmv_shrink(inputs, lora_a_weights, output_tensor,
lora_indices_tensor, scaling)
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True) -> None:
ipex.llm.functional.bgmv_expand(inputs, lora_b_weights, output_tensor,
lora_indices_tensor, add_inputs)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True) -> None:
ipex.llm.functional.bgmv_expand_slice(inputs, lora_b_weights,
output_tensor, lora_indices_tensor,
slice_offset, slice_size, add_inputs)