Add update_weights_from_tensor (#2631)

This commit is contained in:
fzyzcjy
2024-12-29 05:30:27 +08:00
committed by GitHub
parent 7863e4368a
commit fd28640dc5
10 changed files with 120 additions and 1 deletions

View File

@@ -22,7 +22,7 @@ import warnings
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import Callable, Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import psutil
import setproctitle
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
@@ -478,6 +480,11 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
success, message = self.update_weights_from_tensor(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromTensorReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
@@ -1458,6 +1465,17 @@ class Scheduler:
logger.error(message)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter