diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 9ce110233..ef88f8ac3 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--log-level` | The logging level of all loggers. | info | | `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None | | `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False | -| `--log-requests-level` | 0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output. | 0 | +| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 | | `--show-time-cost` | Show time cost of custom marks. | False | | `--enable-metrics` | Enable log prometheus metrics. | False | | `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None | diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 29168e793..74b5a6711 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -38,6 +38,7 @@ class BenchArgs: output_len: Tuple[int] = (16,) temperature: float = 0.0 return_logprob: bool = False + client_stream_interval: int = 1 input_len_step_percentage: float = 0.0 result_filename: str = "result.jsonl" base_url: str = "" @@ -60,6 +61,11 @@ class BenchArgs: ) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument("--return-logprob", action="store_true") + parser.add_argument( + "--client-stream-interval", + type=int, + default=BenchArgs.client_stream_interval, + ) parser.add_argument( "--input-len-step-percentage", type=float, @@ -120,6 +126,7 @@ def run_one_case( output_len: int, temperature: float, return_logprob: bool, + stream_interval: int, input_len_step_percentage: float, run_name: str, result_filename: str, @@ -168,6 +175,7 @@ def run_one_case( "max_new_tokens": output_len, "ignore_eos": True, "json_schema": json_schema, + "stream_interval": stream_interval, }, "return_logprob": return_logprob, "stream": True, @@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): else: proc, base_url = launch_server_process(server_args) - tokenizer_id = server_args.tokenizer_path or server_args.model_path - tokenizer = get_tokenizer(tokenizer_id) + server_info = requests.get(base_url + "/get_server_info") + tokenizer_path = server_info.json()["tokenizer_path"] + tokenizer = get_tokenizer(tokenizer_path) # warmup if not bench_args.skip_warmup: @@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): output_len=16, temperature=bench_args.temperature, return_logprob=bench_args.return_logprob, + stream_interval=bench_args.client_stream_interval, input_len_step_percentage=bench_args.input_len_step_percentage, run_name="", result_filename="", @@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ol, temperature=bench_args.temperature, return_logprob=bench_args.return_logprob, + stream_interval=bench_args.client_stream_interval, input_len_step_percentage=bench_args.input_len_step_percentage, run_name=bench_args.run_name, result_filename=bench_args.result_filename, @@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ol, temperature=bench_args.temperature, return_logprob=bench_args.return_logprob, + stream_interval=bench_args.client_stream_interval, input_len_step_percentage=bench_args.input_len_step_percentage, run_name=bench_args.run_name, result_filename=bench_args.result_filename, diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 4e606dfa2..04c2202d2 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace): if args.base_url else f"http://{args.host}:{args.port}/generate" ) - args.apply_chat_template = True elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: api_url = ( f"{args.base_url}/v1/completions" diff --git a/python/sglang/srt/configs/internvl.py b/python/sglang/srt/configs/internvl.py index f386c5c21..14b648252 100644 --- a/python/sglang/srt/configs/internvl.py +++ b/python/sglang/srt/configs/internvl.py @@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig): ) if ( rope_scaling_factor is None - or not isinstance(rope_scaling_factor, float) - or not isinstance(rope_scaling_factor, int) + or not isinstance(rope_scaling_factor, (float, int)) or rope_scaling_factor < 1.0 ): raise ValueError( - f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor}" + f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}" ) if isinstance(rope_scaling_factor, int): rope_scaling_factor = float(rope_scaling_factor) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 61aec045e..b82b40d1d 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState): @asynccontextmanager async def lifespan(fast_api_app: FastAPI): - server_args: ServerArgs = fast_api_app.server_args - # Initialize OpenAI serving handlers fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( _global_state.tokenizer_manager, _global_state.template_manager @@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI): _global_state.tokenizer_manager ) + server_args: ServerArgs = fast_api_app.server_args if server_args.warmups is not None: await execute_warmups( - server_args.warmups.split(","), _global_state.tokenizer_manager + server_args.disaggregation_mode, + server_args.warmups.split(","), + _global_state.tokenizer_manager, ) logger.info("Warmup ended") @@ -280,13 +281,17 @@ async def get_model_info(): "model_path": _global_state.tokenizer_manager.model_path, "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, "is_generation": _global_state.tokenizer_manager.is_generation, + "preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params, } return result @app.get("/get_server_info") async def get_server_info(): - internal_states = await _global_state.tokenizer_manager.get_internal_state() + # Returns interna states per DP. + internal_states: List[Dict[Any, Any]] = ( + await _global_state.tokenizer_manager.get_internal_state() + ) return { **dataclasses.asdict(_global_state.tokenizer_manager.server_args), **_global_state.scheduler_info, @@ -300,6 +305,8 @@ async def get_load(): return await _global_state.tokenizer_manager.get_load() +# example usage: +# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}' @app.api_route("/set_internal_state", methods=["POST", "PUT"]) async def set_internal_state(obj: SetInternalStateReq, request: Request): res = await _global_state.tokenizer_manager.set_internal_state(obj) @@ -886,6 +893,15 @@ def launch_server( add_prometheus_middleware(app) enable_func_timer() + image_token_text = None + if ( + tokenizer_manager.image_token_id is not None + and not server_args.skip_tokenizer_init + ): + image_token_text = tokenizer_manager.tokenizer.decode( + [tokenizer_manager.image_token_id] + ) + # Send a warmup request - we will create the thread launch it # in the lifespan after all other warmups have fired. warmup_thread = threading.Thread( @@ -893,7 +909,7 @@ def launch_server( args=( server_args, pipe_finish_writer, - _global_state.tokenizer_manager.image_token_id, + image_token_text, launch_callback, ), ) @@ -1022,9 +1038,10 @@ def _wait_and_warmup( return # Debug print - # logger.info(f"{res.json()=}") + # logger.info(f"warmup request returns: {res.json()=}") logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: pipe_finish_writer.send("ready") diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py index 87d46f6cf..3134e2bc1 100644 --- a/python/sglang/srt/layers/elementwise.py +++ b/python/sglang/srt/layers/elementwise.py @@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip _is_hip = is_hip() + fused_softcap_autotune = triton.autotune( configs=[ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), @@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal assert x.shape == residual.shape and x.dtype == residual.dtype output, mid = torch.empty_like(x), torch.empty_like(x) bs, hidden_dim = x.shape - - min_num_warps = 16 if _is_hip else 32 - if autotune: fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim ) else: + max_warps = 16 if _is_hip else 32 config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( - min( - triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps - ), - 4, + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4 ), } @@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): else: output = torch.empty_like(x) bs, hidden_dim = x.shape - - min_num_warps = 16 if _is_hip else 32 - + max_warps = 16 if _is_hip else 32 config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4 ), } @@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm: return self.rmsnorm2.forward_native(residual), residual +@triton.jit +def experts_combine_kernel( + out_hidden_states, + moe_hidden_states, + mlp_hidden_states, + combine_k: tl.constexpr, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + start_index_mlp = pid * hidden_dim + start_index_rmoe = pid * hidden_dim * combine_k + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + combine_k_offsets = tl.arange(0, combine_k) + + moe_x = tl.load( + moe_hidden_states + + start_index_rmoe + + combine_k_offsets[:, None] * hidden_dim + + offsets[None, :], + mask=mask[None, :], + other=0.0, + ) + moe_x = tl.sum(moe_x, axis=0) + mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0) + combined_x = (moe_x + mlp_x) / 1.4142135623730951 + + tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask) + + +def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None): + assert moe_hidden_states.is_contiguous() + assert mlp_hidden_states.is_contiguous() + + if len(moe_hidden_states.shape) == 2: + combine_k = 1 # pre-combined + else: + combine_k = moe_hidden_states.shape[1] + + if output_buffer is None: + out_hidden_states = torch.empty_like(mlp_hidden_states) + else: + flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1) + assert flat_output_buffer.numel() >= mlp_hidden_states.numel() + out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape( + mlp_hidden_states.shape + ) + + bs, hidden_dim = mlp_hidden_states.shape + + config = { + "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), + "num_warps": max( + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4 + ), + } + + experts_combine_kernel[(bs,)]( + out_hidden_states, + moe_hidden_states, + mlp_hidden_states, + combine_k, + hidden_dim, + **config, + ) + return out_hidden_states + + # gelu on first half of vector @triton.jit def gelu_and_mul_kernel( @@ -400,10 +463,11 @@ def gelu_and_mul_triton( out_scales = scales static_scale = True + max_warps = 16 if _is_hip else 32 config = { # 8 ele per thread (not tuned) "num_warps": max( - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4 + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4 ), } diff --git a/python/sglang/srt/layers/moe/router.py b/python/sglang/srt/layers/moe/router.py index ffa120cad..d78437f7b 100644 --- a/python/sglang/srt/layers/moe/router.py +++ b/python/sglang/srt/layers/moe/router.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -16,6 +16,8 @@ def fused_moe_router_kernel( moe_router_weight_ptr, # input (num_experts, hidden_dim) topk_weights_ptr, # output (bs, topk) topk_ids_ptr, # output (bs, topk) + correction_bias_ptr, + is_correction_bias: tl.constexpr, num_experts: tl.constexpr, topk: tl.constexpr, moe_softcapping: tl.constexpr, @@ -49,6 +51,11 @@ def fused_moe_router_kernel( bottom = exped + 1 logits_softcapped = top / bottom * moe_softcapping + # Add bias after softcapping + if is_correction_bias: + bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts)) + logits_softcapped = logits_softcapped + bias + # topk # assert 1 <= topk <= num_experts @@ -109,6 +116,7 @@ def fused_moe_router_impl( router_weight: torch.Tensor, topk: int, moe_softcapping: float, + correction_bias: Optional[torch.Tensor] = None, ): assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1] bs, hidden_dim = x.shape @@ -117,23 +125,23 @@ def fused_moe_router_impl( # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device) topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) + is_correction_bias = correction_bias is not None - grid = lambda meta: (bs,) - - min_num_warps = 16 if _is_hip else 32 - + max_warps = 16 if _is_hip else 32 config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4 ), } - fused_moe_router_kernel[grid]( + fused_moe_router_kernel[(bs,)]( x, router_weight, topk_weights, topk_ids, + correction_bias, + is_correction_bias=is_correction_bias, num_experts=num_experts, topk=topk, moe_softcapping=moe_softcapping, @@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel( topk_ids_ptr, # output (bs, topk) bs, num_experts: tl.constexpr, - topk: tl.constexpr, # only support topk == 1 + topk: tl.constexpr, # only support topk <= 2 moe_softcapping: tl.constexpr, moe_renormalize: tl.constexpr, # not supported K: tl.constexpr, @@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel( logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping # 5. top1 - cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts - top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1) + arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :] + cond_top1 = arange_block_size_n < num_experts + top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1) top1_v = tl.max( - tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True + tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True ) - invsumexp = 1.0 / tl.sum( - tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1 + top1_invsumexp = 1.0 / tl.sum( + tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1 ) - # 6. store to output - offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - topk_mask = offs_topk < bs - tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask) + # 6. store top1 to output + offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M) + top1_mask = offs_top1 < bs * topk + tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask) tl.store( - topk_weights_ptr + offs_topk, - invsumexp, - mask=topk_mask, + topk_weights_ptr + offs_top1, + top1_invsumexp, + mask=top1_mask, ) + # 7. handle topk == 2 + if topk == 2: + cond_top2 = (arange_block_size_n < num_experts) and ( + arange_block_size_n != top1[:, None] + ) + top2 = tl.argmax( + tl.where(cond_top2, logits_softcapped, float("-inf")), + axis=1, + keep_dims=True, + ) + top2_v = tl.sum( + logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True + ) + top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None] + + # store top2 + offs_top2 = ( + pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1 + ) + top2_mask = offs_top2 < bs * topk + tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask) + tl.store( + topk_weights_ptr + offs_top2, + top2_invsumexp, + mask=top2_mask, + ) + def fused_moe_router_large_bs_impl( x: torch.Tensor, @@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl( assert num_experts <= BLOCK_SIZE_N assert hidden_dim % BLOCK_SIZE_K == 0 - assert topk == 1 + assert topk <= 2 topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) @@ -273,6 +309,7 @@ def fused_moe_router_shim( gating_output, topk, renormalize, + correction_bias: Optional[torch.Tensor] = None, ): assert not renormalize assert ( @@ -286,7 +323,7 @@ def fused_moe_router_shim( BLOCK_SIZE_K = 256 if ( bs >= 512 - and topk == 1 + and topk <= 2 and num_experts <= BLOCK_SIZE_N and hidden_dim % BLOCK_SIZE_K == 0 ): @@ -305,6 +342,7 @@ def fused_moe_router_shim( router_weight=gating_output, topk=topk, moe_softcapping=moe_softcapping, + correction_bias=correction_bias, ) diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py index d331990ff..0dc78edfa 100644 --- a/python/sglang/srt/managers/configure_logging.py +++ b/python/sglang/srt/managers/configure_logging.py @@ -28,7 +28,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str, default="http://localhost:30000") parser.add_argument("--log-requests", action="store_true") - parser.add_argument("--log-requests-level", type=int, default=2) + parser.add_argument("--log-requests-level", type=int, default=3) parser.add_argument( "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 800dfc1fd..27b14a2ce 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -516,9 +516,6 @@ class EmbeddingReqInput: # For cross-encoder requests is_cross_encoder_request: bool = False - def contains_mm_input(self) -> bool: - return has_valid_data(self.image_data) or has_valid_data(self.audio_data) - def normalize_batch_and_arguments(self): # at least one of text, input_ids, or image should be provided if self.text is None and self.input_ids is None and self.image_data is None: @@ -572,6 +569,9 @@ class EmbeddingReqInput: self.rid = uuid.uuid4().hex return self.rid + def contains_mm_input(self) -> bool: + return has_valid_data(self.image_data) or has_valid_data(self.audio_data) + def __getitem__(self, i): if self.is_cross_encoder_request: return EmbeddingReqInput( diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index e421348b1..d05df897f 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -2,12 +2,15 @@ Multi-modality utils """ +import hashlib from abc import abstractmethod from typing import Callable, List, Optional, Tuple +import numpy as np import torch from torch import nn +from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, @@ -678,3 +681,52 @@ def get_multimodal_data_bounds( # Convert valid pairs to tensor valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) return valid_pairs_tensor + + +def data_hash(data) -> int: + hash_bytes = hashlib.sha256(data).digest()[:8] + return int.from_bytes(hash_bytes, byteorder="big", signed=False) + + +def tensor_hash(tensor_list) -> int: + """ + hash a tensor or a tensor list + """ + tensor = tensor_list + if isinstance(tensor_list, list): + tensor_list = flatten_nested_list(tensor_list) + tensor_list = [ + x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list + ] + tensor = torch.concat(tensor_list) + if tensor.is_cuda: + return gpu_tensor_hash(tensor) + tensor = tensor.detach().contiguous() + + if tensor.dtype == torch.bfloat16: + # memoryview() doesn't support PyTorch's BFloat16 dtype + tensor = tensor.float() + + assert isinstance(tensor, torch.Tensor) + if tensor.is_cuda: + # TODO: improve this + tensor_cpu = tensor.cpu() + else: + tensor_cpu = tensor + + mv = memoryview(tensor_cpu.numpy()) + return data_hash(mv.tobytes()) + + +def hash_feature(f): + if isinstance(f, list): + if isinstance(f[0], torch.Tensor): + return tensor_hash(f) + return data_hash(tuple(flatten_nested_list(f))) + elif isinstance(f, np.ndarray): + arr = np.ascontiguousarray(f) + arr_bytes = arr.tobytes() + return data_hash(arr_bytes) + elif isinstance(f, torch.Tensor): + return tensor_hash([f]) + return data_hash(f) diff --git a/python/sglang/srt/managers/multimodal_processor.py b/python/sglang/srt/managers/multimodal_processor.py index faf6576e6..76679358a 100644 --- a/python/sglang/srt/managers/multimodal_processor.py +++ b/python/sglang/srt/managers/multimodal_processor.py @@ -3,7 +3,6 @@ import importlib import inspect import logging import pkgutil -from functools import lru_cache from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.server_args import ServerArgs diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8053c35da..117a82d8a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i import copy import dataclasses -import hashlib import logging import threading from enum import Enum, auto @@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank -from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache @@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [ "max_micro_batch_size", "disable_shared_experts_fusion", "sampling_backend", - "speculative_accept_threshold_acc", "speculative_accept_threshold_single", + "speculative_accept_threshold_acc", "torchao_config", "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", @@ -180,7 +178,9 @@ class Modality(Enum): @dataclasses.dataclass class MultimodalDataItem: """ - A single multimodal data, from a single image/video/audio or others. + One MultimodalDataItem contains all inputs for one modality. + For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem. + One for images and one for audio. We put the common fields first and the model-specific fields last. """ @@ -232,53 +232,7 @@ class MultimodalDataItem: """ Set the pad value after first hashing the data """ - - def data_hash(data) -> int: - hash_bytes = hashlib.sha256(data).digest()[:8] - return int.from_bytes(hash_bytes, byteorder="big", signed=False) - - def tensor_hash(tensor_list) -> int: - """ - hash a tensor or a tensor list - """ - tensor = tensor_list - if isinstance(tensor_list, list): - tensor_list = flatten_nested_list(tensor_list) - tensor_list = [ - x.flatten() if isinstance(x, torch.Tensor) else x - for x in tensor_list - ] - tensor = torch.concat(tensor_list) - if tensor.is_cuda: - return gpu_tensor_hash(tensor) - tensor = tensor.detach().contiguous() - - if tensor.dtype == torch.bfloat16: - # memoryview() doesn't support PyTorch's BFloat16 dtype - tensor = tensor.float() - - assert isinstance(tensor, torch.Tensor) - if tensor.is_cuda: - # TODO: improve this - tensor_cpu = tensor.cpu() - else: - tensor_cpu = tensor - - mv = memoryview(tensor_cpu.numpy()) - return data_hash(mv.tobytes()) - - def hash_feature(f): - if isinstance(f, list): - if isinstance(f[0], torch.Tensor): - return tensor_hash(f) - return data_hash(tuple(flatten_nested_list(f))) - elif isinstance(f, np.ndarray): - arr = np.ascontiguousarray(f) - arr_bytes = arr.tobytes() - return data_hash(arr_bytes) - elif isinstance(f, torch.Tensor): - return tensor_hash([f]) - return data_hash(f) + from sglang.srt.managers.mm_utils import hash_feature if self.precomputed_features is not None: self.hash = hash_feature(self.precomputed_features) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e252f908c..630a9dd2d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -418,14 +418,16 @@ class Scheduler( self.last_decode_stats_tic = time.perf_counter() self.last_prefill_stats_tic = time.perf_counter() self.return_health_check_ct = 0 + self.num_retracted_reqs: int = 0 + self.num_paused_reqs: int = 0 + self.kv_transfer_speed_gb_s: float = 0.0 + self.kv_transfer_latency_ms: float = 0.0 + self.sessions: Dict[str, Session] = {} self.current_stream = torch.get_device_module(self.device).current_stream() if self.device == "cpu": self.current_stream.synchronize = lambda: None # No-op for CPU self.forward_sleep_time = None - # Init session info - self.sessions: Dict[str, Session] = {} - # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size if self.chunked_prefill_size <= 0: # -1 means disable @@ -473,26 +475,12 @@ class Scheduler( t = threading.Thread(target=self.watchdog_thread, daemon=True) t.start() self.parent_process = psutil.Process().parent() + + # Init memory saver, profiler and metric stats self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=server_args.enable_memory_saver ) - - # Init profiler - self.torch_profiler = None - self.torch_profiler_output_dir: Optional[str] = None - self.profiler_activities: Optional[List[str]] = None - self.profile_id: Optional[str] = None - self.profiler_target_forward_ct: Optional[int] = None - self.profiler_target_prefill_ct: Optional[int] = None - self.profiler_target_decode_ct: Optional[int] = None - self.profiler_prefill_ct: Optional[int] = None - self.profiler_decode_ct: Optional[int] = None - self.profile_by_stage: bool = False - self.profile_steps: Optional[int] = None - self.profile_in_progress: bool = False - self.rpd_profiler = None - - # Init metrics stats + self.init_profier() self.init_metrics() self.init_kv_events(server_args.kv_events_config) @@ -526,6 +514,7 @@ class Scheduler( ] ) + # Init disaggregation self.disaggregation_mode = DisaggregationMode( self.server_args.disaggregation_mode ) @@ -624,6 +613,21 @@ class Scheduler( ) ) + def init_profier(self): + self.torch_profiler = None + self.torch_profiler_output_dir: Optional[str] = None + self.profiler_activities: Optional[List[str]] = None + self.profile_id: Optional[str] = None + self.profiler_target_forward_ct: Optional[int] = None + self.profiler_target_prefill_ct: Optional[int] = None + self.profiler_target_decode_ct: Optional[int] = None + self.profiler_prefill_ct: Optional[int] = None + self.profiler_decode_ct: Optional[int] = None + self.profile_by_stage: bool = False + self.profile_steps: Optional[int] = None + self.profile_in_progress: bool = False + self.rpd_profiler = None + def init_metrics(self): self.last_gen_throughput: float = 0.0 self.last_input_throughput: float = 0.0 @@ -2107,6 +2111,18 @@ class Scheduler( def get_internal_state(self, recv_req: GetInternalStateReq): ret = dict(global_server_args_dict) ret["last_gen_throughput"] = self.last_gen_throughput + ret["memory_usage"] = { + "weight": round( + self.tp_worker.worker.model_runner.weight_load_mem_usage, 2 + ), + "kvcache": round( + self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2 + ), + "cuda_graph": round( + self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2 + ), + "token_capacity": int(self.max_total_num_tokens), + } if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: ret["avg_spec_accept_length"] = ( self.cum_spec_accept_length / self.cum_spec_accept_count diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index d2a450aec..9ad9fdbfb 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin: stream_interval = ( req.sampling_params.stream_interval or self.stream_interval ) - should_output = len(req.output_ids) % stream_interval == 0 + should_output = ( + len(req.output_ids) % stream_interval == 1 + if not self.model_config.is_multimodal_gen + and stream_interval > 1 + else len(req.output_ids) % stream_interval == 0 + ) else: should_output = ( len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 - and not self.model_config.is_multimodal_gen + if not self.model_config.is_multimodal_gen + else False ) if should_output: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index aa61ad063..16ab5c10c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) -from sglang.srt.managers.multimodal_processor import ( - get_dummy_processor, - get_mm_processor, - import_processors, -) +from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs @@ -187,6 +183,8 @@ class TokenizerManager: if server_args.preferred_sampling_params else None ) + self.crash_dump_folder = server_args.crash_dump_folder + self.crash_dump_performed = False # Flag to ensure dump is only called once # Init inter-process communication context = zmq.asyncio.Context(2) @@ -251,10 +249,11 @@ class TokenizerManager: self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] + self.crash_dump_request_list: deque[Tuple] = deque() self.log_request_metadata = self.get_log_request_metadata() - self.asyncio_tasks = set() self.session_futures = {} # session_id -> asyncio event self.max_req_input_len = None + self.asyncio_tasks = set() # The event to notify the weight sync is finished. self.model_update_lock = RWLock() @@ -266,14 +265,14 @@ class TokenizerManager: self.disaggregation_mode = DisaggregationMode( self.server_args.disaggregation_mode ) - self.transfer_backend = TransferBackend( + self.disaggregation_transfer_backend = TransferBackend( self.server_args.disaggregation_transfer_backend ) # Start kv boostrap server on prefill if self.disaggregation_mode == DisaggregationMode.PREFILL: # only start bootstrap server on prefill tm kv_bootstrap_server_class = get_kv_class( - self.transfer_backend, KVClassType.BOOTSTRAP_SERVER + self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER ) self.bootstrap_server = kv_bootstrap_server_class( self.server_args.disaggregation_bootstrap_port @@ -324,7 +323,6 @@ class TokenizerManager: self.profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1) self.get_internal_state_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -484,7 +482,7 @@ class TokenizerManager: token_type_ids = encoded.get("token_type_ids", [None])[0] if self.mm_processor and obj.contains_mm_input(): - image_inputs = await self.mm_processor.process_mm_data_async( + image_inputs: Dict = await self.mm_processor.process_mm_data_async( image_data=obj.image_data, input_text=input_text or input_ids, request_obj=obj, @@ -547,6 +545,14 @@ class TokenizerManager: "Please set `--enable-custom-logits-processor` to enable this feature." ) + def _validate_input_ids_in_vocab( + self, input_ids: List[int], vocab_size: int + ) -> None: + if any(id >= vocab_size for id in input_ids): + raise ValueError( + f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})." + ) + def _create_tokenized_object( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -1096,12 +1102,36 @@ class TokenizerManager: "image_data", "audio_data", "lora_path", + "sampling_params", + ] + ) + out_skip_names = set( + [ + "text", + "output_ids", ] ) - out_skip_names = set(["text", "output_ids", "embedding"]) elif self.log_requests_level == 1: - max_length = 2048 + max_length = 1 << 30 + skip_names = set( + [ + "text", + "input_ids", + "input_embeds", + "image_data", + "audio_data", + "lora_path", + ] + ) + out_skip_names = set( + [ + "text", + "output_ids", + ] + ) elif self.log_requests_level == 2: + max_length = 2048 + elif self.log_requests_level == 3: max_length = 1 << 30 else: raise ValueError( @@ -1118,6 +1148,8 @@ class TokenizerManager: self.dump_requests_folder = obj.dump_requests_folder if obj.dump_requests_threshold is not None: self.dump_requests_threshold = obj.dump_requests_threshold + if obj.crash_dump_folder is not None: + self.crash_dump_folder = obj.crash_dump_folder logging.info(f"Config logging: {obj=}") self.log_request_metadata = self.get_log_request_metadata() @@ -1166,6 +1198,52 @@ class TokenizerManager: loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) ) + def dump_requests_before_crash(self): + if self.crash_dump_performed: + logger.info( + "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping." + ) + return + logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}") + self.crash_dump_performed = True + if not self.crash_dump_folder: + return + + data_to_dump = [] + if self.crash_dump_request_list: + data_to_dump.extend(self.crash_dump_request_list) + + # Add unfinished requests from rid_to_state + unfinished_requests = [] + for rid, state in self.rid_to_state.items(): + if not state.finished: + unfinished_requests.append( + (state.obj, {}, state.created_time, time.time()) + ) + if unfinished_requests: + data_to_dump.extend(unfinished_requests) + + if not data_to_dump: + return + + filename = os.path.join( + self.crash_dump_folder, + os.getenv("HOSTNAME", None), + f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl', + ) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + # Include server_args in the dump + data_to_dump_with_server_args = { + "server_args": self.server_args, + "requests": data_to_dump, + } + with open(filename, "wb") as f: + pickle.dump(data_to_dump_with_server_args, f) + logger.error( + f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}" + ) + async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(5) @@ -1175,11 +1253,12 @@ class TokenizerManager: remain_num_req = len(self.rid_to_state) if self.health_check_failed: - # if health check failed, exit immediately + # if health check failed, we should exit immediately logger.error( "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", remain_num_req, ) + self.dump_requests_before_crash() break elif get_bool_env_var("SGL_FORCE_SHUTDOWN"): @@ -1196,6 +1275,7 @@ class TokenizerManager: if remain_num_req > 0: await asyncio.sleep(5) else: + self.dump_requests_before_crash() break kill_process_tree(os.getpid(), include_parent=True) @@ -1273,16 +1353,7 @@ class TokenizerManager: "meta_info": meta_info, } elif isinstance(recv_obj, BatchMultimodalOut): - if isinstance(recv_obj.outputs[i], str): - out_dict = { - "text": recv_obj.outputs[i], - "meta_info": meta_info, - } - else: - out_dict = { - "outputs": json.dumps(recv_obj.outputs[i]), - "meta_info": meta_info, - } + raise NotImplementedError("BatchMultimodalOut not implemented") else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { @@ -1306,6 +1377,8 @@ class TokenizerManager: self.collect_metrics(state, recv_obj, i) if self.dump_requests_folder and state.finished and state.obj.log_metrics: self.dump_requests(state, out_dict) + if self.crash_dump_folder and state.finished and state.obj.log_metrics: + self.record_request_for_crash_dump(state, out_dict) def convert_logprob_style( self, @@ -1317,6 +1390,9 @@ class TokenizerManager: recv_obj: BatchStrOut, recv_obj_index: int, ): + if recv_obj.input_token_logprobs_val is None: + return + if len(recv_obj.input_token_logprobs_val) > 0: state.input_token_logprobs_val.extend( recv_obj.input_token_logprobs_val[recv_obj_index] @@ -1436,7 +1512,10 @@ class TokenizerManager: else 0 ) - if state.first_token_time == 0.0: + if ( + state.first_token_time == 0.0 + and self.disaggregation_mode != DisaggregationMode.PREFILL + ): state.first_token_time = state.last_time = time.time() state.last_completion_tokens = completion_tokens self.metrics_collector.observe_time_to_first_token( @@ -1484,14 +1563,31 @@ class TokenizerManager: to_dump = self.dump_request_list self.dump_request_list = [] + to_dump_with_server_args = { + "server_args": self.server_args, + "requests": to_dump, + } + def background_task(): os.makedirs(self.dump_requests_folder, exist_ok=True) with open(filename, "wb") as f: - pickle.dump(to_dump, f) + pickle.dump(to_dump_with_server_args, f) # Schedule the task to run in the background without awaiting it asyncio.create_task(asyncio.to_thread(background_task)) + def record_request_for_crash_dump(self, state: ReqState, out_dict: dict): + current_time = time.time() + self.crash_dump_request_list.append( + (state.obj, out_dict, state.created_time, current_time) + ) + # Remove requests older than 5 minutes based on finish time + while ( + self.crash_dump_request_list + and current_time - self.crash_dump_request_list[0][3] >= 300 + ): + self.crash_dump_request_list.popleft() + def _handle_abort_req(self, recv_obj): self.rid_to_state.pop(recv_obj.rid, None) @@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func): except Exception: traceback = get_exception_traceback() logger.error(f"TokenizerManager hit an exception: {traceback}") + if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager): + func.__self__.dump_requests_before_crash() kill_process_tree(os.getpid(), include_parent=True) sys.exit(1) @@ -1632,6 +1730,7 @@ class SignalHandler: logger.error( "Received sigquit from a child process. It usually means the child failed." ) + self.tokenizer_manager.dump_requests_before_crash() kill_process_tree(os.getpid()) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 14eef2043..4e3f40371 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -123,6 +123,7 @@ class KVCache(abc.ABC): self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) + self.mem_usage = 0 # used for chunked cpu-offloading self.cpu_offloading_chunk_size = 8192 @@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache): logger.info( f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" ) + self.mem_usage = (k_size + v_size) / GB def _create_buffers(self): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): @@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache): logger.info( f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB" ) + self.mem_usage = kv_size / GB def get_kv_size_bytes(self): assert hasattr(self, "kv_buffer") diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8b9a367f4..ccb1cf08f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -604,12 +604,13 @@ class ModelRunner: self.dtype = self.model_config.dtype after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + self.weight_load_mem_usage = before_avail_memory - after_avail_memory logger.info( f"Load weight end. " f"type={type(self.model).__name__}, " f"dtype={self.dtype}, " f"avail mem={after_avail_memory:.2f} GB, " - f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB." + f"mem usage={self.weight_load_mem_usage:.2f} GB." ) # Handle the case where some ranks do not finish loading. @@ -1250,6 +1251,7 @@ class ModelRunner: def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None + self.cuda_graph_mem_usage = 0 if not self.is_generation: # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models @@ -1265,9 +1267,10 @@ class ModelRunner: ) self.cuda_graph_runner = CudaGraphRunner(self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) + self.cuda_graph_mem_usage = before_mem - after_mem logger.info( f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. " - f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." + f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." ) def apply_torch_tp(self): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 29c82b084..7b1154c94 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader): model_config: ModelConfig, device_config: DeviceConfig, ) -> nn.Module: + + if get_bool_env_var("SGL_CPU_QUANTIZATION"): + return load_model_with_cpu_quantization( + self, model_config=model_config, device_config=device_config + ) + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model( @@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader): return model.eval() +def load_model_with_cpu_quantization( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, +) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model = _initialize_model( + model_config, + self.load_config, + ) + + if not isinstance(self, DummyModelLoader): + model.load_weights(self._get_all_weights(model_config, model)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + model.to(target_device) + + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" diff --git a/python/sglang/srt/models/mistral.py b/python/sglang/srt/models/mistral.py index b0c6beaab..d3d2efcae 100644 --- a/python/sglang/srt/models/mistral.py +++ b/python/sglang/srt/models/mistral.py @@ -13,7 +13,7 @@ # ============================================================================== """Inference-only Mistral model.""" -from typing import List, Union +from typing import List import torch from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 45b2e2d9e..050099a03 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -99,6 +99,7 @@ class ServerArgs: log_level_http: Optional[str] = None log_requests: bool = False log_requests_level: int = 0 + crash_dump_folder: Optional[str] = None show_time_cost: bool = False enable_metrics: bool = False bucket_time_to_first_token: Optional[List[float]] = None @@ -927,8 +928,14 @@ class ServerArgs: "--log-requests-level", type=int, default=0, - help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.", - choices=[0, 1, 2], + help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.", + choices=[0, 1, 2, 3], + ) + parser.add_argument( + "--crash-dump-folder", + type=str, + default=ServerArgs.crash_dump_folder, + help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.", ) parser.add_argument( "--show-time-cost", diff --git a/python/sglang/srt/warmup.py b/python/sglang/srt/warmup.py index fc6d2202b..0bed9fb94 100644 --- a/python/sglang/srt/warmup.py +++ b/python/sglang/srt/warmup.py @@ -4,6 +4,7 @@ from typing import List import numpy as np import tqdm +from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -20,17 +21,21 @@ def warmup(name: str) -> callable: return decorator -async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager): +async def execute_warmups( + disaggregation_mode: str, + warmup_names: List[str], + tokenizer_manager: TokenizerManager, +): for warmup_name in warmup_names: if warmup_name not in _warmup_registry: logger.warning(f"Could not find custom warmup {warmup_name}") continue logger.info(f"Running warmup {warmup_name}") - await _warmup_registry[warmup_name](tokenizer_manager) + await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager) @warmup("voice_chat") -async def voice_chat(tokenizer_manager: TokenizerManager): +async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager): # this warms up the fused_moe triton kernels and caches them # if we don't do this we break real time inference for voice chat for i in tqdm.trange(1, 512): @@ -44,4 +49,8 @@ async def voice_chat(tokenizer_manager: TokenizerManager): "min_p": 0.0, }, ) + if disaggregation_mode != "null": + generate_req_input.bootstrap_room = 0 + generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST + await tokenizer_manager.generate_request(generate_req_input, None).__anext__() diff --git a/scripts/playground/replay_request_dump.py b/scripts/playground/replay_request_dump.py new file mode 100644 index 000000000..93d0d7d26 --- /dev/null +++ b/scripts/playground/replay_request_dump.py @@ -0,0 +1,150 @@ +""" +Usage: +# replay from a folder +python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/ + +# replay from a single file +python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl +""" + +import argparse +import glob +import json +import pickle +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict +from datetime import datetime + +import requests + +from sglang.bench_serving import set_ulimit +from sglang.utils import get_exception_traceback + + +def read_records(files): + records = [] + for f in files: + tmp = pickle.load(open(f, "rb")) + if isinstance(tmp, dict) and "requests" in tmp: + records.extend(tmp["requests"]) + else: + records.extend(tmp) + + return records + + +def run_one_request_internal(record): + (req, output, replay_init_time, start_time, end_time, idx) = record + time.sleep(max(0, start_time - (time.time() - replay_init_time))) + + if "completion_tokens" in output.get("meta_info", {}): + recorded_completion_tokens = output["meta_info"]["completion_tokens"] + else: + recorded_completion_tokens = "" + + json_data = asdict(req) + stream = json_data["stream"] + + if args.ignore_eos: + json_data["sampling_params"]["ignore_eos"] = True + if recorded_completion_tokens: + json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens + + response = requests.post( + f"http://{args.host}:{args.port}/generate", + json=json_data, + stream=stream, + ) + + if stream: + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + ret = json.loads(chunk[5:].strip("\n")) + else: + ret = response.json() + + prompt_tokens = ret["meta_info"]["prompt_tokens"] + completion_tokens = ret["meta_info"]["completion_tokens"] + print( + f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, " + f"{completion_tokens=}, {recorded_completion_tokens=}" + ) + + +def run_one_request(record): + # global success_ct, error_ct + + try: + run_one_request_internal(record) + # success_ct += 1 + except Exception: + # error_ct += 1 + traceback = get_exception_traceback() + print(f"Hit an exception: {traceback}") + + +def main(records): + if len(records) == 0: + return + + base_time = records[0][-2] + base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S") + print(f"{base_time_str=}") + replay_init_time = time.time() + + for i in range(len(records)): + req, output, start_time, end_time = records[i] + start_time -= base_time + records[i] = (req, output, replay_init_time, start_time, end_time, i) + + with ThreadPoolExecutor(args.parallel) as executor: + executor.map(run_one_request, records) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=30000) + parser.add_argument( + "--input-folder", type=str, default=None, help="Folder containing pickle files" + ) + parser.add_argument( + "--input-file", type=str, default=None, help="Single pickle file to process" + ) + parser.add_argument("--file-number", type=int, default=1) + parser.add_argument("--req-number", type=int, default=1000000) + parser.add_argument("--req-start", type=int, default=0) + parser.add_argument("--parallel", type=int, default=512) + parser.add_argument("--idx", type=int, default=None) + parser.add_argument("--ignore-eos", action="store_true") + args = parser.parse_args() + + set_ulimit() + + files = [] + if args.input_file: + files = [args.input_file] + if args.file_number > 1: + print("Warning: --file-number is ignored when --input-file is provided.") + elif args.input_folder: + files = glob.glob(f"{args.input_folder}/*.pkl") + files = files[: args.file_number] + else: + print("Error: Either --input-folder or --input-file must be provided.") + exit(1) + print(f"{files=}") + + records = read_records(files) + # Sort by the receive time, before filtering + records.sort(key=lambda x: x[-2]) + records = records[args.req_start :] + if args.idx: + records = [records[args.idx]] + print(f"testing {args.idx=}") + print(f"{records[0]}") + print(f"{len(records)=}") + main(records) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a9c72b952..34c190cf8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -173,10 +173,11 @@ suites = { # TestFile("test_deepep_intranode.py", 50), # TestFile("test_deepep_low_latency.py", 50), # TestFile("test_moe_deepep_eval_accuracy_large.py", 250), + # Disabled because it hangs on the CI. + # TestFile("test_moe_ep.py", 181), TestFile("test_disaggregation.py", 270), TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_full_deepseek_v3.py", 463), - TestFile("test_moe_ep.py", 181), ], "per-commit-8-gpu-amd": [ TestFile("test_full_deepseek_v3.py", 250), diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py index e41759e7b..3876e915b 100644 --- a/test/srt/test_vision_chunked_prefill.py +++ b/test/srt/test_vision_chunked_prefill.py @@ -178,7 +178,7 @@ class TestVisionChunkedPrefill(CustomTestCase): print(output_chunked) print("output without chunked prefill:") print(output_no_chunked) - assert output_chunked == output_no_chunked + self.assertEqual(output_chunked, output_no_chunked) def test_chunked_prefill(self): self._test_chunked_prefill(batches=[False, True], num_frames=[1, [2, 6, 8, 10]])