refactor EAGLE 2 (#3269)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: merrymercy <lianminzheng@gmail.com>
Co-authored-by: Ying1123 <sqy1415@gmail.com>
This commit is contained in:
Yineng Zhang
2025-02-03 20:52:30 +08:00
committed by GitHub
parent 3c8ac78dc1
commit 013021b6a1
9 changed files with 1271 additions and 687 deletions

View File

@@ -103,69 +103,75 @@ def set_torch_compile_config():
torch._dynamo.config.cache_size_limit = 1024
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs.
capture_bs = list(
sorted(
set(
capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
capture_bs = [
bs
for bs in capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= server_args.cuda_graph_max_bs
]
compile_bs = (
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
if server_args.enable_torch_compile
else []
)
return capture_bs, compile_bs
# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool = None
def get_global_graph_memory_pool():
return global_graph_memory_pool
def set_global_graph_memory_pool(val):
global global_graph_memory_pool
global_graph_memory_pool = val
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: "ModelRunner"):
def __init__(self, model_runner: ModelRunner):
# Parse args
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size
self.dp_size = self.model_runner.server_args.dp_size
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
# Batch sizes to capture
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
if self.capture_bs is None:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 33)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs.
self.capture_bs = list(
sorted(
set(
self.capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
self.capture_bs = [
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.cuda_graph_max_bs
]
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.torch_compile_max_bs
]
if self.use_torch_compile
else []
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_eagle_topk
)
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
@@ -182,10 +188,10 @@ class CudaGraphRunner:
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile:
if self.enable_torch_compile:
set_torch_compile_config()
# Common inputs
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
@@ -301,7 +307,7 @@ class CudaGraphRunner:
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
# Graph inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
@@ -320,7 +326,7 @@ class CudaGraphRunner:
global_num_tokens = None
gathered_buffer = None
spec_info = self.get_spec_info(num_tokens, positions)
spec_info = self.get_spec_info(num_tokens)
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
@@ -335,7 +341,6 @@ class CudaGraphRunner:
seq_lens_sum=seq_lens.sum(),
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=positions,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
@@ -375,13 +380,14 @@ class CudaGraphRunner:
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
self.graph_memory_pool = graph.pool()
global_graph_memory_pool = graph.pool()
return graph, out
def replay(self, forward_batch: ForwardBatch):
@@ -439,35 +445,26 @@ class CudaGraphRunner:
)
return logits_output
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import (
EAGLEDraftInput,
EagleVerifyInput,
)
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.model_runner.server_args)
spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
raise RuntimeError("This should not happen.")
else:
spec_info = EagleVerifyInput(
None,
None,
None,
None,
None,
None,
self.model_runner.server_args.speculative_num_draft_tokens,
draft_token=None,
custom_mask=torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
),
positions=None,
retrive_index=None,
retrive_cum_len=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
)
spec_info.custom_mask = torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
)
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
return spec_info

View File

@@ -197,64 +197,6 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
vision_start_token_id=hf_config.vision_start_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
@classmethod
def init_new(
cls,
@@ -337,7 +279,7 @@ class ForwardBatch:
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch)
ret._compute_mrope_positions(model_runner, batch)
# Init lora information
if model_runner.server_args.lora_paths is not None:
@@ -345,6 +287,63 @@ class ForwardBatch:
return ret
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
vision_start_token_id=hf_config.vision_start_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum

View File

@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs
@@ -714,8 +715,6 @@ class ModelRunner:
def init_cuda_graphs(self):
"""Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None
if not self.is_generation: