Files
sglang/python/sglang/srt/warmup.py
Lianmin Zheng c49484a658 [Auto Sync] Update scheduler_profiler_mixin.py, rpd_utils.p... (20250916) (#10494)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: cctry <shiyang@x.ai>
2025-09-16 17:02:20 -07:00

61 lines
1.8 KiB
Python

from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List
import numpy as np
import tqdm
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
from sglang.srt.managers.io_struct import GenerateReqInput
if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__file__)
_warmup_registry = {}
def warmup(name: str):
def decorator(fn):
_warmup_registry[name] = fn
return fn
return decorator
async def execute_warmups(
disaggregation_mode: str,
warmup_names: List[str],
tokenizer_manager: TokenizerManager,
):
for warmup_name in warmup_names:
if warmup_name not in _warmup_registry:
logger.warning(f"Could not find custom warmup {warmup_name}")
continue
logger.info(f"Running warmup {warmup_name}")
await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)
@warmup("voice_chat")
async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):
# this warms up the fused_moe triton kernels and caches them
# if we don't do this we break real time inference for voice chat
for i in tqdm.trange(1, 512):
size = i * 4
generate_req_input = GenerateReqInput(
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
sampling_params={
"max_new_tokens": 30,
"temperature": 0.8,
"stop_token_ids": [1],
"min_p": 0.0,
},
)
if disaggregation_mode != "null":
generate_req_input.bootstrap_room = 0
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()