refactor EAGLE 2 (#3269)
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com>
This commit is contained in:
@@ -103,69 +103,75 @@ def set_torch_compile_config():
|
||||
torch._dynamo.config.cache_size_limit = 1024
|
||||
|
||||
|
||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
server_args = model_runner.server_args
|
||||
capture_bs = server_args.cuda_graph_bs
|
||||
if capture_bs is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 128]
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# is very samll. We add more values here to make sure we capture the maximum bs.
|
||||
capture_bs = list(
|
||||
sorted(
|
||||
set(
|
||||
capture_bs
|
||||
+ [model_runner.req_to_token_pool.size - 1]
|
||||
+ [model_runner.req_to_token_pool.size]
|
||||
)
|
||||
)
|
||||
)
|
||||
capture_bs = [
|
||||
bs
|
||||
for bs in capture_bs
|
||||
if bs <= model_runner.req_to_token_pool.size
|
||||
and bs <= server_args.cuda_graph_max_bs
|
||||
]
|
||||
compile_bs = (
|
||||
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||
if server_args.enable_torch_compile
|
||||
else []
|
||||
)
|
||||
return capture_bs, compile_bs
|
||||
|
||||
|
||||
# Reuse this memory pool across all cuda graph runners.
|
||||
global_graph_memory_pool = None
|
||||
|
||||
|
||||
def get_global_graph_memory_pool():
|
||||
return global_graph_memory_pool
|
||||
|
||||
|
||||
def set_global_graph_memory_pool(val):
|
||||
global global_graph_memory_pool
|
||||
global_graph_memory_pool = val
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: "ModelRunner"):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Parse args
|
||||
self.model_runner = model_runner
|
||||
self.graphs = {}
|
||||
self.input_buffers = {}
|
||||
self.output_buffers = {}
|
||||
self.flashinfer_handlers = {}
|
||||
self.graph_memory_pool = None
|
||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = self.model_runner.server_args.dp_size
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
||||
if self.capture_bs is None:
|
||||
if model_runner.server_args.disable_cuda_graph_padding:
|
||||
self.capture_bs = list(range(1, 33)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# is very samll. We add more values here to make sure we capture the maximum bs.
|
||||
self.capture_bs = list(
|
||||
sorted(
|
||||
set(
|
||||
self.capture_bs
|
||||
+ [model_runner.req_to_token_pool.size - 1]
|
||||
+ [model_runner.req_to_token_pool.size]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.capture_bs = [
|
||||
bs
|
||||
for bs in self.capture_bs
|
||||
if bs <= model_runner.req_to_token_pool.size
|
||||
and bs <= model_runner.server_args.cuda_graph_max_bs
|
||||
]
|
||||
|
||||
self.compile_bs = (
|
||||
[
|
||||
bs
|
||||
for bs in self.capture_bs
|
||||
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
||||
]
|
||||
if self.use_torch_compile
|
||||
else []
|
||||
)
|
||||
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
self.capture_forward_mode = ForwardMode.DECODE
|
||||
self.num_tokens_per_bs = 1
|
||||
if model_runner.spec_algorithm.is_eagle():
|
||||
if self.model_runner.is_draft_worker:
|
||||
self.num_tokens_per_bs = (
|
||||
self.model_runner.server_args.speculative_eagle_topk
|
||||
)
|
||||
raise RuntimeError("This should not happen")
|
||||
else:
|
||||
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
||||
self.num_tokens_per_bs = (
|
||||
@@ -182,10 +188,10 @@ class CudaGraphRunner:
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
|
||||
if self.use_torch_compile:
|
||||
if self.enable_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
# Common inputs
|
||||
# Graph inputs
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
@@ -301,7 +307,7 @@ class CudaGraphRunner:
|
||||
stream = self.stream
|
||||
num_tokens = bs * self.num_tokens_per_bs
|
||||
|
||||
# Common inputs
|
||||
# Graph inputs
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
@@ -320,7 +326,7 @@ class CudaGraphRunner:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
|
||||
spec_info = self.get_spec_info(num_tokens, positions)
|
||||
spec_info = self.get_spec_info(num_tokens)
|
||||
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
@@ -335,7 +341,6 @@ class CudaGraphRunner:
|
||||
seq_lens_sum=seq_lens.sum(),
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=positions,
|
||||
global_num_tokens=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
@@ -375,13 +380,14 @@ class CudaGraphRunner:
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
||||
global global_graph_memory_pool
|
||||
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
||||
out = run_once()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
self.graph_memory_pool = graph.pool()
|
||||
global_graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
@@ -439,35 +445,26 @@ class CudaGraphRunner:
|
||||
)
|
||||
return logits_output
|
||||
|
||||
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
|
||||
def get_spec_info(self, num_tokens: int):
|
||||
spec_info = None
|
||||
if self.model_runner.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
EAGLEDraftInput,
|
||||
EagleVerifyInput,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||
|
||||
if self.model_runner.is_draft_worker:
|
||||
spec_info = EAGLEDraftInput()
|
||||
spec_info.load_server_args(self.model_runner.server_args)
|
||||
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
||||
spec_info.positions = positions
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
raise RuntimeError("This should not happen.")
|
||||
else:
|
||||
spec_info = EagleVerifyInput(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
draft_token=None,
|
||||
custom_mask=torch.zeros(
|
||||
(num_tokens * self.model_runner.model_config.context_len),
|
||||
dtype=torch.bool,
|
||||
device="cuda",
|
||||
),
|
||||
positions=None,
|
||||
retrive_index=None,
|
||||
retrive_cum_len=None,
|
||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
)
|
||||
spec_info.custom_mask = torch.zeros(
|
||||
(num_tokens * self.model_runner.model_config.context_len),
|
||||
dtype=torch.bool,
|
||||
device="cuda",
|
||||
)
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
|
||||
return spec_info
|
||||
|
||||
@@ -197,64 +197,6 @@ class ForwardBatch:
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
def compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
device = model_runner.device
|
||||
hf_config = model_runner.model_config.hf_config
|
||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||
if self.forward_mode.is_decode():
|
||||
for i, _ in enumerate(mrope_positions_list):
|
||||
mrope_position_delta = (
|
||||
0
|
||||
if batch.image_inputs[i] is None
|
||||
else batch.image_inputs[i].mrope_position_delta
|
||||
)
|
||||
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||
mrope_position_delta,
|
||||
int(self.seq_lens[i]) - 1,
|
||||
int(self.seq_lens[i]),
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||
for i, image_inputs in enumerate(batch.image_inputs):
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
extend_start_loc_cpu[i],
|
||||
batch.extend_seq_lens[i],
|
||||
batch.extend_prefix_lens[i],
|
||||
)
|
||||
if image_inputs is None:
|
||||
# text only
|
||||
mrope_positions = [
|
||||
[
|
||||
pos
|
||||
for pos in range(
|
||||
extend_prefix_len, extend_prefix_len + extend_seq_len
|
||||
)
|
||||
]
|
||||
] * 3
|
||||
else:
|
||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||
mrope_positions, mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
input_tokens=self.input_ids[
|
||||
extend_start_loc : extend_start_loc + extend_seq_len
|
||||
],
|
||||
image_grid_thw=image_inputs.image_grid_thws,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||
context_len=0,
|
||||
)
|
||||
)
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
|
||||
self.mrope_positions = torch.concat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -337,7 +279,7 @@ class ForwardBatch:
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
if model_runner.model_is_mrope:
|
||||
ret.compute_mrope_positions(model_runner, batch)
|
||||
ret._compute_mrope_positions(model_runner, batch)
|
||||
|
||||
# Init lora information
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
@@ -345,6 +287,63 @@ class ForwardBatch:
|
||||
|
||||
return ret
|
||||
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
device = model_runner.device
|
||||
hf_config = model_runner.model_config.hf_config
|
||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||
if self.forward_mode.is_decode():
|
||||
for i, _ in enumerate(mrope_positions_list):
|
||||
mrope_position_delta = (
|
||||
0
|
||||
if batch.image_inputs[i] is None
|
||||
else batch.image_inputs[i].mrope_position_delta
|
||||
)
|
||||
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||
mrope_position_delta,
|
||||
int(self.seq_lens[i]) - 1,
|
||||
int(self.seq_lens[i]),
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||
for i, image_inputs in enumerate(batch.image_inputs):
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
extend_start_loc_cpu[i],
|
||||
batch.extend_seq_lens[i],
|
||||
batch.extend_prefix_lens[i],
|
||||
)
|
||||
if image_inputs is None:
|
||||
# text only
|
||||
mrope_positions = [
|
||||
[
|
||||
pos
|
||||
for pos in range(
|
||||
extend_prefix_len, extend_prefix_len + extend_seq_len
|
||||
)
|
||||
]
|
||||
] * 3
|
||||
else:
|
||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||
mrope_positions, mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
input_tokens=self.input_ids[
|
||||
extend_start_loc : extend_start_loc + extend_seq_len
|
||||
],
|
||||
image_grid_thw=image_inputs.image_grid_thws,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||
context_len=0,
|
||||
)
|
||||
)
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
self.mrope_positions = torch.concat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
|
||||
def compute_position_triton(
|
||||
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -714,8 +715,6 @@ class ModelRunner:
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
self.cuda_graph_runner = None
|
||||
|
||||
if not self.is_generation:
|
||||
|
||||
Reference in New Issue
Block a user