Enable cuda graph by default (#612)
This commit is contained in:
@@ -30,7 +30,6 @@ import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
|
||||
@@ -8,36 +8,40 @@ class GlobalConfig:
|
||||
# 2: output final text after every run
|
||||
self.verbosity = 0
|
||||
|
||||
# Default backend of the language
|
||||
self.default_backend = None
|
||||
|
||||
# Output configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
self.spaces_between_special_tokens_in_out = True
|
||||
|
||||
# Optimization configs
|
||||
self.eager_fill_image = False
|
||||
self.enable_precache_with_tracing = True
|
||||
self.enable_parallel_encoding = True
|
||||
self.enable_parallel_decoding = True
|
||||
|
||||
# Choices: ["no_adjust", "adjust_cache"]
|
||||
# no_adjust: Do not adjust the position embedding of KV cache.
|
||||
# adjust_cache: Adjust the position embedding of KV cache.
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
# Request dependency time due to network delay
|
||||
# Runtime constants: Request dependency time due to network delay
|
||||
self.request_dependency_delay = 0.02
|
||||
self.wait_for_new_request_delay = 0.0006
|
||||
|
||||
# New generation token ratio estimation
|
||||
# Runtime constants: New generation token ratio estimation
|
||||
self.base_new_token_ratio = 0.4
|
||||
self.base_min_new_token_ratio = 0.2
|
||||
self.new_token_ratio_decay = 0.0001
|
||||
self.new_token_ratio_recovery = 0.05
|
||||
|
||||
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||
# This can improve the speed for large batch sizes during prefill.
|
||||
self.layer_sync_threshold = 8192
|
||||
|
||||
# Runtime constants: Flashinfer
|
||||
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
||||
|
||||
# Output tokenization configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
self.spaces_between_special_tokens_in_out = True
|
||||
|
||||
# Interpreter optimization configs
|
||||
self.eager_fill_image = False
|
||||
self.enable_precache_with_tracing = True
|
||||
self.enable_parallel_encoding = True
|
||||
self.enable_parallel_decoding = True
|
||||
|
||||
# Deprecated
|
||||
# Choices: ["no_adjust", "adjust_cache"]
|
||||
# no_adjust: Do not adjust the position embedding of KV cache.
|
||||
# adjust_cache: Adjust the position embedding of KV cache.
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
173
python/sglang/srt/managers/controller/cuda_graph_runner.py
Normal file
173
python/sglang/srt/managers/controller/cuda_graph_runner.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Run the model with cuda graph."""
|
||||
|
||||
import bisect
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
Batch, ForwardMode, InputMetadata, init_flashinfer_args
|
||||
)
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
def __init__(self, model_runner, max_batch_size_to_capture):
|
||||
self.model_runner = model_runner
|
||||
self.graphs = {}
|
||||
self.input_buffers = {}
|
||||
self.output_buffers = {}
|
||||
self.flashinfer_handlers = {}
|
||||
self.graph_memory_pool = None
|
||||
|
||||
# Common inputs
|
||||
self.max_bs = max_batch_size_to_capture
|
||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
|
||||
# FlashInfer inputs
|
||||
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
|
||||
self.flashinfer_kv_indptr = torch.zeros(
|
||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.flashinfer_kv_indices = torch.zeros(
|
||||
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.flashinfer_kv_last_page_len = torch.ones(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
def can_run(self, batch_size):
|
||||
return batch_size < self.max_bs
|
||||
|
||||
def capture(self, batch_size_list):
|
||||
self.batch_size_list = batch_size_list
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
for bs in batch_size_list:
|
||||
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
|
||||
self.graphs[bs] = graph
|
||||
self.input_buffers[bs] = input_buffers
|
||||
self.output_buffers[bs] = output_buffers
|
||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
||||
|
||||
def capture_one_batch_size(self, bs):
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
|
||||
# Common inputs
|
||||
input_ids = self.input_ids[:bs]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
position_ids_offsets = self.position_ids_offsets[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
|
||||
# FlashInfer inputs
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
|
||||
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
||||
):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||
)
|
||||
init_flashinfer_args(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
input_metadata = InputMetadata.create(
|
||||
self.model_runner,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=None,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=0,
|
||||
skip_flashinfer_init=True,
|
||||
)
|
||||
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
||||
return self.model_runner.model.forward(
|
||||
input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
for _ in range(2):
|
||||
run_once()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
||||
out = run_once()
|
||||
torch.cuda.synchronize()
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, None, out, flashinfer_decode_wrapper
|
||||
|
||||
def replay(self, batch: Batch):
|
||||
assert batch.out_cache_loc is not None
|
||||
assert not batch.return_logprob
|
||||
raw_bs = len(batch.reqs)
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
||||
bs = self.batch_size_list[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_bs] = batch.input_ids
|
||||
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = batch.seq_lens
|
||||
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||
|
||||
# FlashInfer inputs
|
||||
init_flashinfer_args(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
self.req_pool_indices[:bs],
|
||||
self.seq_lens[:bs],
|
||||
None,
|
||||
self.flashinfer_handlers[bs],
|
||||
)
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
output = self.output_buffers[bs]
|
||||
|
||||
# Unpad
|
||||
if bs == raw_bs:
|
||||
return output
|
||||
else:
|
||||
output = LogitProcessorOutput(
|
||||
next_token_logits=output.next_token_logits[:raw_bs],
|
||||
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
|
||||
)
|
||||
return output
|
||||
@@ -675,7 +675,11 @@ class Batch:
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
||||
-1
|
||||
)
|
||||
@@ -757,9 +761,11 @@ class InputMetadata:
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
skip_flashinfer_init=False,
|
||||
):
|
||||
if not model_runner.server_args.disable_flashinfer:
|
||||
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens)
|
||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
||||
model_runner.flashinfer_decode_wrapper)
|
||||
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
@@ -826,7 +832,8 @@ class InputMetadata:
|
||||
return ret
|
||||
|
||||
|
||||
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens):
|
||||
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
||||
flashinfer_decode_wrapper):
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
@@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
model_runner.flashinfer_decode_wrapper.end_forward()
|
||||
model_runner.flashinfer_decode_wrapper.begin_forward(
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -90,6 +91,9 @@ class ModelRunner:
|
||||
self.init_cublas()
|
||||
self.init_flash_infer()
|
||||
|
||||
# Capture cuda graphs
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
||||
@@ -203,17 +207,51 @@ class ModelRunner:
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
workspace_buffers = torch.empty(
|
||||
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[0], "NHD"
|
||||
self.flashinfer_workspace_buffers[0], "NHD"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[1], "NHD"
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
||||
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
||||
self.cuda_graph_runner = None
|
||||
return
|
||||
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
||||
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: Batch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
out_cache_cont_start=batch.out_cache_cont_start,
|
||||
out_cache_cont_end=batch.out_cache_cont_end,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -233,25 +271,6 @@ class ModelRunner:
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: Batch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
out_cache_cont_start=batch.out_cache_cont_start,
|
||||
out_cache_cont_end=batch.out_cache_cont_end,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: Batch):
|
||||
input_metadata = InputMetadata.create(
|
||||
|
||||
@@ -98,7 +98,7 @@ class ModelTpServer:
|
||||
)
|
||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||
self.max_prefill_tokens = (
|
||||
4096
|
||||
8192
|
||||
if server_args.max_prefill_tokens is None
|
||||
else server_args.max_prefill_tokens
|
||||
)
|
||||
@@ -314,11 +314,9 @@ class ModelTpServer:
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
||||
if (
|
||||
self.running_batch is not None
|
||||
and len(self.running_batch.reqs) > self.max_running_requests
|
||||
):
|
||||
return None
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
if running_bs > self.max_running_requests:
|
||||
return
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.forward_queue:
|
||||
@@ -394,6 +392,10 @@ class ModelTpServer:
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
else:
|
||||
break
|
||||
|
||||
if running_bs + len(can_run_list) > self.max_running_requests:
|
||||
break
|
||||
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
|
||||
@@ -38,7 +38,10 @@ class ReqToTokenPool:
|
||||
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
||||
self.size = size
|
||||
# mem_state is the reference counter.
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
|
||||
self.total_ref_ct = 0
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
@@ -47,6 +50,8 @@ class TokenToKVPool:
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.clear()
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
return self.kv_data[layer_id][:, 0]
|
||||
|
||||
@@ -101,3 +106,6 @@ class TokenToKVPool:
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.total_ref_ct = 0
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
||||
@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
if server_args.disable_disk_cache:
|
||||
|
||||
@@ -29,7 +29,7 @@ class ServerArgs:
|
||||
max_prefill_tokens: Optional[int] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
schedule_heuristic: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
schedule_conservativeness: float = 0.8
|
||||
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
@@ -68,13 +68,13 @@ class ServerArgs:
|
||||
self.tokenizer_path = self.model_path
|
||||
if self.mem_fraction_static is None:
|
||||
if self.tp_size >= 8:
|
||||
self.mem_fraction_static = 0.80
|
||||
self.mem_fraction_static = 0.78
|
||||
elif self.tp_size >= 4:
|
||||
self.mem_fraction_static = 0.82
|
||||
self.mem_fraction_static = 0.80
|
||||
elif self.tp_size >= 2:
|
||||
self.mem_fraction_static = 0.85
|
||||
else:
|
||||
self.mem_fraction_static = 0.90
|
||||
self.mem_fraction_static = 0.88
|
||||
if isinstance(self.additional_ports, int):
|
||||
self.additional_ports = [self.additional_ports]
|
||||
elif self.additional_ports is None:
|
||||
|
||||
Reference in New Issue
Block a user