Add get weights by parameter name for llama (#2266)

This commit is contained in:
Chayenne
2024-11-29 23:36:38 -08:00
committed by GitHub
parent 7d5d1d3d29
commit 7d1485d376
12 changed files with 337 additions and 17 deletions

View File

@@ -23,7 +23,10 @@ from typing import Optional
import psutil
import torch
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
UpdateWeightFromDiskReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
@@ -208,6 +211,9 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))