update
This commit is contained in:
112
vllm/lora/layers/utils.py
Normal file
112
vllm/lora/layers/utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
|
||||
class LoRAMappingType(Enum):
|
||||
LANGUAGE = 1
|
||||
TOWER = 2
|
||||
CONNECTOR = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
type: LoRAMappingType = LoRAMappingType.LANGUAGE
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
|
||||
|
||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||
"""Returns the device for where to place the LoRA tensors."""
|
||||
# unquantizedLinear
|
||||
if hasattr(base_layer, "weight"):
|
||||
return base_layer.weight.device
|
||||
# Compressed Tensor
|
||||
elif hasattr(base_layer, "weight_packed"):
|
||||
return base_layer.weight_packed.device
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(base_layer, "qweight"):
|
||||
return base_layer.qweight.device
|
||||
# MoE layer
|
||||
elif hasattr(base_layer, "w2_weight"):
|
||||
return base_layer.w2_weight.device
|
||||
# MoE Compressed Tensor
|
||||
elif hasattr(base_layer, "w2_weight_packed"):
|
||||
return base_layer.w2_weight_packed.device
|
||||
# MoE GPTQ/AWQ/GGUF
|
||||
elif hasattr(base_layer, "w2_qweight"):
|
||||
return base_layer.w2_qweight.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _not_fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of not using fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
||||
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
|
||||
return can_replace(*args, **kwargs) and condition
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (
|
||||
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
|
||||
)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def try_get_optimal_moe_lora_config(
|
||||
op_type: str,
|
||||
w1_shape: tuple[int, ...],
|
||||
w2_shape: tuple[int, ...],
|
||||
rank: int,
|
||||
top_k: int,
|
||||
dtype: str | None,
|
||||
M: int,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int | None]:
|
||||
config = try_get_optimal_moe_config(
|
||||
w1_shape, w2_shape, top_k, dtype, M, block_shape
|
||||
).copy()
|
||||
if op_type in [
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
]:
|
||||
config["BLOCK_SIZE_N"] = min(
|
||||
config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank)
|
||||
)
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]:
|
||||
config["BLOCK_SIZE_K"] = max(
|
||||
16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank))
|
||||
)
|
||||
return config
|
||||
Reference in New Issue
Block a user