Add get weights by parameter name for llama (#2266)

This commit is contained in:
Chayenne
2024-11-29 23:36:38 -08:00
committed by GitHub
parent 7d5d1d3d29
commit 7d1485d376
12 changed files with 337 additions and 17 deletions

View File

@@ -20,13 +20,10 @@ import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
from typing import Optional, Type
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
@@ -403,6 +400,23 @@ class ModelRunner:
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
# TODO: (chenyang) Add support for Qwen models.
try:
return self.model.get_weights_by_name(
name, truncate_size, tp_size=self.tp_size
)
except Exception as e:
logger.error(f"Error when getting parameter {name}: {e}")
return None
def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,