Add get weights by parameter name for llama (#2266)
This commit is contained in:
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -511,6 +513,9 @@ class Scheduler:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1373,6 +1378,10 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
|
||||
Reference in New Issue
Block a user