Add get weights by parameter name for llama (#2266)
This commit is contained in:
@@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqInput:
|
||||
name: str
|
||||
truncate_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqOutput:
|
||||
parameter: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
# The request id
|
||||
|
||||
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -511,6 +513,9 @@ class Scheduler:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1373,6 +1378,10 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
|
||||
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -454,6 +456,23 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.get_weights_by_name_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.get_weights_by_name_result
|
||||
return result.parameter
|
||||
else:
|
||||
self.get_weights_by_name_tmp = []
|
||||
result = await self.get_weights_by_name_result
|
||||
all_parameters = [r.parameter for r in result]
|
||||
return all_parameters
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -527,6 +546,7 @@ class TokenizerManager:
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
GetWeightsByNameReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
@@ -538,6 +558,16 @@ class TokenizerManager:
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
else:
|
||||
self.get_weights_by_name_tmp.append(recv_obj)
|
||||
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
||||
self.get_weights_by_name_result.set_result(
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
|
||||
@@ -19,7 +19,10 @@ from typing import Optional
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -160,3 +163,9 @@ class TpModelWorker:
|
||||
recv_req.model_path, recv_req.load_format
|
||||
)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
)
|
||||
return parameter
|
||||
|
||||
@@ -23,7 +23,10 @@ from typing import Optional
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -208,6 +211,9 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
|
||||
@@ -20,13 +20,10 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
@@ -403,6 +400,23 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
||||
|
||||
Only used for unit test with an unoptimized performance.
|
||||
For optimized performance, please use torch.save and torch.load.
|
||||
"""
|
||||
# TODO: (chenyang) Add support for Qwen models.
|
||||
try:
|
||||
return self.model.get_weights_by_name(
|
||||
name, truncate_size, tp_size=self.tp_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error when getting parameter {name}: {e}")
|
||||
return None
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return params_mapping.get(name, name)
|
||||
|
||||
def get_module_name_from_weight_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
embed_tokens_weight = None
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
load_tie_word_embeddings = (
|
||||
@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
if embed_tokens_weight is not None:
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
||||
|
||||
Only used for unit test with an unoptimized performance.
|
||||
For optimized performance, please use torch.save and torch.load.
|
||||
"""
|
||||
try:
|
||||
mapped_name = name
|
||||
mapped_shard_id = None
|
||||
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
mapped_name = name.replace(weight_name, param_name)
|
||||
mapped_shard_id = shard_id
|
||||
break
|
||||
params_dict = dict(self.named_parameters())
|
||||
if mapped_name in params_dict:
|
||||
param = params_dict[mapped_name]
|
||||
if mapped_shard_id is not None:
|
||||
if mapped_shard_id in ["q", "k", "v"]:
|
||||
num_heads = self.config.num_attention_heads // tp_size
|
||||
num_kv_heads = self.config.num_key_value_heads // tp_size
|
||||
head_dim = (
|
||||
self.config.hidden_size // self.config.num_attention_heads
|
||||
)
|
||||
if mapped_shard_id == "q":
|
||||
offset = 0
|
||||
size = num_heads * head_dim
|
||||
elif mapped_shard_id == "k":
|
||||
offset = num_heads * head_dim
|
||||
size = num_kv_heads * head_dim
|
||||
elif mapped_shard_id == "v":
|
||||
offset = (num_heads + num_kv_heads) * head_dim
|
||||
size = num_kv_heads * head_dim
|
||||
weight = param.data.narrow(0, offset, size)
|
||||
elif mapped_shard_id in [0, 1]:
|
||||
intermediate_size = self.config.intermediate_size
|
||||
hidden_size = self.config.hidden_size
|
||||
slice_size = intermediate_size // tp_size
|
||||
if mapped_shard_id == 0: # gate_proj
|
||||
offset = 0
|
||||
size = slice_size
|
||||
elif mapped_shard_id == 1: # up_proj
|
||||
offset = slice_size
|
||||
size = slice_size
|
||||
|
||||
weight = param.data.narrow(0, offset, size)
|
||||
else:
|
||||
weight = param.data
|
||||
else:
|
||||
weight = param.data
|
||||
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
||||
gathered_weights = [
|
||||
torch.zeros_like(weight) for _ in range(tp_size)
|
||||
]
|
||||
torch.distributed.all_gather(gathered_weights, weight)
|
||||
weight = torch.cat(gathered_weights, dim=1)
|
||||
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Get model parameter by name."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
if ret is None:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": "Get parameter by name failed"}},
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(ret, status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
@time_func_latency
|
||||
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Handle a get parameter by name request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
@@ -938,3 +969,8 @@ class Engine:
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(get_weights_by_name_request(obj, None))
|
||||
|
||||
Reference in New Issue
Block a user