Add update_weights_from_tensor (#2631)

This commit is contained in:
fzyzcjy
2024-12-29 05:30:27 +08:00
committed by GitHub
parent 7863e4368a
commit fd28640dc5
10 changed files with 120 additions and 1 deletions

View File

@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -109,6 +110,7 @@ app.add_middleware(
tokenizer_manager: TokenizerManager = None
scheduler_info: Dict = None
##### Native API endpoints #####
@@ -866,6 +868,14 @@ class Engine:
tokenizer_manager.update_weights_from_distributed(obj, None)
)
def update_weights_from_tensor(self, name, tensor):
"""Update weights from distributed source."""
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
tokenizer_manager.update_weights_from_tensor(obj, None)
)
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)