Eagle speculative decoding part 2: Fix cuda graph + DP attention hanging (#2684)
Co-authored-by: yukavio <kavioyu@gmail.com>
This commit is contained in:
@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
# Compute logits
|
||||
last_logits = self._get_logits(last_hidden, lm_head)
|
||||
if not logits_metadata.extend_return_logprob:
|
||||
if (
|
||||
not logits_metadata.extend_return_logprob
|
||||
or logits_metadata.capture_hidden_mode.need_capture()
|
||||
):
|
||||
# Decode mode or extend mode without return_logprob.
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -565,9 +571,13 @@ class ScheduleBatch:
|
||||
# Has grammar
|
||||
has_grammar: bool = False
|
||||
|
||||
# device
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -577,6 +587,7 @@ class ScheduleBatch:
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -589,6 +600,7 @@ class ScheduleBatch:
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=speculative_algorithm,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1103,6 +1115,9 @@ class ScheduleBatch:
|
||||
self.has_stream |= other.has_stream
|
||||
self.has_grammar |= other.has_grammar
|
||||
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
@@ -1144,6 +1159,8 @@ class ScheduleBatch:
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
input_embeds=self.input_embeds,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
spec_info=self.spec_info,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
|
||||
# The input Embeds
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
|
||||
@@ -150,12 +150,18 @@ class TpModelWorker:
|
||||
self,
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
skip_sample: bool = False,
|
||||
):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
if launch_done:
|
||||
launch_done.set()
|
||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||
|
||||
if skip_sample:
|
||||
next_token_ids = None
|
||||
else:
|
||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
|
||||
@@ -375,9 +375,7 @@ class CudaGraphRunner:
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
raw_bs = forward_batch.batch_size
|
||||
# In normal decoding case, raw_bs == raw_num_token
|
||||
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
|
||||
raw_num_token = forward_batch.input_ids.numel()
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
|
||||
@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
|
||||
return self == ForwardMode.DRAFT_EXTEND
|
||||
|
||||
def is_cuda_graph(self):
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
|
||||
return (
|
||||
self == ForwardMode.DECODE
|
||||
or self == ForwardMode.TARGET_VERIFY
|
||||
or self == ForwardMode.IDLE
|
||||
)
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
@@ -161,15 +165,15 @@ class ForwardBatch:
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: SpecInfo = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: SpecInfo = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
@@ -258,6 +262,8 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
spec_algorithm=batch.spec_algorithm,
|
||||
spec_info=batch.spec_info,
|
||||
input_embeds=batch.input_embeds,
|
||||
)
|
||||
|
||||
|
||||
@@ -108,14 +108,6 @@ class ServerArgs:
|
||||
# Model override args in JSON
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
ds_channel_config_path: str = None
|
||||
ds_heavy_channel_num: int = 32
|
||||
ds_heavy_token_num: int = 256
|
||||
ds_heavy_channel_type: str = "qk"
|
||||
ds_sparse_decode_threshold: int = 4096
|
||||
|
||||
# LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
@@ -125,6 +117,21 @@ class ServerArgs:
|
||||
sampling_backend: Optional[str] = None
|
||||
grammar_backend: Optional[str] = "outlines"
|
||||
|
||||
# Speculative decoding
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_algorithm: Optional[str] = None
|
||||
speculative_num_steps: int = 5
|
||||
speculative_num_draft_tokens: int = 64
|
||||
speculative_eagle_topk: int = 8
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
ds_channel_config_path: str = None
|
||||
ds_heavy_channel_num: int = 32
|
||||
ds_heavy_token_num: int = 256
|
||||
ds_heavy_channel_type: str = "qk"
|
||||
ds_sparse_decode_threshold: int = 4096
|
||||
|
||||
# Optimization/debug options
|
||||
disable_radix_cache: bool = False
|
||||
disable_jump_forward: bool = False
|
||||
@@ -602,43 +609,6 @@ class ServerArgs:
|
||||
default=ServerArgs.json_model_override_args,
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
"--enable-double-sparsity",
|
||||
action="store_true",
|
||||
help="Enable double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-channel-config-path",
|
||||
type=str,
|
||||
default=ServerArgs.ds_channel_config_path,
|
||||
help="The path of the double sparsity channel config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-channel-num",
|
||||
type=int,
|
||||
default=ServerArgs.ds_heavy_channel_num,
|
||||
help="The number of heavy channels in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-token-num",
|
||||
type=int,
|
||||
default=ServerArgs.ds_heavy_token_num,
|
||||
help="The number of heavy tokens in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-channel-type",
|
||||
type=str,
|
||||
default=ServerArgs.ds_heavy_channel_type,
|
||||
help="The type of heavy channels in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-sparse-decode-threshold",
|
||||
type=int,
|
||||
default=ServerArgs.ds_sparse_decode_threshold,
|
||||
help="The type of heavy channels in double sparsity attention",
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
@@ -678,6 +648,75 @@ class ServerArgs:
|
||||
help="Choose the backend for grammar-guided decoding.",
|
||||
)
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument(
|
||||
"--speculative-algorithm",
|
||||
type=str,
|
||||
choices=["EAGLE"],
|
||||
help="Speculative algorithm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-draft-model-path",
|
||||
type=str,
|
||||
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-num-steps",
|
||||
type=int,
|
||||
help="The number of steps sampled from draft model in Speculative Decoding.",
|
||||
default=ServerArgs.speculative_num_steps,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-num-draft-tokens",
|
||||
type=int,
|
||||
help="The number of token sampled from draft model in Speculative Decoding.",
|
||||
default=ServerArgs.speculative_num_draft_tokens,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-eagle-topk",
|
||||
type=int,
|
||||
help="The number of token sampled from draft model in eagle2 each step.",
|
||||
choices=[1, 2, 4, 8],
|
||||
default=ServerArgs.speculative_eagle_topk,
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
"--enable-double-sparsity",
|
||||
action="store_true",
|
||||
help="Enable double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-channel-config-path",
|
||||
type=str,
|
||||
default=ServerArgs.ds_channel_config_path,
|
||||
help="The path of the double sparsity channel config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-channel-num",
|
||||
type=int,
|
||||
default=ServerArgs.ds_heavy_channel_num,
|
||||
help="The number of heavy channels in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-token-num",
|
||||
type=int,
|
||||
default=ServerArgs.ds_heavy_token_num,
|
||||
help="The number of heavy tokens in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-channel-type",
|
||||
type=str,
|
||||
default=ServerArgs.ds_heavy_channel_type,
|
||||
help="The type of heavy channels in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-sparse-decode-threshold",
|
||||
type=int,
|
||||
default=ServerArgs.ds_sparse_decode_threshold,
|
||||
help="The type of heavy channels in double sparsity attention",
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
|
||||
Reference in New Issue
Block a user