Sync from v0.13
This commit is contained in:
74
vllm/lora/layers/utils.py
Normal file
74
vllm/lora/layers/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
|
||||
|
||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||
"""Returns the device for where to place the LoRA tensors."""
|
||||
# unquantizedLinear
|
||||
if hasattr(base_layer, "weight"):
|
||||
return base_layer.weight.device
|
||||
# Compressed Tensor
|
||||
elif hasattr(base_layer, "weight_packed"):
|
||||
return base_layer.weight_packed.device
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(base_layer, "qweight"):
|
||||
return base_layer.qweight.device
|
||||
# HQQ marlin
|
||||
elif hasattr(base_layer, "W_q"):
|
||||
return base_layer.W_q.device
|
||||
# MoE layer
|
||||
elif hasattr(base_layer, "w2_weight"):
|
||||
return base_layer.w2_weight.device
|
||||
# MoE Compressed Tensor
|
||||
elif hasattr(base_layer, "w2_weight_packed"):
|
||||
return base_layer.w2_weight_packed.device
|
||||
# MoE GPTQ/AWQ/GGUF
|
||||
elif hasattr(base_layer, "w2_qweight"):
|
||||
return base_layer.w2_qweight.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _not_fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of not using fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
||||
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
|
||||
return can_replace(*args, **kwargs) and condition
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (
|
||||
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
|
||||
)
|
||||
|
||||
return dec
|
||||
Reference in New Issue
Block a user