Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads (#10788)
This commit is contained in:
@@ -222,6 +222,11 @@ class Envs:
|
||||
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||
|
||||
# RoPE cache configuration
|
||||
SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)
|
||||
SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)
|
||||
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)
|
||||
|
||||
# Overlap Spec V2
|
||||
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
||||
|
||||
|
||||
@@ -147,6 +147,36 @@ class RotaryEmbedding(CustomOp):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
|
||||
"""Ensure cos_sin_cache length > needed_max_pos."""
|
||||
cur_len = int(self.cos_sin_cache.shape[0])
|
||||
if needed_max_pos < cur_len:
|
||||
return
|
||||
|
||||
# Align to 128 to reduce realloc frequency
|
||||
new_len = ((needed_max_pos + 128) // 128) * 128
|
||||
device = self.cos_sin_cache.device
|
||||
dtype = self.cos_sin_cache.dtype
|
||||
|
||||
# Compute inv_freq on same device
|
||||
inv_freq = self._compute_inv_freq(self.base).to(device=device)
|
||||
|
||||
# Incremental computation for new positions only
|
||||
start = cur_len
|
||||
t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device)
|
||||
if t_new.numel() == 0:
|
||||
return
|
||||
|
||||
freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
|
||||
cos_new = freqs_new.cos()
|
||||
sin_new = freqs_new.sin()
|
||||
new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype)
|
||||
|
||||
# Update cache with new rows
|
||||
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
|
||||
log_info_on_rank0,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
reserve_rope_cache_for_long_sequences,
|
||||
set_cuda_arch,
|
||||
slow_rank_detector,
|
||||
)
|
||||
@@ -898,6 +899,15 @@ class ModelRunner:
|
||||
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
||||
)
|
||||
|
||||
# Pre-expand RoPE cache before CUDA Graph capture
|
||||
reserve_rope_cache_for_long_sequences(
|
||||
self.model,
|
||||
self.server_args,
|
||||
self.model_config,
|
||||
self.req_to_token_pool,
|
||||
logger,
|
||||
)
|
||||
|
||||
if self.server_args.elastic_ep_backend == "mooncake":
|
||||
# Mooncake does not support `monitored_barrier`
|
||||
dist.barrier(group=get_tp_group().cpu_group)
|
||||
|
||||
@@ -3460,3 +3460,61 @@ def cached_triton_kernel(key_fn=None):
|
||||
return CachedKernel(fn, key_fn)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def reserve_rope_cache_for_long_sequences(
|
||||
model, server_args, model_config, req_to_token_pool=None, logger=None
|
||||
):
|
||||
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
|
||||
from sglang.srt.environ import envs
|
||||
|
||||
if logger is None:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
|
||||
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
|
||||
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value
|
||||
|
||||
# 1) Estimate base context upper bound
|
||||
base_ctx = (
|
||||
getattr(server_args, "context_length", None)
|
||||
or getattr(model_config, "context_len", None)
|
||||
or getattr(model_config, "max_model_len", None)
|
||||
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
|
||||
or 2048
|
||||
)
|
||||
|
||||
# 2) Runtime input capacity (including extra_len from req_to_token_pool)
|
||||
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx
|
||||
|
||||
# 3) Speculative decoding expansion
|
||||
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
|
||||
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
|
||||
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN
|
||||
|
||||
# 4) Align to reduce reallocation frequency
|
||||
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN
|
||||
|
||||
logger.info(
|
||||
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
|
||||
)
|
||||
|
||||
# Recursively expand all RoPE layers
|
||||
def reserve_rope_cache_recursive(module):
|
||||
for child in module.children():
|
||||
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
|
||||
child, "cos_sin_cache"
|
||||
):
|
||||
old_len = child.cos_sin_cache.shape[0]
|
||||
child._ensure_cos_sin_cache_length(reserve - 1)
|
||||
new_len = child.cos_sin_cache.shape[0]
|
||||
if new_len > old_len:
|
||||
logger.info(
|
||||
f"Expanded RoPE cache from {old_len} to {new_len} positions"
|
||||
)
|
||||
else:
|
||||
reserve_rope_cache_recursive(child)
|
||||
|
||||
reserve_rope_cache_recursive(model)
|
||||
|
||||
Reference in New Issue
Block a user