Return more infos for computing average acceptance length (#3152)
This commit is contained in:
@@ -57,6 +57,7 @@ from sglang.srt.utils import (
|
|||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
|
launch_dummy_health_check_server,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
prepare_model_and_tokenizer,
|
prepare_model_and_tokenizer,
|
||||||
set_prometheus_multiproc_dir,
|
set_prometheus_multiproc_dir,
|
||||||
@@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|||||||
|
|
||||||
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||||
# When using `Engine` as a Python API, we don't want to block here.
|
# When using `Engine` as a Python API, we don't want to block here.
|
||||||
return
|
return None, None
|
||||||
|
|
||||||
|
launch_dummy_health_check_server(server_args.host, server_args.port)
|
||||||
|
|
||||||
for proc in scheduler_procs:
|
for proc in scheduler_procs:
|
||||||
proc.join()
|
proc.join()
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||||
)
|
)
|
||||||
return
|
return None, None
|
||||||
|
|
||||||
# Launch detokenizer process
|
# Launch detokenizer process
|
||||||
detoken_proc = mp.Process(
|
detoken_proc = mp.Process(
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|||||||
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||||
|
|
||||||
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||||
|
|
||||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
||||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||||
)
|
)
|
||||||
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|||||||
],
|
],
|
||||||
tp_rank,
|
tp_rank,
|
||||||
torch.distributed.get_backend(tp_group.device_group),
|
torch.distributed.get_backend(tp_group.device_group),
|
||||||
False,
|
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ class DetokenizerManager:
|
|||||||
prompt_tokens=recv_obj.prompt_tokens,
|
prompt_tokens=recv_obj.prompt_tokens,
|
||||||
completion_tokens=recv_obj.completion_tokens,
|
completion_tokens=recv_obj.completion_tokens,
|
||||||
cached_tokens=recv_obj.cached_tokens,
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
|
spec_verify_ct=recv_obj.spec_verify_ct,
|
||||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||||
|
|||||||
@@ -354,10 +354,13 @@ class BatchTokenIDOut:
|
|||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
spaces_between_special_tokens: List[bool]
|
spaces_between_special_tokens: List[bool]
|
||||||
no_stop_trim: List[bool]
|
no_stop_trim: List[bool]
|
||||||
|
|
||||||
# Token counts
|
# Token counts
|
||||||
prompt_tokens: List[int]
|
prompt_tokens: List[int]
|
||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
|
spec_verify_ct: List[int]
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
input_token_logprobs_idx: List[int]
|
input_token_logprobs_idx: List[int]
|
||||||
@@ -382,6 +385,7 @@ class BatchStrOut:
|
|||||||
prompt_tokens: List[int]
|
prompt_tokens: List[int]
|
||||||
completion_tokens: List[int]
|
completion_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
|
spec_verify_ct: List[int]
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
input_token_logprobs_val: List[float]
|
input_token_logprobs_val: List[float]
|
||||||
|
|||||||
@@ -252,7 +252,6 @@ class Req:
|
|||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.lora_path = lora_path
|
|
||||||
self.custom_logit_processor = custom_logit_processor
|
self.custom_logit_processor = custom_logit_processor
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
@@ -300,7 +299,7 @@ class Req:
|
|||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
self.top_logprobs_num = top_logprobs_num
|
self.top_logprobs_num = top_logprobs_num
|
||||||
|
|
||||||
# Logprobs (return value)
|
# Logprobs (return values)
|
||||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||||
@@ -329,10 +328,15 @@ class Req:
|
|||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.grammar: Optional[BaseGrammarObject] = None
|
self.grammar: Optional[BaseGrammarObject] = None
|
||||||
|
|
||||||
# The number of cached tokens, that were already cached in the KV cache
|
# The number of cached tokens that were already cached in the KV cache
|
||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
self.already_computed = 0
|
self.already_computed = 0
|
||||||
|
|
||||||
|
# The number of verification forward passes in the speculative decoding.
|
||||||
|
# This is used to compute the average acceptance length per request.
|
||||||
|
self.spec_verify_ct = 0
|
||||||
|
self.lora_path = lora_path
|
||||||
|
|
||||||
def extend_image_inputs(self, image_inputs):
|
def extend_image_inputs(self, image_inputs):
|
||||||
if self.image_inputs is None:
|
if self.image_inputs is None:
|
||||||
self.image_inputs = image_inputs
|
self.image_inputs = image_inputs
|
||||||
|
|||||||
@@ -281,6 +281,7 @@ class Scheduler:
|
|||||||
# Print debug info
|
# Print debug info
|
||||||
logger.info(
|
logger.info(
|
||||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||||
|
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
||||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||||
f"max_running_requests={self.max_running_requests}, "
|
f"max_running_requests={self.max_running_requests}, "
|
||||||
f"context_len={self.model_config.context_len}"
|
f"context_len={self.model_config.context_len}"
|
||||||
@@ -408,6 +409,11 @@ class Scheduler:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The largest prefill length of a single request
|
||||||
|
self._largest_prefill_len: int = 0
|
||||||
|
# The largest context length (prefill + generation) of a single request
|
||||||
|
self._largest_prefill_decode_len: int = 0
|
||||||
|
|
||||||
# Init request dispatcher
|
# Init request dispatcher
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
@@ -1371,6 +1377,7 @@ class Scheduler:
|
|||||||
prompt_tokens = []
|
prompt_tokens = []
|
||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
cached_tokens = []
|
cached_tokens = []
|
||||||
|
spec_verify_ct = []
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
input_token_logprobs_val = []
|
input_token_logprobs_val = []
|
||||||
@@ -1424,6 +1431,9 @@ class Scheduler:
|
|||||||
completion_tokens.append(len(req.output_ids))
|
completion_tokens.append(len(req.output_ids))
|
||||||
cached_tokens.append(req.cached_tokens)
|
cached_tokens.append(req.cached_tokens)
|
||||||
|
|
||||||
|
if not self.spec_algorithm.is_none():
|
||||||
|
spec_verify_ct.append(req.spec_verify_ct)
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||||
@@ -1451,6 +1461,7 @@ class Scheduler:
|
|||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
cached_tokens,
|
cached_tokens,
|
||||||
|
spec_verify_ct,
|
||||||
input_token_logprobs_val,
|
input_token_logprobs_val,
|
||||||
input_token_logprobs_idx,
|
input_token_logprobs_idx,
|
||||||
output_token_logprobs_val,
|
output_token_logprobs_val,
|
||||||
|
|||||||
@@ -785,6 +785,9 @@ class TokenizerManager:
|
|||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.server_args.speculative_algorithm:
|
||||||
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||||
|
|
||||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||||
meta_info.update(
|
meta_info.update(
|
||||||
{
|
{
|
||||||
@@ -809,6 +812,7 @@ class TokenizerManager:
|
|||||||
"embedding": recv_obj.embeddings[i],
|
"embedding": recv_obj.embeddings[i],
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
state.out_list.append(out_dict)
|
state.out_list.append(out_dict)
|
||||||
state.finished = recv_obj.finished_reasons[i] is not None
|
state.finished = recv_obj.finished_reasons[i] is not None
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
if isinstance(sub, CustomOp):
|
if isinstance(sub, CustomOp):
|
||||||
if reverse:
|
if reverse:
|
||||||
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|||||||
else:
|
else:
|
||||||
# NOTE: Temporarily workaround MoE
|
# NOTE: Temporarily workaround MoE
|
||||||
if "FusedMoE" in sub.__class__.__name__:
|
if "FusedMoE" in sub.__class__.__name__:
|
||||||
if batch_size == 1:
|
if num_tokens == 1:
|
||||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||||
# so we decide to only use torch.compile when bs =1
|
# so we decide to only use torch.compile when bs =1
|
||||||
sub._forward_method = fused_moe_forward_native
|
sub._forward_method = fused_moe_forward_native
|
||||||
@@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|||||||
sub._forward_method = sub.forward_native
|
sub._forward_method = sub.forward_native
|
||||||
setattr(sub, "is_torch_compile", True)
|
setattr(sub, "is_torch_compile", True)
|
||||||
if isinstance(sub, torch.nn.Module):
|
if isinstance(sub, torch.nn.Module):
|
||||||
_to_torch(sub, reverse, batch_size)
|
_to_torch(sub, reverse, num_tokens)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_model(
|
def patch_model(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
enable_compile: bool,
|
enable_compile: bool,
|
||||||
batch_size: int,
|
num_tokens: int,
|
||||||
tp_group: GroupCoordinator,
|
tp_group: GroupCoordinator,
|
||||||
):
|
):
|
||||||
"""Patch the model to make it compatible with with torch.compile"""
|
"""Patch the model to make it compatible with with torch.compile"""
|
||||||
@@ -70,7 +70,7 @@ def patch_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if enable_compile:
|
if enable_compile:
|
||||||
_to_torch(model, reverse=False, batch_size=batch_size)
|
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
||||||
backup_ca_comm = tp_group.ca_comm
|
backup_ca_comm = tp_group.ca_comm
|
||||||
# Use custom-allreduce here.
|
# Use custom-allreduce here.
|
||||||
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
||||||
@@ -85,7 +85,7 @@ def patch_model(
|
|||||||
yield model.forward
|
yield model.forward
|
||||||
finally:
|
finally:
|
||||||
if enable_compile:
|
if enable_compile:
|
||||||
_to_torch(model, reverse=True, batch_size=batch_size)
|
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
||||||
tp_group.ca_comm = backup_ca_comm
|
tp_group.ca_comm = backup_ca_comm
|
||||||
|
|
||||||
|
|
||||||
@@ -283,8 +283,8 @@ class CudaGraphRunner:
|
|||||||
with patch_model(
|
with patch_model(
|
||||||
self.model_runner.model,
|
self.model_runner.model,
|
||||||
bs in self.compile_bs,
|
bs in self.compile_bs,
|
||||||
bs,
|
num_tokens=bs * self.num_tokens_per_bs,
|
||||||
self.model_runner.tp_group,
|
tp_group=self.model_runner.tp_group,
|
||||||
) as forward:
|
) as forward:
|
||||||
(
|
(
|
||||||
graph,
|
graph,
|
||||||
|
|||||||
@@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo):
|
|||||||
if not req.finished():
|
if not req.finished():
|
||||||
new_accept_index.extend(new_accept_index_)
|
new_accept_index.extend(new_accept_index_)
|
||||||
unfinished_index.append(i)
|
unfinished_index.append(i)
|
||||||
|
req.spec_verify_ct += 1
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
|
||||||
accept_index = accept_index[accept_index != -1]
|
accept_index = accept_index[accept_index != -1]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
import io
|
||||||
import ipaddress
|
import ipaddress
|
||||||
@@ -29,6 +30,7 @@ import shutil
|
|||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
@@ -59,7 +61,6 @@ from triton.runtime.cache import (
|
|||||||
default_dump_dir,
|
default_dump_dir,
|
||||||
default_override_dir,
|
default_override_dir,
|
||||||
)
|
)
|
||||||
from uvicorn.config import LOGGING_CONFIG
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1366,7 +1367,33 @@ def nullable_str(val: str):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def pyspy_dump_schedulers():
|
||||||
|
"""py-spy dump on all scheduler in a local node."""
|
||||||
|
try:
|
||||||
|
pid = psutil.Process().pid
|
||||||
|
# Command to run py-spy with the PID
|
||||||
|
cmd = f"py-spy dump --pid {pid}"
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd, shell=True, capture_output=True, text=True, check=True
|
||||||
|
)
|
||||||
|
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
|
||||||
|
|
||||||
|
|
||||||
|
def kill_itself_when_parent_died():
|
||||||
|
if sys.platform == "linux":
|
||||||
|
# sigkill this process when parent worker manager dies
|
||||||
|
PR_SET_PDEATHSIG = 1
|
||||||
|
libc = ctypes.CDLL("libc.so.6")
|
||||||
|
libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
|
||||||
|
else:
|
||||||
|
logger.warninig("kill_itself_when_parent_died is only supported in linux.")
|
||||||
|
|
||||||
|
|
||||||
def set_uvicorn_logging_configs():
|
def set_uvicorn_logging_configs():
|
||||||
|
from uvicorn.config import LOGGING_CONFIG
|
||||||
|
|
||||||
LOGGING_CONFIG["formatters"]["default"][
|
LOGGING_CONFIG["formatters"]["default"][
|
||||||
"fmt"
|
"fmt"
|
||||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||||
@@ -1449,3 +1476,28 @@ def rank0_print(msg: str):
|
|||||||
|
|
||||||
if get_tensor_model_parallel_rank() == 0:
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
print(msg, flush=True)
|
print(msg, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_dummy_health_check_server(host, port):
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, Response
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
"""Check the health of the http server."""
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@app.get("/health_generate")
|
||||||
|
async def health_generate():
|
||||||
|
"""Check the health of the http server."""
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
timeout_keep_alive=5,
|
||||||
|
loop="uvloop",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user