Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)
Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ class BenchArgs:
|
||||
output_len: Tuple[int] = (16,)
|
||||
temperature: float = 0.0
|
||||
return_logprob: bool = False
|
||||
client_stream_interval: int = 1
|
||||
input_len_step_percentage: float = 0.0
|
||||
result_filename: str = "result.jsonl"
|
||||
base_url: str = ""
|
||||
@@ -60,6 +61,11 @@ class BenchArgs:
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--client-stream-interval",
|
||||
type=int,
|
||||
default=BenchArgs.client_stream_interval,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len-step-percentage",
|
||||
type=float,
|
||||
@@ -120,6 +126,7 @@ def run_one_case(
|
||||
output_len: int,
|
||||
temperature: float,
|
||||
return_logprob: bool,
|
||||
stream_interval: int,
|
||||
input_len_step_percentage: float,
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
@@ -168,6 +175,7 @@ def run_one_case(
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
else:
|
||||
proc, base_url = launch_server_process(server_args)
|
||||
|
||||
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
server_info = requests.get(base_url + "/get_server_info")
|
||||
tokenizer_path = server_info.json()["tokenizer_path"]
|
||||
tokenizer = get_tokenizer(tokenizer_path)
|
||||
|
||||
# warmup
|
||||
if not bench_args.skip_warmup:
|
||||
@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
output_len=16,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
|
||||
@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/generate"
|
||||
)
|
||||
args.apply_chat_template = True
|
||||
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
||||
api_url = (
|
||||
f"{args.base_url}/v1/completions"
|
||||
|
||||
@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig):
|
||||
)
|
||||
if (
|
||||
rope_scaling_factor is None
|
||||
or not isinstance(rope_scaling_factor, float)
|
||||
or not isinstance(rope_scaling_factor, int)
|
||||
or not isinstance(rope_scaling_factor, (float, int))
|
||||
or rope_scaling_factor < 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor}"
|
||||
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
|
||||
)
|
||||
if isinstance(rope_scaling_factor, int):
|
||||
rope_scaling_factor = float(rope_scaling_factor)
|
||||
|
||||
@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
|
||||
# Initialize OpenAI serving handlers
|
||||
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
_global_state.tokenizer_manager, _global_state.template_manager
|
||||
@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
|
||||
_global_state.tokenizer_manager
|
||||
)
|
||||
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.warmups.split(","), _global_state.tokenizer_manager
|
||||
server_args.disaggregation_mode,
|
||||
server_args.warmups.split(","),
|
||||
_global_state.tokenizer_manager,
|
||||
)
|
||||
logger.info("Warmup ended")
|
||||
|
||||
@@ -280,13 +281,17 @@ async def get_model_info():
|
||||
"model_path": _global_state.tokenizer_manager.model_path,
|
||||
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
||||
"is_generation": _global_state.tokenizer_manager.is_generation,
|
||||
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
internal_states = await _global_state.tokenizer_manager.get_internal_state()
|
||||
# Returns interna states per DP.
|
||||
internal_states: List[Dict[Any, Any]] = (
|
||||
await _global_state.tokenizer_manager.get_internal_state()
|
||||
)
|
||||
return {
|
||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||
**_global_state.scheduler_info,
|
||||
@@ -300,6 +305,8 @@ async def get_load():
|
||||
return await _global_state.tokenizer_manager.get_load()
|
||||
|
||||
|
||||
# example usage:
|
||||
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
|
||||
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
||||
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
||||
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
||||
@@ -886,6 +893,15 @@ def launch_server(
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
image_token_text = None
|
||||
if (
|
||||
tokenizer_manager.image_token_id is not None
|
||||
and not server_args.skip_tokenizer_init
|
||||
):
|
||||
image_token_text = tokenizer_manager.tokenizer.decode(
|
||||
[tokenizer_manager.image_token_id]
|
||||
)
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
@@ -893,7 +909,7 @@ def launch_server(
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
_global_state.tokenizer_manager.image_token_id,
|
||||
image_token_text,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
|
||||
return
|
||||
|
||||
# Debug print
|
||||
# logger.info(f"{res.json()=}")
|
||||
# logger.info(f"warmup request returns: {res.json()=}")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
fused_softcap_autotune = triton.autotune(
|
||||
configs=[
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
||||
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
||||
assert x.shape == residual.shape and x.dtype == residual.dtype
|
||||
output, mid = torch.empty_like(x), torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
|
||||
min_num_warps = 16 if _is_hip else 32
|
||||
|
||||
if autotune:
|
||||
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
||||
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
||||
)
|
||||
else:
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(
|
||||
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
||||
),
|
||||
4,
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
||||
else:
|
||||
output = torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
|
||||
min_num_warps = 16 if _is_hip else 32
|
||||
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
|
||||
return self.rmsnorm2.forward_native(residual), residual
|
||||
|
||||
|
||||
@triton.jit
|
||||
def experts_combine_kernel(
|
||||
out_hidden_states,
|
||||
moe_hidden_states,
|
||||
mlp_hidden_states,
|
||||
combine_k: tl.constexpr,
|
||||
hidden_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_index_mlp = pid * hidden_dim
|
||||
start_index_rmoe = pid * hidden_dim * combine_k
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < hidden_dim
|
||||
combine_k_offsets = tl.arange(0, combine_k)
|
||||
|
||||
moe_x = tl.load(
|
||||
moe_hidden_states
|
||||
+ start_index_rmoe
|
||||
+ combine_k_offsets[:, None] * hidden_dim
|
||||
+ offsets[None, :],
|
||||
mask=mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
moe_x = tl.sum(moe_x, axis=0)
|
||||
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
|
||||
combined_x = (moe_x + mlp_x) / 1.4142135623730951
|
||||
|
||||
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
|
||||
|
||||
|
||||
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
|
||||
assert moe_hidden_states.is_contiguous()
|
||||
assert mlp_hidden_states.is_contiguous()
|
||||
|
||||
if len(moe_hidden_states.shape) == 2:
|
||||
combine_k = 1 # pre-combined
|
||||
else:
|
||||
combine_k = moe_hidden_states.shape[1]
|
||||
|
||||
if output_buffer is None:
|
||||
out_hidden_states = torch.empty_like(mlp_hidden_states)
|
||||
else:
|
||||
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
|
||||
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
|
||||
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
|
||||
mlp_hidden_states.shape
|
||||
)
|
||||
|
||||
bs, hidden_dim = mlp_hidden_states.shape
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
|
||||
),
|
||||
}
|
||||
|
||||
experts_combine_kernel[(bs,)](
|
||||
out_hidden_states,
|
||||
moe_hidden_states,
|
||||
mlp_hidden_states,
|
||||
combine_k,
|
||||
hidden_dim,
|
||||
**config,
|
||||
)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
# gelu on first half of vector
|
||||
@triton.jit
|
||||
def gelu_and_mul_kernel(
|
||||
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
|
||||
out_scales = scales
|
||||
static_scale = True
|
||||
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
# 8 ele per thread (not tuned)
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
|
||||
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
||||
topk_weights_ptr, # output (bs, topk)
|
||||
topk_ids_ptr, # output (bs, topk)
|
||||
correction_bias_ptr,
|
||||
is_correction_bias: tl.constexpr,
|
||||
num_experts: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
moe_softcapping: tl.constexpr,
|
||||
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
|
||||
bottom = exped + 1
|
||||
logits_softcapped = top / bottom * moe_softcapping
|
||||
|
||||
# Add bias after softcapping
|
||||
if is_correction_bias:
|
||||
bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
|
||||
logits_softcapped = logits_softcapped + bias
|
||||
|
||||
# topk
|
||||
# assert 1 <= topk <= num_experts
|
||||
|
||||
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
|
||||
router_weight: torch.Tensor,
|
||||
topk: int,
|
||||
moe_softcapping: float,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
||||
bs, hidden_dim = x.shape
|
||||
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
|
||||
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
is_correction_bias = correction_bias is not None
|
||||
|
||||
grid = lambda meta: (bs,)
|
||||
|
||||
min_num_warps = 16 if _is_hip else 32
|
||||
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
fused_moe_router_kernel[grid](
|
||||
fused_moe_router_kernel[(bs,)](
|
||||
x,
|
||||
router_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
correction_bias,
|
||||
is_correction_bias=is_correction_bias,
|
||||
num_experts=num_experts,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
|
||||
topk_ids_ptr, # output (bs, topk)
|
||||
bs,
|
||||
num_experts: tl.constexpr,
|
||||
topk: tl.constexpr, # only support topk == 1
|
||||
topk: tl.constexpr, # only support topk <= 2
|
||||
moe_softcapping: tl.constexpr,
|
||||
moe_renormalize: tl.constexpr, # not supported
|
||||
K: tl.constexpr,
|
||||
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
|
||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||
|
||||
# 5. top1
|
||||
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
|
||||
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
|
||||
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
||||
cond_top1 = arange_block_size_n < num_experts
|
||||
top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
|
||||
top1_v = tl.max(
|
||||
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
||||
tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
||||
)
|
||||
invsumexp = 1.0 / tl.sum(
|
||||
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
||||
top1_invsumexp = 1.0 / tl.sum(
|
||||
tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
||||
)
|
||||
|
||||
# 6. store to output
|
||||
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
topk_mask = offs_topk < bs
|
||||
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
|
||||
# 6. store top1 to output
|
||||
offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
|
||||
top1_mask = offs_top1 < bs * topk
|
||||
tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
|
||||
tl.store(
|
||||
topk_weights_ptr + offs_topk,
|
||||
invsumexp,
|
||||
mask=topk_mask,
|
||||
topk_weights_ptr + offs_top1,
|
||||
top1_invsumexp,
|
||||
mask=top1_mask,
|
||||
)
|
||||
|
||||
# 7. handle topk == 2
|
||||
if topk == 2:
|
||||
cond_top2 = (arange_block_size_n < num_experts) and (
|
||||
arange_block_size_n != top1[:, None]
|
||||
)
|
||||
top2 = tl.argmax(
|
||||
tl.where(cond_top2, logits_softcapped, float("-inf")),
|
||||
axis=1,
|
||||
keep_dims=True,
|
||||
)
|
||||
top2_v = tl.sum(
|
||||
logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
|
||||
)
|
||||
top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
|
||||
|
||||
# store top2
|
||||
offs_top2 = (
|
||||
pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
|
||||
)
|
||||
top2_mask = offs_top2 < bs * topk
|
||||
tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
|
||||
tl.store(
|
||||
topk_weights_ptr + offs_top2,
|
||||
top2_invsumexp,
|
||||
mask=top2_mask,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_router_large_bs_impl(
|
||||
x: torch.Tensor,
|
||||
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
|
||||
|
||||
assert num_experts <= BLOCK_SIZE_N
|
||||
assert hidden_dim % BLOCK_SIZE_K == 0
|
||||
assert topk == 1
|
||||
assert topk <= 2
|
||||
|
||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert not renormalize
|
||||
assert (
|
||||
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
|
||||
BLOCK_SIZE_K = 256
|
||||
if (
|
||||
bs >= 512
|
||||
and topk == 1
|
||||
and topk <= 2
|
||||
and num_experts <= BLOCK_SIZE_N
|
||||
and hidden_dim % BLOCK_SIZE_K == 0
|
||||
):
|
||||
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
|
||||
router_weight=gating_output,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||
parser.add_argument("--log-requests", action="store_true")
|
||||
parser.add_argument("--log-requests-level", type=int, default=2)
|
||||
parser.add_argument("--log-requests-level", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||
)
|
||||
|
||||
@@ -516,9 +516,6 @@ class EmbeddingReqInput:
|
||||
# For cross-encoder requests
|
||||
is_cross_encoder_request: bool = False
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
# at least one of text, input_ids, or image should be provided
|
||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||
@@ -572,6 +569,9 @@ class EmbeddingReqInput:
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if self.is_cross_encoder_request:
|
||||
return EmbeddingReqInput(
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
Multi-modality utils
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
|
||||
# Convert valid pairs to tensor
|
||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
||||
return valid_pairs_tensor
|
||||
|
||||
|
||||
def data_hash(data) -> int:
|
||||
hash_bytes = hashlib.sha256(data).digest()[:8]
|
||||
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
||||
|
||||
|
||||
def tensor_hash(tensor_list) -> int:
|
||||
"""
|
||||
hash a tensor or a tensor list
|
||||
"""
|
||||
tensor = tensor_list
|
||||
if isinstance(tensor_list, list):
|
||||
tensor_list = flatten_nested_list(tensor_list)
|
||||
tensor_list = [
|
||||
x.flatten() if isinstance(x, torch.Tensor) else x for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
if tensor.is_cuda:
|
||||
return gpu_tensor_hash(tensor)
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
if tensor.is_cuda:
|
||||
# TODO: improve this
|
||||
tensor_cpu = tensor.cpu()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
|
||||
mv = memoryview(tensor_cpu.numpy())
|
||||
return data_hash(mv.tobytes())
|
||||
|
||||
|
||||
def hash_feature(f):
|
||||
if isinstance(f, list):
|
||||
if isinstance(f[0], torch.Tensor):
|
||||
return tensor_hash(f)
|
||||
return data_hash(tuple(flatten_nested_list(f)))
|
||||
elif isinstance(f, np.ndarray):
|
||||
arr = np.ascontiguousarray(f)
|
||||
arr_bytes = arr.tobytes()
|
||||
return data_hash(arr_bytes)
|
||||
elif isinstance(f, torch.Tensor):
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
|
||||
@@ -3,7 +3,6 @@ import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"max_micro_batch_size",
|
||||
"disable_shared_experts_fusion",
|
||||
"sampling_backend",
|
||||
"speculative_accept_threshold_acc",
|
||||
"speculative_accept_threshold_single",
|
||||
"speculative_accept_threshold_acc",
|
||||
"torchao_config",
|
||||
"triton_attention_reduce_in_fp32",
|
||||
"num_reserved_decode_tokens",
|
||||
@@ -180,7 +178,9 @@ class Modality(Enum):
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
"""
|
||||
A single multimodal data, from a single image/video/audio or others.
|
||||
One MultimodalDataItem contains all inputs for one modality.
|
||||
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
||||
One for images and one for audio.
|
||||
|
||||
We put the common fields first and the model-specific fields last.
|
||||
"""
|
||||
@@ -232,53 +232,7 @@ class MultimodalDataItem:
|
||||
"""
|
||||
Set the pad value after first hashing the data
|
||||
"""
|
||||
|
||||
def data_hash(data) -> int:
|
||||
hash_bytes = hashlib.sha256(data).digest()[:8]
|
||||
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
||||
|
||||
def tensor_hash(tensor_list) -> int:
|
||||
"""
|
||||
hash a tensor or a tensor list
|
||||
"""
|
||||
tensor = tensor_list
|
||||
if isinstance(tensor_list, list):
|
||||
tensor_list = flatten_nested_list(tensor_list)
|
||||
tensor_list = [
|
||||
x.flatten() if isinstance(x, torch.Tensor) else x
|
||||
for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
if tensor.is_cuda:
|
||||
return gpu_tensor_hash(tensor)
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
if tensor.is_cuda:
|
||||
# TODO: improve this
|
||||
tensor_cpu = tensor.cpu()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
|
||||
mv = memoryview(tensor_cpu.numpy())
|
||||
return data_hash(mv.tobytes())
|
||||
|
||||
def hash_feature(f):
|
||||
if isinstance(f, list):
|
||||
if isinstance(f[0], torch.Tensor):
|
||||
return tensor_hash(f)
|
||||
return data_hash(tuple(flatten_nested_list(f)))
|
||||
elif isinstance(f, np.ndarray):
|
||||
arr = np.ascontiguousarray(f)
|
||||
arr_bytes = arr.tobytes()
|
||||
return data_hash(arr_bytes)
|
||||
elif isinstance(f, torch.Tensor):
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
from sglang.srt.managers.mm_utils import hash_feature
|
||||
|
||||
if self.precomputed_features is not None:
|
||||
self.hash = hash_feature(self.precomputed_features)
|
||||
|
||||
@@ -418,14 +418,16 @@ class Scheduler(
|
||||
self.last_decode_stats_tic = time.perf_counter()
|
||||
self.last_prefill_stats_tic = time.perf_counter()
|
||||
self.return_health_check_ct = 0
|
||||
self.num_retracted_reqs: int = 0
|
||||
self.num_paused_reqs: int = 0
|
||||
self.kv_transfer_speed_gb_s: float = 0.0
|
||||
self.kv_transfer_latency_ms: float = 0.0
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.current_stream.synchronize = lambda: None # No-op for CPU
|
||||
self.forward_sleep_time = None
|
||||
|
||||
# Init session info
|
||||
self.sessions: Dict[str, Session] = {}
|
||||
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
if self.chunked_prefill_size <= 0: # -1 means disable
|
||||
@@ -473,26 +475,12 @@ class Scheduler(
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
# Init memory saver, profiler and metric stats
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Init profiler
|
||||
self.torch_profiler = None
|
||||
self.torch_profiler_output_dir: Optional[str] = None
|
||||
self.profiler_activities: Optional[List[str]] = None
|
||||
self.profile_id: Optional[str] = None
|
||||
self.profiler_target_forward_ct: Optional[int] = None
|
||||
self.profiler_target_prefill_ct: Optional[int] = None
|
||||
self.profiler_target_decode_ct: Optional[int] = None
|
||||
self.profiler_prefill_ct: Optional[int] = None
|
||||
self.profiler_decode_ct: Optional[int] = None
|
||||
self.profile_by_stage: bool = False
|
||||
self.profile_steps: Optional[int] = None
|
||||
self.profile_in_progress: bool = False
|
||||
self.rpd_profiler = None
|
||||
|
||||
# Init metrics stats
|
||||
self.init_profier()
|
||||
self.init_metrics()
|
||||
self.init_kv_events(server_args.kv_events_config)
|
||||
|
||||
@@ -526,6 +514,7 @@ class Scheduler(
|
||||
]
|
||||
)
|
||||
|
||||
# Init disaggregation
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
@@ -624,6 +613,21 @@ class Scheduler(
|
||||
)
|
||||
)
|
||||
|
||||
def init_profier(self):
|
||||
self.torch_profiler = None
|
||||
self.torch_profiler_output_dir: Optional[str] = None
|
||||
self.profiler_activities: Optional[List[str]] = None
|
||||
self.profile_id: Optional[str] = None
|
||||
self.profiler_target_forward_ct: Optional[int] = None
|
||||
self.profiler_target_prefill_ct: Optional[int] = None
|
||||
self.profiler_target_decode_ct: Optional[int] = None
|
||||
self.profiler_prefill_ct: Optional[int] = None
|
||||
self.profiler_decode_ct: Optional[int] = None
|
||||
self.profile_by_stage: bool = False
|
||||
self.profile_steps: Optional[int] = None
|
||||
self.profile_in_progress: bool = False
|
||||
self.rpd_profiler = None
|
||||
|
||||
def init_metrics(self):
|
||||
self.last_gen_throughput: float = 0.0
|
||||
self.last_input_throughput: float = 0.0
|
||||
@@ -2107,6 +2111,18 @@ class Scheduler(
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = dict(global_server_args_dict)
|
||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||
ret["memory_usage"] = {
|
||||
"weight": round(
|
||||
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
|
||||
),
|
||||
"kvcache": round(
|
||||
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
|
||||
),
|
||||
"cuda_graph": round(
|
||||
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
|
||||
),
|
||||
"token_capacity": int(self.max_total_num_tokens),
|
||||
}
|
||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||
ret["avg_spec_accept_length"] = (
|
||||
self.cum_spec_accept_length / self.cum_spec_accept_count
|
||||
|
||||
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
|
||||
stream_interval = (
|
||||
req.sampling_params.stream_interval or self.stream_interval
|
||||
)
|
||||
should_output = len(req.output_ids) % stream_interval == 0
|
||||
should_output = (
|
||||
len(req.output_ids) % stream_interval == 1
|
||||
if not self.model_config.is_multimodal_gen
|
||||
and stream_interval > 1
|
||||
else len(req.output_ids) % stream_interval == 0
|
||||
)
|
||||
else:
|
||||
should_output = (
|
||||
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
||||
and not self.model_config.is_multimodal_gen
|
||||
if not self.model_config.is_multimodal_gen
|
||||
else False
|
||||
)
|
||||
|
||||
if should_output:
|
||||
|
||||
@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
get_dummy_processor,
|
||||
get_mm_processor,
|
||||
import_processors,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -187,6 +183,8 @@ class TokenizerManager:
|
||||
if server_args.preferred_sampling_params
|
||||
else None
|
||||
)
|
||||
self.crash_dump_folder = server_args.crash_dump_folder
|
||||
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -251,10 +249,11 @@ class TokenizerManager:
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
self.crash_dump_request_list: deque[Tuple] = deque()
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
self.asyncio_tasks = set()
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
self.max_req_input_len = None
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
@@ -266,14 +265,14 @@ class TokenizerManager:
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.disaggregation_transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
@@ -324,7 +323,6 @@ class TokenizerManager:
|
||||
self.profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -484,7 +482,7 @@ class TokenizerManager:
|
||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
@@ -547,6 +545,14 @@ class TokenizerManager:
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
|
||||
def _validate_input_ids_in_vocab(
|
||||
self, input_ids: List[int], vocab_size: int
|
||||
) -> None:
|
||||
if any(id >= vocab_size for id in input_ids):
|
||||
raise ValueError(
|
||||
f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
|
||||
)
|
||||
|
||||
def _create_tokenized_object(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -1096,12 +1102,36 @@ class TokenizerManager:
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
"sampling_params",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(["text", "output_ids", "embedding"])
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 2048
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 2:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 3:
|
||||
max_length = 1 << 30
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -1118,6 +1148,8 @@ class TokenizerManager:
|
||||
self.dump_requests_folder = obj.dump_requests_folder
|
||||
if obj.dump_requests_threshold is not None:
|
||||
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||
if obj.crash_dump_folder is not None:
|
||||
self.crash_dump_folder = obj.crash_dump_folder
|
||||
logging.info(f"Config logging: {obj=}")
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
|
||||
@@ -1166,6 +1198,52 @@ class TokenizerManager:
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
def dump_requests_before_crash(self):
|
||||
if self.crash_dump_performed:
|
||||
logger.info(
|
||||
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
|
||||
)
|
||||
return
|
||||
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
||||
self.crash_dump_performed = True
|
||||
if not self.crash_dump_folder:
|
||||
return
|
||||
|
||||
data_to_dump = []
|
||||
if self.crash_dump_request_list:
|
||||
data_to_dump.extend(self.crash_dump_request_list)
|
||||
|
||||
# Add unfinished requests from rid_to_state
|
||||
unfinished_requests = []
|
||||
for rid, state in self.rid_to_state.items():
|
||||
if not state.finished:
|
||||
unfinished_requests.append(
|
||||
(state.obj, {}, state.created_time, time.time())
|
||||
)
|
||||
if unfinished_requests:
|
||||
data_to_dump.extend(unfinished_requests)
|
||||
|
||||
if not data_to_dump:
|
||||
return
|
||||
|
||||
filename = os.path.join(
|
||||
self.crash_dump_folder,
|
||||
os.getenv("HOSTNAME", None),
|
||||
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
# Include server_args in the dump
|
||||
data_to_dump_with_server_args = {
|
||||
"server_args": self.server_args,
|
||||
"requests": data_to_dump,
|
||||
}
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(data_to_dump_with_server_args, f)
|
||||
logger.error(
|
||||
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(5)
|
||||
@@ -1175,11 +1253,12 @@ class TokenizerManager:
|
||||
remain_num_req = len(self.rid_to_state)
|
||||
|
||||
if self.health_check_failed:
|
||||
# if health check failed, exit immediately
|
||||
# if health check failed, we should exit immediately
|
||||
logger.error(
|
||||
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
||||
remain_num_req,
|
||||
)
|
||||
self.dump_requests_before_crash()
|
||||
break
|
||||
|
||||
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
||||
@@ -1196,6 +1275,7 @@ class TokenizerManager:
|
||||
if remain_num_req > 0:
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
self.dump_requests_before_crash()
|
||||
break
|
||||
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
@@ -1273,16 +1353,7 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
if isinstance(recv_obj.outputs[i], str):
|
||||
out_dict = {
|
||||
"text": recv_obj.outputs[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
else:
|
||||
out_dict = {
|
||||
"outputs": json.dumps(recv_obj.outputs[i]),
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
raise NotImplementedError("BatchMultimodalOut not implemented")
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -1306,6 +1377,8 @@ class TokenizerManager:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||
self.dump_requests(state, out_dict)
|
||||
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
|
||||
self.record_request_for_crash_dump(state, out_dict)
|
||||
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
@@ -1317,6 +1390,9 @@ class TokenizerManager:
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj_index: int,
|
||||
):
|
||||
if recv_obj.input_token_logprobs_val is None:
|
||||
return
|
||||
|
||||
if len(recv_obj.input_token_logprobs_val) > 0:
|
||||
state.input_token_logprobs_val.extend(
|
||||
recv_obj.input_token_logprobs_val[recv_obj_index]
|
||||
@@ -1436,7 +1512,10 @@ class TokenizerManager:
|
||||
else 0
|
||||
)
|
||||
|
||||
if state.first_token_time == 0.0:
|
||||
if (
|
||||
state.first_token_time == 0.0
|
||||
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
||||
):
|
||||
state.first_token_time = state.last_time = time.time()
|
||||
state.last_completion_tokens = completion_tokens
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
@@ -1484,14 +1563,31 @@ class TokenizerManager:
|
||||
to_dump = self.dump_request_list
|
||||
self.dump_request_list = []
|
||||
|
||||
to_dump_with_server_args = {
|
||||
"server_args": self.server_args,
|
||||
"requests": to_dump,
|
||||
}
|
||||
|
||||
def background_task():
|
||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(to_dump, f)
|
||||
pickle.dump(to_dump_with_server_args, f)
|
||||
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
|
||||
current_time = time.time()
|
||||
self.crash_dump_request_list.append(
|
||||
(state.obj, out_dict, state.created_time, current_time)
|
||||
)
|
||||
# Remove requests older than 5 minutes based on finish time
|
||||
while (
|
||||
self.crash_dump_request_list
|
||||
and current_time - self.crash_dump_request_list[0][3] >= 300
|
||||
):
|
||||
self.crash_dump_request_list.popleft()
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
self.rid_to_state.pop(recv_obj.rid, None)
|
||||
|
||||
@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
||||
if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
|
||||
func.__self__.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1632,6 +1730,7 @@ class SignalHandler:
|
||||
logger.error(
|
||||
"Received sigquit from a child process. It usually means the child failed."
|
||||
)
|
||||
self.tokenizer_manager.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ class KVCache(abc.ABC):
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
self.mem_usage = 0
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
self.cpu_offloading_chunk_size = 8192
|
||||
@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache):
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
||||
)
|
||||
self.mem_usage = (k_size + v_size) / GB
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache):
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
assert hasattr(self, "kv_buffer")
|
||||
|
||||
@@ -604,12 +604,13 @@ class ModelRunner:
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.weight_load_mem_usage = before_avail_memory - after_avail_memory
|
||||
logger.info(
|
||||
f"Load weight end. "
|
||||
f"type={type(self.model).__name__}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"avail mem={after_avail_memory:.2f} GB, "
|
||||
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
||||
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
||||
)
|
||||
|
||||
# Handle the case where some ranks do not finish loading.
|
||||
@@ -1250,6 +1251,7 @@ class ModelRunner:
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
self.cuda_graph_runner = None
|
||||
self.cuda_graph_mem_usage = 0
|
||||
|
||||
if not self.is_generation:
|
||||
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
||||
@@ -1265,9 +1267,10 @@ class ModelRunner:
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
||||
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
|
||||
@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader):
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
|
||||
if get_bool_env_var("SGL_CPU_QUANTIZATION"):
|
||||
return load_model_with_cpu_quantization(
|
||||
self, model_config=model_config, device_config=device_config
|
||||
)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(
|
||||
@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader):
|
||||
return model.eval()
|
||||
|
||||
|
||||
def load_model_with_cpu_quantization(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
model = _initialize_model(
|
||||
model_config,
|
||||
self.load_config,
|
||||
)
|
||||
|
||||
if not isinstance(self, DummyModelLoader):
|
||||
model.load_weights(self._get_all_weights(model_config, model))
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
model.to(target_device)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Inference-only Mistral model."""
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
|
||||
|
||||
@@ -99,6 +99,7 @@ class ServerArgs:
|
||||
log_level_http: Optional[str] = None
|
||||
log_requests: bool = False
|
||||
log_requests_level: int = 0
|
||||
crash_dump_folder: Optional[str] = None
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
bucket_time_to_first_token: Optional[List[float]] = None
|
||||
@@ -927,8 +928,14 @@ class ServerArgs:
|
||||
"--log-requests-level",
|
||||
type=int,
|
||||
default=0,
|
||||
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
|
||||
choices=[0, 1, 2],
|
||||
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
||||
choices=[0, 1, 2, 3],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crash-dump-folder",
|
||||
type=str,
|
||||
default=ServerArgs.crash_dump_folder,
|
||||
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-time-cost",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
@@ -20,17 +21,21 @@ def warmup(name: str) -> callable:
|
||||
return decorator
|
||||
|
||||
|
||||
async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
|
||||
async def execute_warmups(
|
||||
disaggregation_mode: str,
|
||||
warmup_names: List[str],
|
||||
tokenizer_manager: TokenizerManager,
|
||||
):
|
||||
for warmup_name in warmup_names:
|
||||
if warmup_name not in _warmup_registry:
|
||||
logger.warning(f"Could not find custom warmup {warmup_name}")
|
||||
continue
|
||||
logger.info(f"Running warmup {warmup_name}")
|
||||
await _warmup_registry[warmup_name](tokenizer_manager)
|
||||
await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)
|
||||
|
||||
|
||||
@warmup("voice_chat")
|
||||
async def voice_chat(tokenizer_manager: TokenizerManager):
|
||||
async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):
|
||||
# this warms up the fused_moe triton kernels and caches them
|
||||
# if we don't do this we break real time inference for voice chat
|
||||
for i in tqdm.trange(1, 512):
|
||||
@@ -44,4 +49,8 @@ async def voice_chat(tokenizer_manager: TokenizerManager):
|
||||
"min_p": 0.0,
|
||||
},
|
||||
)
|
||||
if disaggregation_mode != "null":
|
||||
generate_req_input.bootstrap_room = 0
|
||||
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
|
||||
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||
|
||||
Reference in New Issue
Block a user