Sync from v0.13
This commit is contained in:
20
vllm/lora/ops/torch_ops/__init__.py
Normal file
20
vllm/lora/ops/torch_ops/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.torch_ops.lora_ops import (
|
||||
bgmv_expand, # noqa: F401
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_expand_slice",
|
||||
"sgmv_shrink",
|
||||
]
|
||||
128
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
128
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
# LoRA adapter and model may add different amounts of padding to output
|
||||
common_len = min(outputs.shape[1], output_tensor.shape[1])
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :common_len] += outputs[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:, :common_len] = outputs[:limit, :common_len]
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling)
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
|
||||
|
||||
|
||||
def sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
exploded_indices,
|
||||
slice_offset,
|
||||
slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
|
||||
else:
|
||||
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
|
||||
Reference in New Issue
Block a user