[Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -286,6 +286,14 @@ class ServerArgs:
|
||||
speculative_accept_threshold_acc: float = 1.0
|
||||
speculative_token_map: Optional[str] = None
|
||||
speculative_attention_mode: str = "prefill"
|
||||
# For lookahead only
|
||||
speculative_lookahead_min_match_window_size: int = 1
|
||||
speculative_lookahead_max_match_window_size: int = 12
|
||||
speculative_lookahead_min_bfs_breadth: int = 1
|
||||
speculative_lookahead_max_bfs_breadth: int = 10
|
||||
speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS"
|
||||
speculative_lookahead_branch_length: int = 18
|
||||
speculative_lookahead_capacity: int = 10 * 1000 * 1000
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
@@ -529,7 +537,7 @@ class ServerArgs:
|
||||
# Standalone speculative decoding needs more memory than other speculative
|
||||
# decoding algorithms since the draft model is typically larger.
|
||||
reserved_mem += 6 * 1024
|
||||
else:
|
||||
elif self.speculative_algorithm != "LOOKAHEAD":
|
||||
reserved_mem += 2 * 1024
|
||||
if self.enable_dp_attention:
|
||||
reserved_mem += 4 * 1024
|
||||
@@ -780,11 +788,11 @@ class ServerArgs:
|
||||
self.speculative_algorithm = "EAGLE"
|
||||
|
||||
if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"):
|
||||
if self.speculative_algorithm == "STANDALONE":
|
||||
if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention:
|
||||
# TODO: support dp attention for standalone speculative decoding
|
||||
assert (
|
||||
self.enable_dp_attention is False
|
||||
), "Currently standalone speculative decoding does not support dp attention."
|
||||
raise ValueError(
|
||||
"Currently standalone speculative decoding does not support dp attention."
|
||||
)
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
@@ -858,6 +866,39 @@ class ServerArgs:
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||
|
||||
if self.speculative_algorithm == "LOOKAHEAD":
|
||||
if not self.device.startswith("cuda"):
|
||||
raise ValueError(
|
||||
"Lookahead speculative decoding only supports CUDA device."
|
||||
)
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
self.enable_mixed_chunk = False
|
||||
self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth
|
||||
if self.speculative_num_draft_tokens is None:
|
||||
# TODO: Do better auto choose in the future
|
||||
self.speculative_num_draft_tokens = (
|
||||
self.speculative_lookahead_max_match_window_size
|
||||
)
|
||||
logger.warning(
|
||||
"The overlap scheduler and mixed chunked prefill are disabled because of "
|
||||
"using lookahead speculative decoding."
|
||||
)
|
||||
if (
|
||||
self.speculative_eagle_topk > 1
|
||||
and self.page_size > 1
|
||||
and self.attention_backend != "flashinfer"
|
||||
):
|
||||
raise ValueError(
|
||||
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
||||
)
|
||||
|
||||
if self.enable_dp_attention:
|
||||
# TODO: support dp attention for lookahead speculative decoding
|
||||
raise ValueError(
|
||||
"Currently lookahead speculative decoding does not support dp attention."
|
||||
)
|
||||
# GGUF
|
||||
if (
|
||||
self.load_format == "auto" or self.load_format == "gguf"
|
||||
@@ -1690,7 +1731,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--speculative-algorithm",
|
||||
type=str,
|
||||
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"],
|
||||
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"],
|
||||
help="Speculative algorithm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -1750,6 +1791,50 @@ class ServerArgs:
|
||||
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
|
||||
default=ServerArgs.speculative_attention_mode,
|
||||
)
|
||||
# Lookahead speculative decoding
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-min-match-window-size",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_min_match_window_size,
|
||||
help="The minimum window size for pattern matching in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-max-match-window-size",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_max_match_window_size,
|
||||
help="The maximum window size for pattern matching in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-min-bfs-breadth",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_min_bfs_breadth,
|
||||
help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-max-bfs-breadth",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_max_bfs_breadth,
|
||||
help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-match-type",
|
||||
type=str,
|
||||
choices=["BFS", "PROB"],
|
||||
default=ServerArgs.speculative_lookahead_match_type,
|
||||
help="The match type for cache tree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-branch-length",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_branch_length,
|
||||
help="The branch length for lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-capacity",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_capacity,
|
||||
help="The cache capacity for lookahead speculative decoding.",
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user