75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
@dataclass
|
|
class LoRAMapping:
|
|
index_mapping: tuple[int, ...]
|
|
prompt_mapping: tuple[int, ...]
|
|
is_prefill: bool = False
|
|
|
|
def __post_init__(self):
|
|
self.index_mapping = tuple(self.index_mapping)
|
|
self.prompt_mapping = tuple(self.prompt_mapping)
|
|
|
|
|
|
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
|
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
|
"""Returns the device for where to place the LoRA tensors."""
|
|
# unquantizedLinear
|
|
if hasattr(base_layer, "weight"):
|
|
return base_layer.weight.device
|
|
# Compressed Tensor
|
|
elif hasattr(base_layer, "weight_packed"):
|
|
return base_layer.weight_packed.device
|
|
# GPTQ/AWQ
|
|
elif hasattr(base_layer, "qweight"):
|
|
return base_layer.qweight.device
|
|
# HQQ marlin
|
|
elif hasattr(base_layer, "W_q"):
|
|
return base_layer.W_q.device
|
|
# MoE layer
|
|
elif hasattr(base_layer, "w2_weight"):
|
|
return base_layer.w2_weight.device
|
|
# MoE Compressed Tensor
|
|
elif hasattr(base_layer, "w2_weight_packed"):
|
|
return base_layer.w2_weight_packed.device
|
|
# MoE GPTQ/AWQ/GGUF
|
|
elif hasattr(base_layer, "w2_qweight"):
|
|
return base_layer.w2_qweight.device
|
|
else:
|
|
raise ValueError(f"Unsupported base layer: {base_layer}")
|
|
|
|
|
|
def _not_fully_sharded_can_replace(can_replace):
|
|
"""
|
|
decorator which adds the condition of not using fully sharded loras
|
|
intended to wrap can_replace_layer()
|
|
"""
|
|
|
|
def dec(*args, **kwargs):
|
|
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
|
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
|
|
return can_replace(*args, **kwargs) and condition
|
|
|
|
return dec
|
|
|
|
|
|
def _fully_sharded_can_replace(can_replace):
|
|
"""
|
|
decorator which adds the condition of fully sharded loras
|
|
intended to wrap can_replace_layer()
|
|
"""
|
|
|
|
def dec(*args, **kwargs):
|
|
return (
|
|
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
|
|
)
|
|
|
|
return dec
|