Files
enginex-mthreads-vllm/vllm/lora/lora_model.py
2026-01-19 10:38:50 +08:00

247 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import safetensors
import torch
from vllm.logger import init_logger
from vllm.lora.lora_weights import LoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import (
get_lora_id,
is_base_embeddding_weights,
is_regex_target_modules,
parse_fine_tuned_lora_name,
)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: dict[str, LoRALayerWeights],
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
"""
self.id = lora_model_id
assert lora_model_id > 0, (
f"a valid lora id should be greater than 0, got {self.id}"
)
self.rank = rank
self.loras: dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
if is_base_embeddding_weights(tensor_name):
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
if module_name not in loras:
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper
)
if is_lora_a:
if (
"lora_embedding_A" in tensor_name
and model_vocab_size is not None
and model_vocab_size != tensor.shape[1]
):
raise RuntimeError(
f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
f" with the base model's vocabulary size({model_vocab_size})."
)
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
return cls(lora_model_id, peft_helper.r, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: set[str],
peft_helper: PEFTHelper,
*,
lora_model_id: int | None = None,
device: str = "cuda",
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
tensorizer_config_dict: dict | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
peft_helper: Loaded lora configuration information.
lora_model_id: LoRA model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module):
continue
# Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights
if ".experts" in module_name:
expert_idx = module_name.find(".experts")
expert_suffix = module_name[expert_idx + 1 :]
if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name)
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
if tensorizer_config_dict:
from tensorizer import TensorDeserializer
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(
lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs,
)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it wont error and model will be trained with A, B
# loraified. C wont exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules
):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
lora_file_path = (
lora_bin_file_path
if os.path.isfile(lora_bin_file_path)
else lora_pt_file_path
)
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
return cls.from_lora_tensors(
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
model_vocab_size=model_vocab_size,
weights_mapper=weights_mapper,
)