Add get weights by parameter name for llama (#2266)

This commit is contained in:
Chayenne
2024-11-29 23:36:38 -08:00
committed by GitHub
parent 7d5d1d3d29
commit 7d1485d376
12 changed files with 337 additions and 17 deletions

View File

@@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput:
message: str
@dataclass
class GetWeightsByNameReqInput:
name: str
truncate_size: int = 100
@dataclass
class GetWeightsByNameReqOutput:
parameter: list
@dataclass
class AbortReq:
# The request id

View File

@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -511,6 +513,9 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj(
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
@@ -1373,6 +1378,10 @@ class Scheduler:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")

View File

@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -454,6 +456,23 @@ class TokenizerManager:
else:
return False, "Another update is in progress. Please try again later."
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.get_weights_by_name_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.get_weights_by_name_result
return result.parameter
else:
self.get_weights_by_name_tmp = []
result = await self.get_weights_by_name_result
all_parameters = [r.parameter for r in result]
return all_parameters
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
@@ -527,6 +546,7 @@ class TokenizerManager:
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
GetWeightsByNameReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
@@ -538,6 +558,16 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
else:
self.get_weights_by_name_tmp.append(recv_obj)
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id

View File

@@ -19,7 +19,10 @@ from typing import Optional
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -160,3 +163,9 @@ class TpModelWorker:
recv_req.model_path, recv_req.load_format
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
)
return parameter

View File

@@ -23,7 +23,10 @@ from typing import Optional
import psutil
import torch
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
@@ -208,6 +211,9 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))

View File

@@ -20,13 +20,10 @@ import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
from typing import Optional, Type
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
@@ -403,6 +400,23 @@ class ModelRunner:
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
# TODO: (chenyang) Add support for Qwen models.
try:
return self.model.get_weights_by_name(
name, truncate_size, tp_size=self.tp_size
)
except Exception as e:
logger.error(f"Error when getting parameter {name}: {e}")
return None
def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,

View File

@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
logger = logging.getLogger(__name__)
class LlamaMLP(nn.Module):
def __init__(
@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
@torch.no_grad()
def forward(
@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module):
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embed_tokens_weight)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try:
mapped_name = name
mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping:
if weight_name in name:
mapped_name = name.replace(weight_name, param_name)
mapped_shard_id = shard_id
break
params_dict = dict(self.named_parameters())
if mapped_name in params_dict:
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.intermediate_size
hidden_size = self.config.hidden_size
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else:
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else:
return None
except Exception as e:
logger.error(
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
)
return None
class Phi3ForCausalLM(LlamaForCausalLM):
pass

View File

@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
OpenSessionReqInput,
UpdateWeightFromDiskReqInput,
)
@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
)
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None:
return ORJSONResponse(
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
else:
return ORJSONResponse(ret, status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
@time_func_latency
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
"""Handle a get parameter by name request."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
@@ -938,3 +969,8 @@ class Engine:
async def get_server_info(self):
return await _get_server_info()
def get_weights_by_name(self, name, truncate_size=100):
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop()
return loop.run_until_complete(get_weights_by_name_request(obj, None))