Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads (#10788)
This commit is contained in:
@@ -222,6 +222,11 @@ class Envs:
|
|||||||
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||||
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||||
|
|
||||||
|
# RoPE cache configuration
|
||||||
|
SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)
|
||||||
|
SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)
|
||||||
|
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)
|
||||||
|
|
||||||
# Overlap Spec V2
|
# Overlap Spec V2
|
||||||
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
||||||
|
|
||||||
|
|||||||
@@ -147,6 +147,36 @@ class RotaryEmbedding(CustomOp):
|
|||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
|
||||||
|
"""Ensure cos_sin_cache length > needed_max_pos."""
|
||||||
|
cur_len = int(self.cos_sin_cache.shape[0])
|
||||||
|
if needed_max_pos < cur_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Align to 128 to reduce realloc frequency
|
||||||
|
new_len = ((needed_max_pos + 128) // 128) * 128
|
||||||
|
device = self.cos_sin_cache.device
|
||||||
|
dtype = self.cos_sin_cache.dtype
|
||||||
|
|
||||||
|
# Compute inv_freq on same device
|
||||||
|
inv_freq = self._compute_inv_freq(self.base).to(device=device)
|
||||||
|
|
||||||
|
# Incremental computation for new positions only
|
||||||
|
start = cur_len
|
||||||
|
t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device)
|
||||||
|
if t_new.numel() == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
|
||||||
|
cos_new = freqs_new.cos()
|
||||||
|
sin_new = freqs_new.sin()
|
||||||
|
new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype)
|
||||||
|
|
||||||
|
# Update cache with new rows
|
||||||
|
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to(
|
||||||
|
device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
|
|||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
|
reserve_rope_cache_for_long_sequences,
|
||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
slow_rank_detector,
|
slow_rank_detector,
|
||||||
)
|
)
|
||||||
@@ -898,6 +899,15 @@ class ModelRunner:
|
|||||||
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pre-expand RoPE cache before CUDA Graph capture
|
||||||
|
reserve_rope_cache_for_long_sequences(
|
||||||
|
self.model,
|
||||||
|
self.server_args,
|
||||||
|
self.model_config,
|
||||||
|
self.req_to_token_pool,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
if self.server_args.elastic_ep_backend == "mooncake":
|
if self.server_args.elastic_ep_backend == "mooncake":
|
||||||
# Mooncake does not support `monitored_barrier`
|
# Mooncake does not support `monitored_barrier`
|
||||||
dist.barrier(group=get_tp_group().cpu_group)
|
dist.barrier(group=get_tp_group().cpu_group)
|
||||||
|
|||||||
@@ -3460,3 +3460,61 @@ def cached_triton_kernel(key_fn=None):
|
|||||||
return CachedKernel(fn, key_fn)
|
return CachedKernel(fn, key_fn)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def reserve_rope_cache_for_long_sequences(
|
||||||
|
model, server_args, model_config, req_to_token_pool=None, logger=None
|
||||||
|
):
|
||||||
|
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
|
||||||
|
from sglang.srt.environ import envs
|
||||||
|
|
||||||
|
if logger is None:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
|
||||||
|
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
|
||||||
|
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value
|
||||||
|
|
||||||
|
# 1) Estimate base context upper bound
|
||||||
|
base_ctx = (
|
||||||
|
getattr(server_args, "context_length", None)
|
||||||
|
or getattr(model_config, "context_len", None)
|
||||||
|
or getattr(model_config, "max_model_len", None)
|
||||||
|
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
|
||||||
|
or 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Runtime input capacity (including extra_len from req_to_token_pool)
|
||||||
|
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx
|
||||||
|
|
||||||
|
# 3) Speculative decoding expansion
|
||||||
|
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
|
||||||
|
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
|
||||||
|
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN
|
||||||
|
|
||||||
|
# 4) Align to reduce reallocation frequency
|
||||||
|
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recursively expand all RoPE layers
|
||||||
|
def reserve_rope_cache_recursive(module):
|
||||||
|
for child in module.children():
|
||||||
|
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
|
||||||
|
child, "cos_sin_cache"
|
||||||
|
):
|
||||||
|
old_len = child.cos_sin_cache.shape[0]
|
||||||
|
child._ensure_cos_sin_cache_length(reserve - 1)
|
||||||
|
new_len = child.cos_sin_cache.shape[0]
|
||||||
|
if new_len > old_len:
|
||||||
|
logger.info(
|
||||||
|
f"Expanded RoPE cache from {old_len} to {new_len} positions"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reserve_rope_cache_recursive(child)
|
||||||
|
|
||||||
|
reserve_rope_cache_recursive(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user