Eagle speculative decoding part 2: Fix cuda graph + DP attention hanging (#2684)
Co-authored-by: yukavio <kavioyu@gmail.com>
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -92,7 +92,7 @@ jobs:
|
|||||||
python3 test_data_parallelism.py
|
python3 test_data_parallelism.py
|
||||||
|
|
||||||
- name: Evaluate MLA accuracy (TP=2)
|
- name: Evaluate MLA accuracy (TP=2)
|
||||||
timeout-minutes: 20
|
timeout-minutes: 10
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_mla.py
|
python3 test_mla.py
|
||||||
|
|||||||
@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
# Compute logits
|
# Compute logits
|
||||||
last_logits = self._get_logits(last_hidden, lm_head)
|
last_logits = self._get_logits(last_hidden, lm_head)
|
||||||
if not logits_metadata.extend_return_logprob:
|
if (
|
||||||
|
not logits_metadata.extend_return_logprob
|
||||||
|
or logits_metadata.capture_hidden_mode.need_capture()
|
||||||
|
):
|
||||||
# Decode mode or extend mode without return_logprob.
|
# Decode mode or extend mode without return_logprob.
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
# Copyright 2023-2024 SGLang Team
|
# Copyright 2023-2024 SGLang Team
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||||
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -565,9 +571,13 @@ class ScheduleBatch:
|
|||||||
# Has grammar
|
# Has grammar
|
||||||
has_grammar: bool = False
|
has_grammar: bool = False
|
||||||
|
|
||||||
# device
|
# Device
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
spec_info: Optional[SpecInfo] = None
|
||||||
|
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -577,6 +587,7 @@ class ScheduleBatch:
|
|||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
|
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -589,6 +600,7 @@ class ScheduleBatch:
|
|||||||
has_stream=any(req.stream for req in reqs),
|
has_stream=any(req.stream for req in reqs),
|
||||||
has_grammar=any(req.grammar for req in reqs),
|
has_grammar=any(req.grammar for req in reqs),
|
||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
|
spec_algorithm=speculative_algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1103,6 +1115,9 @@ class ScheduleBatch:
|
|||||||
self.has_stream |= other.has_stream
|
self.has_stream |= other.has_stream
|
||||||
self.has_grammar |= other.has_grammar
|
self.has_grammar |= other.has_grammar
|
||||||
|
|
||||||
|
if self.spec_info:
|
||||||
|
self.spec_info.merge_batch(other.spec_info)
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||||
@@ -1144,6 +1159,8 @@ class ScheduleBatch:
|
|||||||
lora_paths=[req.lora_path for req in self.reqs],
|
lora_paths=[req.lora_path for req in self.reqs],
|
||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
input_embeds=self.input_embeds,
|
input_embeds=self.input_embeds,
|
||||||
|
spec_algorithm=self.spec_algorithm,
|
||||||
|
spec_info=self.spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
|
|||||||
# The input Embeds
|
# The input Embeds
|
||||||
input_embeds: Optional[torch.tensor] = None
|
input_embeds: Optional[torch.tensor] = None
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
spec_info: Optional[SpecInfo] = None
|
||||||
|
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_req_to_token_pool_triton(
|
def write_req_to_token_pool_triton(
|
||||||
|
|||||||
@@ -150,12 +150,18 @@ class TpModelWorker:
|
|||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
launch_done: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
|
skip_sample: bool = False,
|
||||||
):
|
):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
if launch_done:
|
if launch_done:
|
||||||
launch_done.set()
|
launch_done.set()
|
||||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
|
||||||
|
if skip_sample:
|
||||||
|
next_token_ids = None
|
||||||
|
else:
|
||||||
|
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||||
|
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
|||||||
@@ -375,9 +375,7 @@ class CudaGraphRunner:
|
|||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
# In normal decoding case, raw_bs == raw_num_token
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
|
|
||||||
raw_num_token = forward_batch.input_ids.numel()
|
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
|
|||||||
@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
|
|||||||
return self == ForwardMode.DRAFT_EXTEND
|
return self == ForwardMode.DRAFT_EXTEND
|
||||||
|
|
||||||
def is_cuda_graph(self):
|
def is_cuda_graph(self):
|
||||||
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
|
return (
|
||||||
|
self == ForwardMode.DECODE
|
||||||
|
or self == ForwardMode.TARGET_VERIFY
|
||||||
|
or self == ForwardMode.IDLE
|
||||||
|
)
|
||||||
|
|
||||||
def is_dummy_first(self):
|
def is_dummy_first(self):
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
@@ -161,15 +165,15 @@ class ForwardBatch:
|
|||||||
token_to_kv_pool: BaseTokenToKVPool = None
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
attn_backend: AttentionBackend = None
|
attn_backend: AttentionBackend = None
|
||||||
|
|
||||||
# Speculative decoding
|
|
||||||
spec_info: SpecInfo = None
|
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
|
||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]] = None
|
global_num_tokens: Optional[List[int]] = None
|
||||||
gathered_buffer: Optional[torch.Tensor] = None
|
gathered_buffer: Optional[torch.Tensor] = None
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
spec_info: SpecInfo = None
|
||||||
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
|
|
||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
mrope_positions: torch.Tensor = None
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
@@ -258,6 +262,8 @@ class ForwardBatch:
|
|||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
|
spec_algorithm=batch.spec_algorithm,
|
||||||
|
spec_info=batch.spec_info,
|
||||||
input_embeds=batch.input_embeds,
|
input_embeds=batch.input_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -108,14 +108,6 @@ class ServerArgs:
|
|||||||
# Model override args in JSON
|
# Model override args in JSON
|
||||||
json_model_override_args: str = "{}"
|
json_model_override_args: str = "{}"
|
||||||
|
|
||||||
# Double Sparsity
|
|
||||||
enable_double_sparsity: bool = False
|
|
||||||
ds_channel_config_path: str = None
|
|
||||||
ds_heavy_channel_num: int = 32
|
|
||||||
ds_heavy_token_num: int = 256
|
|
||||||
ds_heavy_channel_type: str = "qk"
|
|
||||||
ds_sparse_decode_threshold: int = 4096
|
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_paths: Optional[List[str]] = None
|
lora_paths: Optional[List[str]] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
@@ -125,6 +117,21 @@ class ServerArgs:
|
|||||||
sampling_backend: Optional[str] = None
|
sampling_backend: Optional[str] = None
|
||||||
grammar_backend: Optional[str] = "outlines"
|
grammar_backend: Optional[str] = "outlines"
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
speculative_draft_model_path: Optional[str] = None
|
||||||
|
speculative_algorithm: Optional[str] = None
|
||||||
|
speculative_num_steps: int = 5
|
||||||
|
speculative_num_draft_tokens: int = 64
|
||||||
|
speculative_eagle_topk: int = 8
|
||||||
|
|
||||||
|
# Double Sparsity
|
||||||
|
enable_double_sparsity: bool = False
|
||||||
|
ds_channel_config_path: str = None
|
||||||
|
ds_heavy_channel_num: int = 32
|
||||||
|
ds_heavy_token_num: int = 256
|
||||||
|
ds_heavy_channel_type: str = "qk"
|
||||||
|
ds_sparse_decode_threshold: int = 4096
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_jump_forward: bool = False
|
disable_jump_forward: bool = False
|
||||||
@@ -602,43 +609,6 @@ class ServerArgs:
|
|||||||
default=ServerArgs.json_model_override_args,
|
default=ServerArgs.json_model_override_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Double Sparsity
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-double-sparsity",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable double sparsity attention",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ds-channel-config-path",
|
|
||||||
type=str,
|
|
||||||
default=ServerArgs.ds_channel_config_path,
|
|
||||||
help="The path of the double sparsity channel config",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ds-heavy-channel-num",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.ds_heavy_channel_num,
|
|
||||||
help="The number of heavy channels in double sparsity attention",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ds-heavy-token-num",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.ds_heavy_token_num,
|
|
||||||
help="The number of heavy tokens in double sparsity attention",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ds-heavy-channel-type",
|
|
||||||
type=str,
|
|
||||||
default=ServerArgs.ds_heavy_channel_type,
|
|
||||||
help="The type of heavy channels in double sparsity attention",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ds-sparse-decode-threshold",
|
|
||||||
type=int,
|
|
||||||
default=ServerArgs.ds_sparse_decode_threshold,
|
|
||||||
help="The type of heavy channels in double sparsity attention",
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-paths",
|
"--lora-paths",
|
||||||
@@ -678,6 +648,75 @@ class ServerArgs:
|
|||||||
help="Choose the backend for grammar-guided decoding.",
|
help="Choose the backend for grammar-guided decoding.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Speculative decoding
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-algorithm",
|
||||||
|
type=str,
|
||||||
|
choices=["EAGLE"],
|
||||||
|
help="Speculative algorithm.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
type=str,
|
||||||
|
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-num-steps",
|
||||||
|
type=int,
|
||||||
|
help="The number of steps sampled from draft model in Speculative Decoding.",
|
||||||
|
default=ServerArgs.speculative_num_steps,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
type=int,
|
||||||
|
help="The number of token sampled from draft model in Speculative Decoding.",
|
||||||
|
default=ServerArgs.speculative_num_draft_tokens,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
type=int,
|
||||||
|
help="The number of token sampled from draft model in eagle2 each step.",
|
||||||
|
choices=[1, 2, 4, 8],
|
||||||
|
default=ServerArgs.speculative_eagle_topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Double Sparsity
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-double-sparsity",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable double sparsity attention",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds-channel-config-path",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.ds_channel_config_path,
|
||||||
|
help="The path of the double sparsity channel config",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds-heavy-channel-num",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.ds_heavy_channel_num,
|
||||||
|
help="The number of heavy channels in double sparsity attention",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds-heavy-token-num",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.ds_heavy_token_num,
|
||||||
|
help="The number of heavy tokens in double sparsity attention",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds-heavy-channel-type",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.ds_heavy_channel_type,
|
||||||
|
help="The type of heavy channels in double sparsity attention",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds-sparse-decode-threshold",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.ds_sparse_decode_threshold,
|
||||||
|
help="The type of heavy channels in double sparsity attention",
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-radix-cache",
|
"--disable-radix-cache",
|
||||||
|
|||||||
Reference in New Issue
Block a user