forked from EngineX-MetaX/enginex-c_series-vllm
[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
16
vllm/lora/ops/torch_ops/__init__.py
Normal file
16
vllm/lora/ops/torch_ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
|
||||
from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
|
||||
sgmv_expand, sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_expand_slice",
|
||||
"sgmv_shrink",
|
||||
]
|
||||
119
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
119
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
# LoRA adapter and model may add different amounts of padding to output
|
||||
common_len = min(outputs.shape[1], output_tensor.shape[1])
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :common_len] += outputs[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:, :common_len] = outputs[:limit, :common_len]
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
||||
scaling)
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
slice_offset, slice_size, add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
|
||||
else:
|
||||
output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
|
||||
Reference in New Issue
Block a user