Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads (#10788)

This commit is contained in:
YAMY
2025-10-18 20:37:43 -07:00
committed by GitHub
parent b288f4f440
commit 80407b0493
4 changed files with 103 additions and 0 deletions

View File

@@ -3460,3 +3460,61 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn)
return decorator
def reserve_rope_cache_for_long_sequences(
model, server_args, model_config, req_to_token_pool=None, logger=None
):
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
from sglang.srt.environ import envs
if logger is None:
import logging
logger = logging.getLogger(__name__)
SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value
# 1) Estimate base context upper bound
base_ctx = (
getattr(server_args, "context_length", None)
or getattr(model_config, "context_len", None)
or getattr(model_config, "max_model_len", None)
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
or 2048
)
# 2) Runtime input capacity (including extra_len from req_to_token_pool)
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx
# 3) Speculative decoding expansion
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN
# 4) Align to reduce reallocation frequency
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN
logger.info(
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
)
# Recursively expand all RoPE layers
def reserve_rope_cache_recursive(module):
for child in module.children():
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
child, "cos_sin_cache"
):
old_len = child.cos_sin_cache.shape[0]
child._ensure_cos_sin_cache_length(reserve - 1)
new_len = child.cos_sin_cache.shape[0]
if new_len > old_len:
logger.info(
f"Expanded RoPE cache from {old_len} to {new_len} positions"
)
else:
reserve_rope_cache_recursive(child)
reserve_rope_cache_recursive(model)