Revert "Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads" (#11827)
This commit is contained in:
@@ -3460,61 +3460,3 @@ def cached_triton_kernel(key_fn=None):
|
||||
return CachedKernel(fn, key_fn)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def reserve_rope_cache_for_long_sequences(
|
||||
model, server_args, model_config, req_to_token_pool=None, logger=None
|
||||
):
|
||||
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
|
||||
from sglang.srt.environ import envs
|
||||
|
||||
if logger is None:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
|
||||
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
|
||||
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value
|
||||
|
||||
# 1) Estimate base context upper bound
|
||||
base_ctx = (
|
||||
getattr(server_args, "context_length", None)
|
||||
or getattr(model_config, "context_len", None)
|
||||
or getattr(model_config, "max_model_len", None)
|
||||
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
|
||||
or 2048
|
||||
)
|
||||
|
||||
# 2) Runtime input capacity (including extra_len from req_to_token_pool)
|
||||
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx
|
||||
|
||||
# 3) Speculative decoding expansion
|
||||
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
|
||||
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
|
||||
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN
|
||||
|
||||
# 4) Align to reduce reallocation frequency
|
||||
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN
|
||||
|
||||
logger.info(
|
||||
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
|
||||
)
|
||||
|
||||
# Recursively expand all RoPE layers
|
||||
def reserve_rope_cache_recursive(module):
|
||||
for child in module.children():
|
||||
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
|
||||
child, "cos_sin_cache"
|
||||
):
|
||||
old_len = child.cos_sin_cache.shape[0]
|
||||
child._ensure_cos_sin_cache_length(reserve - 1)
|
||||
new_len = child.cos_sin_cache.shape[0]
|
||||
if new_len > old_len:
|
||||
logger.info(
|
||||
f"Expanded RoPE cache from {old_len} to {new_len} positions"
|
||||
)
|
||||
else:
|
||||
reserve_rope_cache_recursive(child)
|
||||
|
||||
reserve_rope_cache_recursive(model)
|
||||
|
||||
Reference in New Issue
Block a user