Support EAGLE draft extend CUDA graph (#6606)
Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
This commit is contained in:
@@ -1268,6 +1268,29 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
self.draft_extend_metadata = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.zeros(
|
||||
max_bs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
if self.topk > 1:
|
||||
self.target_verify_metadata_topk_normal = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
@@ -1508,6 +1531,32 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.target_verify_metadata_topk_normal[bs] = metadata
|
||||
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.max_seq_len_q = num_tokens_per_bs
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * num_tokens_per_bs + 1,
|
||||
num_tokens_per_bs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
||||
: (bs + 1)
|
||||
]
|
||||
metadata.page_table = self.draft_extend_metadata["page_table"][
|
||||
req_pool_indices, :
|
||||
]
|
||||
|
||||
self.draft_extend_metadata[bs] = metadata
|
||||
|
||||
if encoder_lens is not None:
|
||||
encoder_bs = encoder_lens.numel()
|
||||
@@ -1732,6 +1781,29 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata_expand.max_seq_len_k = (
|
||||
metadata_expand.cache_seqlens_int32.max().item()
|
||||
)
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
metadata.max_seq_len_q = accept_length.max().item()
|
||||
metadata.cu_seqlens_q[1:].copy_(
|
||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||
)
|
||||
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
||||
]
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
|
||||
if encoder_lens is not None:
|
||||
# Only support encoder size 1 for now
|
||||
|
||||
@@ -262,10 +262,14 @@ class ServerArgs:
|
||||
self.mem_fraction_static = 0.88
|
||||
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
||||
mem_fraction = self.mem_fraction_static
|
||||
# 15 GB + additional 3GB for cuda graph
|
||||
reserve_mem = 1024 * 18
|
||||
# need reserve more memory for spec cuda graph
|
||||
if self.speculative_algorithm is not None:
|
||||
reserve_mem = 1024 * 20
|
||||
self.mem_fraction_static = min(
|
||||
mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
|
||||
(gpu_mem - 1024 * 18)
|
||||
/ gpu_mem, # 15 GB + additional 3GB for cuda graph
|
||||
(gpu_mem - reserve_mem) / gpu_mem,
|
||||
)
|
||||
|
||||
# Set chunked prefill size, which depends on the gpu memory capacity
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CudaGraphRunner,
|
||||
LogitsProcessorOutput,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
|
||||
class EAGLEDraftExtendCudaGraphRunner:
|
||||
def __init__(self, eagle_worker: EAGLEWorker):
|
||||
# Parse args
|
||||
self.eagle_worker = eagle_worker
|
||||
self.model_runner = model_runner = eagle_worker.model_runner
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
self.padded_static_len = -1
|
||||
|
||||
# Attention backend
|
||||
self.num_tokens_per_bs = self.speculative_num_steps + 1
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
|
||||
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
|
||||
self.max_num_token
|
||||
)
|
||||
self.seq_len_fill_value = (
|
||||
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
# Graph inputs
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
|
||||
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
||||
self.hidden_states = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size * 3,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
else:
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture CUDA graph failed: {e}\n"
|
||||
"Possible solutions:\n"
|
||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
batch_size = forward_batch.seq_lens.numel()
|
||||
|
||||
is_bs_supported = (
|
||||
batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else batch_size <= self.max_bs
|
||||
)
|
||||
|
||||
return is_bs_supported
|
||||
|
||||
def capture(self):
|
||||
CudaGraphRunner.capture(self)
|
||||
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
num_tokens = bs * self.num_tokens_per_bs
|
||||
|
||||
# Graph inputs
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
extend_seq_lens = self.extend_seq_lens[:bs]
|
||||
accept_length = self.accept_length[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
accept_length=accept_length,
|
||||
)
|
||||
spec_info.positions = None
|
||||
|
||||
# Forward batch
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens.sum(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
attn_backend=self.eagle_worker.draft_extend_attn_backend,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
padded_static_len=self.padded_static_len,
|
||||
)
|
||||
|
||||
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs=bs,
|
||||
num_tokens=num_tokens,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||
|
||||
ret = self.eagle_worker.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
)
|
||||
|
||||
forward_batch.out_cache_loc = output_cache_loc_backup
|
||||
forward_batch.spec_info.hidden_states = hidden_states_backup
|
||||
return ret
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
run_once()
|
||||
|
||||
with torch.cuda.graph(
|
||||
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||
):
|
||||
out = run_once()
|
||||
|
||||
set_global_graph_memory_pool(graph.pool())
|
||||
return graph, out
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
# batch_size and num_seqs can be different in case there are finished examples
|
||||
# in the batch, which will not be counted as num_seqs
|
||||
raw_bs = forward_batch.batch_size
|
||||
num_tokens = forward_batch.input_ids.shape[0]
|
||||
assert raw_bs * self.num_tokens_per_bs == num_tokens
|
||||
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.accept_length.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
||||
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:num_tokens].copy_(forward_batch.positions)
|
||||
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
forward_batch.spec_info.positions = None
|
||||
if bs != raw_bs:
|
||||
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
||||
|
||||
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs=bs,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DRAFT_EXTEND,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
)
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
out = self.output_buffers[bs]
|
||||
if bs != raw_bs:
|
||||
forward_batch.spec_info.accept_length = self.accept_length[:raw_bs]
|
||||
out = LogitsProcessorOutput(
|
||||
next_token_logits=out.next_token_logits[:raw_bs],
|
||||
hidden_states=out.hidden_states[:raw_bs],
|
||||
)
|
||||
return out
|
||||
@@ -84,6 +84,7 @@ class EagleDraftInput:
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
speculative_num_steps: int,
|
||||
pad_input: bool = False,
|
||||
):
|
||||
assert len(self.verified_id) == len(batch.out_cache_loc)
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
@@ -111,6 +112,50 @@ class EagleDraftInput:
|
||||
batch.input_ids = self.verified_id
|
||||
self.verified_id = new_verified_id
|
||||
|
||||
if pad_input:
|
||||
batch_size = sum(not req.finished() for req in batch.reqs)
|
||||
# Total constant input length after padding
|
||||
static_len = speculative_num_steps + 1
|
||||
# Total size after padding
|
||||
padded_input_size = batch_size * static_len
|
||||
|
||||
padded_len = padded_input_size - batch.input_ids.shape[0]
|
||||
if padded_len > 0:
|
||||
new_input_ids = torch.nn.functional.pad(
|
||||
batch.input_ids, (0, padded_len), value=0
|
||||
)
|
||||
position_padding = torch.arange(
|
||||
padded_len, device=self.positions.device
|
||||
)
|
||||
new_positions = torch.cat([self.positions, position_padding])
|
||||
|
||||
# need dummy hidden states for the padded positions
|
||||
hidden_states_dim = self.hidden_states.shape[-1]
|
||||
new_hidden_states = torch.cat(
|
||||
[
|
||||
self.hidden_states,
|
||||
torch.zeros(
|
||||
(padded_len, hidden_states_dim),
|
||||
dtype=self.hidden_states.dtype,
|
||||
device=self.hidden_states.device,
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# allocate KV cache location for the padded tokens
|
||||
padded_cache_loc = torch.zeros(
|
||||
padded_len,
|
||||
dtype=batch.out_cache_loc.dtype,
|
||||
device=batch.out_cache_loc.device,
|
||||
)
|
||||
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
|
||||
|
||||
batch.input_ids = new_input_ids
|
||||
self.hidden_states = new_hidden_states
|
||||
self.positions = new_positions
|
||||
batch.out_cache_loc = new_out_cache_loc
|
||||
|
||||
def generate_attn_arg_prefill(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
|
||||
@@ -26,6 +26,9 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
||||
EAGLEDraftExtendCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
EagleDraftInput,
|
||||
EagleVerifyInput,
|
||||
@@ -189,6 +192,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
FlashAttentionMultiStepBackend,
|
||||
)
|
||||
|
||||
@@ -197,7 +201,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.draft_extend_attn_backend = FlashAttentionBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
@@ -242,7 +249,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Capture extend
|
||||
if self.draft_extend_attn_backend:
|
||||
raise NotImplementedError()
|
||||
tic = time.perf_counter()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
|
||||
self
|
||||
)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
@property
|
||||
def draft_model_runner(self):
|
||||
@@ -656,6 +674,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
|
||||
)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_logprob = False
|
||||
@@ -665,7 +684,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Run
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
can_cuda_graph = (
|
||||
self.cuda_graph_runner_for_draft_extend
|
||||
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
||||
)
|
||||
if can_cuda_graph:
|
||||
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
|
||||
forward_batch
|
||||
)
|
||||
else:
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
Reference in New Issue
Block a user