From 80407b0493846dc498a51ae110ced4f53344bb37 Mon Sep 17 00:00:00 2001 From: YAMY <74099316+YAMY1234@users.noreply.github.com> Date: Sat, 18 Oct 2025 20:37:43 -0700 Subject: [PATCH] Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads (#10788) --- python/sglang/srt/environ.py | 5 ++ python/sglang/srt/layers/rotary_embedding.py | 30 ++++++++++ .../sglang/srt/model_executor/model_runner.py | 10 ++++ python/sglang/srt/utils/common.py | 58 +++++++++++++++++++ 4 files changed, 103 insertions(+) diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index acc8b0e68..be9eba0eb 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -222,6 +222,11 @@ class Envs: SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) + # RoPE cache configuration + SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2) + SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256) + SGLANG_ROPE_CACHE_ALIGN = EnvInt(128) + # Overlap Spec V2 SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index e4ca62c86..ee6fe199c 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -147,6 +147,36 @@ class RotaryEmbedding(CustomOp): cache = torch.cat((cos, sin), dim=-1) return cache + def _ensure_cos_sin_cache_length(self, needed_max_pos: int): + """Ensure cos_sin_cache length > needed_max_pos.""" + cur_len = int(self.cos_sin_cache.shape[0]) + if needed_max_pos < cur_len: + return + + # Align to 128 to reduce realloc frequency + new_len = ((needed_max_pos + 128) // 128) * 128 + device = self.cos_sin_cache.device + dtype = self.cos_sin_cache.dtype + + # Compute inv_freq on same device + inv_freq = self._compute_inv_freq(self.base).to(device=device) + + # Incremental computation for new positions only + start = cur_len + t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device) + if t_new.numel() == 0: + return + + freqs_new = torch.einsum("i,j->ij", t_new, inv_freq) + cos_new = freqs_new.cos() + sin_new = freqs_new.sin() + new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype) + + # Update cache with new rows + self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to( + device=device, dtype=dtype + ) + def forward_native( self, positions: torch.Tensor, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6fce4cda4..7619faeb2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -140,6 +140,7 @@ from sglang.srt.utils import ( log_info_on_rank0, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, + reserve_rope_cache_for_long_sequences, set_cuda_arch, slow_rank_detector, ) @@ -898,6 +899,15 @@ class ModelRunner: f"mem usage={self.weight_load_mem_usage:.2f} GB." ) + # Pre-expand RoPE cache before CUDA Graph capture + reserve_rope_cache_for_long_sequences( + self.model, + self.server_args, + self.model_config, + self.req_to_token_pool, + logger, + ) + if self.server_args.elastic_ep_backend == "mooncake": # Mooncake does not support `monitored_barrier` dist.barrier(group=get_tp_group().cpu_group) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 3436b2682..7d0be4e7d 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3460,3 +3460,61 @@ def cached_triton_kernel(key_fn=None): return CachedKernel(fn, key_fn) return decorator + + +def reserve_rope_cache_for_long_sequences( + model, server_args, model_config, req_to_token_pool=None, logger=None +): + """Pre-expand RoPE cache for long sequences and speculative decoding.""" + from sglang.srt.environ import envs + + if logger is None: + import logging + + logger = logging.getLogger(__name__) + + SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value + MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value + ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value + + # 1) Estimate base context upper bound + base_ctx = ( + getattr(server_args, "context_length", None) + or getattr(model_config, "context_len", None) + or getattr(model_config, "max_model_len", None) + or getattr(model_config.hf_text_config, "max_position_embeddings", None) + or 2048 + ) + + # 2) Runtime input capacity (including extra_len from req_to_token_pool) + inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx + + # 3) Speculative decoding expansion + steps = int(getattr(server_args, "speculative_num_steps", 0) or 0) + draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0) + reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN + + # 4) Align to reduce reallocation frequency + reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN + + logger.info( + f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})" + ) + + # Recursively expand all RoPE layers + def reserve_rope_cache_recursive(module): + for child in module.children(): + if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr( + child, "cos_sin_cache" + ): + old_len = child.cos_sin_cache.shape[0] + child._ensure_cos_sin_cache_length(reserve - 1) + new_len = child.cos_sin_cache.shape[0] + if new_len > old_len: + logger.info( + f"Expanded RoPE cache from {old_len} to {new_len} positions" + ) + else: + reserve_rope_cache_recursive(child) + + reserve_rope_cache_recursive(model)