Return more infos for computing average acceptance length (#3152)
This commit is contained in:
@@ -57,6 +57,7 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
kill_process_tree,
|
||||
launch_dummy_health_check_server,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
set_prometheus_multiproc_dir,
|
||||
@@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
||||
|
||||
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||
# When using `Engine` as a Python API, we don't want to block here.
|
||||
return
|
||||
return None, None
|
||||
|
||||
launch_dummy_health_check_server(server_args.host, server_args.port)
|
||||
|
||||
for proc in scheduler_procs:
|
||||
proc.join()
|
||||
logger.error(
|
||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||
)
|
||||
return
|
||||
return None, None
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
|
||||
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
||||
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||
|
||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||
)
|
||||
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
],
|
||||
tp_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
False,
|
||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
|
||||
@@ -201,6 +201,7 @@ class DetokenizerManager:
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
spec_verify_ct=recv_obj.spec_verify_ct,
|
||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||
|
||||
@@ -354,10 +354,13 @@ class BatchTokenIDOut:
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
no_stop_trim: List[bool]
|
||||
|
||||
# Token counts
|
||||
prompt_tokens: List[int]
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
spec_verify_ct: List[int]
|
||||
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
input_token_logprobs_idx: List[int]
|
||||
@@ -382,6 +385,7 @@ class BatchStrOut:
|
||||
prompt_tokens: List[int]
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
spec_verify_ct: List[int]
|
||||
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
|
||||
@@ -252,7 +252,6 @@ class Req:
|
||||
|
||||
# Sampling info
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
|
||||
# Memory pool info
|
||||
@@ -300,7 +299,7 @@ class Req:
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
|
||||
# Logprobs (return value)
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
@@ -329,10 +328,15 @@ class Req:
|
||||
# Constrained decoding
|
||||
self.grammar: Optional[BaseGrammarObject] = None
|
||||
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
# The number of cached tokens that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
self.already_computed = 0
|
||||
|
||||
# The number of verification forward passes in the speculative decoding.
|
||||
# This is used to compute the average acceptance length per request.
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
|
||||
@@ -281,6 +281,7 @@ class Scheduler:
|
||||
# Print debug info
|
||||
logger.info(
|
||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
f"max_running_requests={self.max_running_requests}, "
|
||||
f"context_len={self.model_config.context_len}"
|
||||
@@ -408,6 +409,11 @@ class Scheduler:
|
||||
},
|
||||
)
|
||||
|
||||
# The largest prefill length of a single request
|
||||
self._largest_prefill_len: int = 0
|
||||
# The largest context length (prefill + generation) of a single request
|
||||
self._largest_prefill_decode_len: int = 0
|
||||
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -1371,6 +1377,7 @@ class Scheduler:
|
||||
prompt_tokens = []
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
spec_verify_ct = []
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
@@ -1424,6 +1431,9 @@ class Scheduler:
|
||||
completion_tokens.append(len(req.output_ids))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
spec_verify_ct.append(req.spec_verify_ct)
|
||||
|
||||
if return_logprob:
|
||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||
@@ -1451,6 +1461,7 @@ class Scheduler:
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
spec_verify_ct,
|
||||
input_token_logprobs_val,
|
||||
input_token_logprobs_idx,
|
||||
output_token_logprobs_val,
|
||||
|
||||
@@ -785,6 +785,9 @@ class TokenizerManager:
|
||||
i,
|
||||
)
|
||||
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
meta_info.update(
|
||||
{
|
||||
@@ -809,6 +812,7 @@ class TokenizerManager:
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
state.event.set()
|
||||
|
||||
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
if reverse:
|
||||
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
||||
else:
|
||||
# NOTE: Temporarily workaround MoE
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
if batch_size == 1:
|
||||
if num_tokens == 1:
|
||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||
# so we decide to only use torch.compile when bs =1
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
@@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
||||
sub._forward_method = sub.forward_native
|
||||
setattr(sub, "is_torch_compile", True)
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse, batch_size)
|
||||
_to_torch(sub, reverse, num_tokens)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_model(
|
||||
model: torch.nn.Module,
|
||||
enable_compile: bool,
|
||||
batch_size: int,
|
||||
num_tokens: int,
|
||||
tp_group: GroupCoordinator,
|
||||
):
|
||||
"""Patch the model to make it compatible with with torch.compile"""
|
||||
@@ -70,7 +70,7 @@ def patch_model(
|
||||
|
||||
try:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=False, batch_size=batch_size)
|
||||
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
# Use custom-allreduce here.
|
||||
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
||||
@@ -85,7 +85,7 @@ def patch_model(
|
||||
yield model.forward
|
||||
finally:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=True, batch_size=batch_size)
|
||||
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
|
||||
|
||||
@@ -283,8 +283,8 @@ class CudaGraphRunner:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
bs,
|
||||
self.model_runner.tp_group,
|
||||
num_tokens=bs * self.num_tokens_per_bs,
|
||||
tp_group=self.model_runner.tp_group,
|
||||
) as forward:
|
||||
(
|
||||
graph,
|
||||
|
||||
@@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo):
|
||||
if not req.finished():
|
||||
new_accept_index.extend(new_accept_index_)
|
||||
unfinished_index.append(i)
|
||||
req.spec_verify_ct += 1
|
||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import io
|
||||
import ipaddress
|
||||
@@ -29,6 +30,7 @@ import shutil
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
@@ -59,7 +61,6 @@ from triton.runtime.cache import (
|
||||
default_dump_dir,
|
||||
default_override_dir,
|
||||
)
|
||||
from uvicorn.config import LOGGING_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1366,7 +1367,33 @@ def nullable_str(val: str):
|
||||
return val
|
||||
|
||||
|
||||
def pyspy_dump_schedulers():
|
||||
"""py-spy dump on all scheduler in a local node."""
|
||||
try:
|
||||
pid = psutil.Process().pid
|
||||
# Command to run py-spy with the PID
|
||||
cmd = f"py-spy dump --pid {pid}"
|
||||
result = subprocess.run(
|
||||
cmd, shell=True, capture_output=True, text=True, check=True
|
||||
)
|
||||
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
|
||||
|
||||
|
||||
def kill_itself_when_parent_died():
|
||||
if sys.platform == "linux":
|
||||
# sigkill this process when parent worker manager dies
|
||||
PR_SET_PDEATHSIG = 1
|
||||
libc = ctypes.CDLL("libc.so.6")
|
||||
libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
|
||||
else:
|
||||
logger.warninig("kill_itself_when_parent_died is only supported in linux.")
|
||||
|
||||
|
||||
def set_uvicorn_logging_configs():
|
||||
from uvicorn.config import LOGGING_CONFIG
|
||||
|
||||
LOGGING_CONFIG["formatters"]["default"][
|
||||
"fmt"
|
||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||
@@ -1449,3 +1476,28 @@ def rank0_print(msg: str):
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def launch_dummy_health_check_server(host, port):
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user