SWA Prefix Cache (#7367)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -63,6 +63,7 @@ class ServerArgs:
|
||||
enable_multimodal: Optional[bool] = None
|
||||
revision: Optional[str] = None
|
||||
hybrid_kvcache_ratio: Optional[float] = None
|
||||
swa_full_tokens_ratio: float = 0.8
|
||||
impl: str = "auto"
|
||||
|
||||
# Port for the HTTP server
|
||||
@@ -225,6 +226,7 @@ class ServerArgs:
|
||||
enable_return_hidden_states: bool = False
|
||||
enable_triton_kernel_moe: bool = False
|
||||
warmups: Optional[str] = None
|
||||
disable_hybrid_swa_memory: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -481,14 +483,22 @@ class ServerArgs:
|
||||
|
||||
model_arch = get_model_arch(self)
|
||||
|
||||
# Auto set draft_model_path DeepSeek-V3/R1
|
||||
if model_arch == "DeepseekV3ForCausalLM":
|
||||
# Auto set draft_model_path DeepSeek-V3/R1
|
||||
if self.speculative_draft_model_path is None:
|
||||
self.speculative_draft_model_path = self.model_path
|
||||
else:
|
||||
logger.warning(
|
||||
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
||||
)
|
||||
elif "Llama4" in model_arch:
|
||||
# TODO: remove this after Llama4 supports in other backends
|
||||
if self.attention_backend != "fa3":
|
||||
self.attention_backend = "fa3"
|
||||
logger.warning(
|
||||
"Llama4 requires using fa3 attention backend. "
|
||||
"Attention backend is automatically set to fa3."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
@@ -852,6 +862,18 @@ class ServerArgs:
|
||||
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swa-full-tokens-ratio",
|
||||
type=float,
|
||||
default=ServerArgs.swa_full_tokens_ratio,
|
||||
help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
|
||||
"E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-hybrid-swa-memory",
|
||||
action="store_true",
|
||||
help="Disable the hybrid SWA memory.",
|
||||
)
|
||||
|
||||
# Other runtime options
|
||||
parser.add_argument(
|
||||
@@ -1730,10 +1752,6 @@ class ServerArgs:
|
||||
else:
|
||||
self.lora_paths[lora_path] = lora_path
|
||||
|
||||
model_arch = get_model_arch(self)
|
||||
if "Llama4" in model_arch and self.hybrid_kvcache_ratio is not None:
|
||||
assert self.attention_backend == "fa3"
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user