[API] add get memory pool size (#1760)
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
GetMemPoolSizeReqOutput,
|
||||||
UpdateWeightReqOutput,
|
UpdateWeightReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
||||||
@@ -111,6 +112,9 @@ class DetokenizerManager:
|
|||||||
# If it is a weight update request, no detokenization is needed.
|
# If it is a weight update request, no detokenization is needed.
|
||||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||||
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||||
|
continue
|
||||||
elif self.tokenizer is None:
|
elif self.tokenizer is None:
|
||||||
# If the tokenizer is skipped, no detokenization is needed
|
# If the tokenizer is skipped, no detokenization is needed
|
||||||
self.send_to_tokenizer.send_pyobj(recv_obj)
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||||
|
|||||||
@@ -353,3 +353,13 @@ class AbortReq:
|
|||||||
class ProfileReq(Enum):
|
class ProfileReq(Enum):
|
||||||
START_PROFILE = 1
|
START_PROFILE = 1
|
||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetMemPoolSizeReq:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetMemPoolSizeReqOutput:
|
||||||
|
size: int
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
|
GetMemPoolSizeReq,
|
||||||
|
GetMemPoolSizeReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -363,6 +365,10 @@ class Scheduler:
|
|||||||
self.start_profile()
|
self.start_profile()
|
||||||
else:
|
else:
|
||||||
self.stop_profile()
|
self.stop_profile()
|
||||||
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
||||||
|
self.send_to_detokenizer.send_pyobj(
|
||||||
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
GetMemPoolSizeReq,
|
||||||
|
GetMemPoolSizeReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
RewardReqInput,
|
RewardReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
@@ -531,6 +533,15 @@ class TokenizerManager:
|
|||||||
req = ProfileReq.STOP_PROFILE
|
req = ProfileReq.STOP_PROFILE
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
async def get_memory_pool_size(self):
|
||||||
|
if self.to_create_loop:
|
||||||
|
self.create_handle_loop()
|
||||||
|
|
||||||
|
req = GetMemPoolSizeReq()
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
self.mem_pool_size = asyncio.Future()
|
||||||
|
return await self.mem_pool_size
|
||||||
|
|
||||||
async def update_weights(
|
async def update_weights(
|
||||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
@@ -590,6 +601,9 @@ class TokenizerManager:
|
|||||||
if isinstance(recv_obj, UpdateWeightReqOutput):
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
||||||
self.model_update_result.set_result(recv_obj)
|
self.model_update_result.set_result(recv_obj)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||||
|
self.mem_pool_size.set_result(recv_obj)
|
||||||
|
continue
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||||
|
|||||||
@@ -172,6 +172,18 @@ async def stop_profile():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
||||||
|
async def get_memory_pool_size():
|
||||||
|
"""Get the memory pool size in number of tokens"""
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.get_memory_pool_size()
|
||||||
|
return ret.size
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights")
|
@app.post("/update_weights")
|
||||||
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
||||||
"""Update the weights inplace without re-launching the server."""
|
"""Update the weights inplace without re-launching the server."""
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_get_memory_pool_size(self):
|
||||||
|
response = requests.post(self.base_url + "/get_memory_pool_size")
|
||||||
|
assert isinstance(response.json(), int)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user