Cleaning codes for speculative attention mode (#10149)
This commit is contained in:
@@ -209,6 +209,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
||||
| `--speculative-accept-threshold-single` | Accept a draft token if its probability in the target model is greater than this threshold. | 1.0 |
|
||||
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | 1.0 |
|
||||
| `--speculative-token-map` | The path of the draft model's small vocab table. | None |
|
||||
| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | Prefill |
|
||||
|
||||
## Expert parallelism
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class HybridAttnBackend(AttentionBackend):
|
||||
|
||||
Note:
|
||||
- decode_or_idle: Always uses decode backend
|
||||
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
|
||||
- target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
|
||||
- prefill: Always uses prefill backend
|
||||
"""
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -42,8 +42,7 @@ class HybridAttnBackend(AttentionBackend):
|
||||
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
|
||||
return (
|
||||
self.decode_backend
|
||||
if self.model_runner.server_args.speculative_attention_backend
|
||||
== "decode"
|
||||
if self.model_runner.server_args.speculative_attention_mode == "decode"
|
||||
else self.prefill_backend
|
||||
)
|
||||
else:
|
||||
@@ -57,7 +56,7 @@ class HybridAttnBackend(AttentionBackend):
|
||||
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
if (
|
||||
self.model_runner.server_args.speculative_algorithm is not None
|
||||
and self.model_runner.server_args.speculative_attention_backend == "prefill"
|
||||
and self.model_runner.server_args.speculative_attention_mode == "prefill"
|
||||
):
|
||||
# When speculative decoding is enabled, we need to initialize the backend
|
||||
# that will be used for target_verify.
|
||||
|
||||
@@ -98,7 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"sampling_backend",
|
||||
"speculative_accept_threshold_single",
|
||||
"speculative_accept_threshold_acc",
|
||||
"speculative_attention_backend",
|
||||
"speculative_attention_mode",
|
||||
"torchao_config",
|
||||
"triton_attention_reduce_in_fp32",
|
||||
"num_reserved_decode_tokens",
|
||||
|
||||
@@ -1050,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
# Use the specified backend for speculative operations (both verify and draft extend)
|
||||
if global_server_args_dict["speculative_attention_backend"] == "decode":
|
||||
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||
else: # default to prefill
|
||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||
|
||||
@@ -262,7 +262,7 @@ class ServerArgs:
|
||||
speculative_accept_threshold_single: float = 1.0
|
||||
speculative_accept_threshold_acc: float = 1.0
|
||||
speculative_token_map: Optional[str] = None
|
||||
speculative_attention_backend: str = "prefill"
|
||||
speculative_attention_mode: str = "prefill"
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
@@ -1563,11 +1563,11 @@ class ServerArgs:
|
||||
default=ServerArgs.speculative_token_map,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-attention-backend",
|
||||
"--speculative-attention-mode",
|
||||
type=str,
|
||||
choices=["prefill", "decode"],
|
||||
help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.",
|
||||
default=ServerArgs.speculative_attention_backend,
|
||||
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
|
||||
default=ServerArgs.speculative_attention_mode,
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
|
||||
@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Initialize decode attention backend
|
||||
self.draft_attn_backend = self._create_decode_backend()
|
||||
|
||||
# Initialize draft extend attention backend (respects speculative_attention_backend setting)
|
||||
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
||||
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
||||
|
||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
@@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
}
|
||||
backend_name = (
|
||||
"decode_attention_backend"
|
||||
if self.server_args.speculative_attention_backend == "decode"
|
||||
if self.server_args.speculative_attention_mode == "decode"
|
||||
else "prefill_attention_backend"
|
||||
)
|
||||
return self._create_backend(
|
||||
|
||||
@@ -111,27 +111,6 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
|
||||
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
|
||||
|
||||
|
||||
class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
|
||||
speculative_decode = True
|
||||
# This eagle test uses a very small model, so the accuracy is low.
|
||||
accuracy_threshold = 0.2
|
||||
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
return DEFAULT_SERVER_ARGS + [
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
"3",
|
||||
"--speculative-eagle-topk",
|
||||
"2",
|
||||
"--speculative-num-draft-tokens",
|
||||
"4",
|
||||
]
|
||||
|
||||
|
||||
class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase):
|
||||
speculative_decode = True
|
||||
# This eagle test uses a very small model, so the accuracy is low.
|
||||
@@ -150,7 +129,7 @@ class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBacke
|
||||
"2",
|
||||
"--speculative-num-draft-tokens",
|
||||
"4",
|
||||
"--speculative-attention-backend",
|
||||
"--speculative-attention-mode",
|
||||
"prefill",
|
||||
]
|
||||
|
||||
@@ -173,7 +152,7 @@ class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBacken
|
||||
"2",
|
||||
"--speculative-num-draft-tokens",
|
||||
"4",
|
||||
"--speculative-attention-backend",
|
||||
"--speculative-attention-mode",
|
||||
"decode",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user