Cleaning codes for speculative attention mode (#10149)

This commit is contained in:
Baizhou Zhang
2025-09-08 17:38:06 -07:00
committed by GitHub
parent 148022fc36
commit 8ad700f735
7 changed files with 14 additions and 35 deletions

View File

@@ -34,7 +34,7 @@ class HybridAttnBackend(AttentionBackend):
Note:
- decode_or_idle: Always uses decode backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
- prefill: Always uses prefill backend
"""
if forward_mode.is_decode_or_idle():
@@ -42,8 +42,7 @@ class HybridAttnBackend(AttentionBackend):
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
return (
self.decode_backend
if self.model_runner.server_args.speculative_attention_backend
== "decode"
if self.model_runner.server_args.speculative_attention_mode == "decode"
else self.prefill_backend
)
else:
@@ -57,7 +56,7 @@ class HybridAttnBackend(AttentionBackend):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
if (
self.model_runner.server_args.speculative_algorithm is not None
and self.model_runner.server_args.speculative_attention_backend == "prefill"
and self.model_runner.server_args.speculative_attention_mode == "prefill"
):
# When speculative decoding is enabled, we need to initialize the backend
# that will be used for target_verify.

View File

@@ -98,7 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"sampling_backend",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"speculative_attention_backend",
"speculative_attention_mode",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",

View File

@@ -1050,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
or forward_batch.forward_mode.is_draft_extend()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_backend"] == "decode":
if global_server_args_dict["speculative_attention_mode"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"]
else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"]

View File

@@ -262,7 +262,7 @@ class ServerArgs:
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
speculative_attention_backend: str = "prefill"
speculative_attention_mode: str = "prefill"
# Expert parallelism
ep_size: int = 1
@@ -1563,11 +1563,11 @@ class ServerArgs:
default=ServerArgs.speculative_token_map,
)
parser.add_argument(
"--speculative-attention-backend",
"--speculative-attention-mode",
type=str,
choices=["prefill", "decode"],
help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_backend,
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_mode,
)
# Expert parallelism

View File

@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
# Initialize decode attention backend
self.draft_attn_backend = self._create_decode_backend()
# Initialize draft extend attention backend (respects speculative_attention_backend setting)
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self.draft_extend_attn_backend = self._create_draft_extend_backend()
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
@@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker):
}
backend_name = (
"decode_attention_backend"
if self.server_args.speculative_attention_backend == "decode"
if self.server_args.speculative_attention_mode == "decode"
else "prefill_attention_backend"
)
return self._create_backend(