Add get weights by parameter name for llama (#2266)
This commit is contained in:
@@ -20,13 +20,10 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
@@ -403,6 +400,23 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
||||
|
||||
Only used for unit test with an unoptimized performance.
|
||||
For optimized performance, please use torch.save and torch.load.
|
||||
"""
|
||||
# TODO: (chenyang) Add support for Qwen models.
|
||||
try:
|
||||
return self.model.get_weights_by_name(
|
||||
name, truncate_size, tp_size=self.tp_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error when getting parameter {name}: {e}")
|
||||
return None
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
|
||||
Reference in New Issue
Block a user