SWA Prefix Cache (#7367)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Hanming Lu
2025-07-13 12:31:07 -07:00
committed by GitHub
parent 0c55cbcfc5
commit 9379da77de
16 changed files with 1742 additions and 158 deletions

View File

@@ -63,6 +63,7 @@ class ServerArgs:
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8
impl: str = "auto"
# Port for the HTTP server
@@ -225,6 +226,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
warmups: Optional[str] = None
disable_hybrid_swa_memory: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
@@ -481,14 +483,22 @@ class ServerArgs:
model_arch = get_model_arch(self)
# Auto set draft_model_path DeepSeek-V3/R1
if model_arch == "DeepseekV3ForCausalLM":
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path
else:
logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path."
)
elif "Llama4" in model_arch:
# TODO: remove this after Llama4 supports in other backends
if self.attention_backend != "fa3":
self.attention_backend = "fa3"
logger.warning(
"Llama4 requires using fa3 attention backend. "
"Attention backend is automatically set to fa3."
)
# Auto choose parameters
if self.speculative_num_steps is None:
@@ -852,6 +862,18 @@ class ServerArgs:
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
),
)
parser.add_argument(
"--swa-full-tokens-ratio",
type=float,
default=ServerArgs.swa_full_tokens_ratio,
help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
"E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
)
parser.add_argument(
"--disable-hybrid-swa-memory",
action="store_true",
help="Disable the hybrid SWA memory.",
)
# Other runtime options
parser.add_argument(
@@ -1730,10 +1752,6 @@ class ServerArgs:
else:
self.lora_paths[lora_path] = lora_path
model_arch = get_model_arch(self)
if "Llama4" in model_arch and self.hybrid_kvcache_ratio is not None:
assert self.attention_backend == "fa3"
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""