Add get weights by parameter name for llama (#2266)

This commit is contained in:
Chayenne
2024-11-29 23:36:38 -08:00
committed by GitHub
parent 7d5d1d3d29
commit 7d1485d376
12 changed files with 337 additions and 17 deletions

View File

@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -511,6 +513,9 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj(
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
@@ -1373,6 +1378,10 @@ class Scheduler:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")