Revert "Fix: Dynamic RoPE Cache Expansion to Prevent Position-ID Out-of-Bounds in EAGLE + Long-Sequence Workloads" (#11827)
This commit is contained in:
@@ -222,11 +222,6 @@ class Envs:
|
|||||||
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
|
||||||
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
|
||||||
|
|
||||||
# RoPE cache configuration
|
|
||||||
SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)
|
|
||||||
SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)
|
|
||||||
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)
|
|
||||||
|
|
||||||
# Overlap Spec V2
|
# Overlap Spec V2
|
||||||
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
|
||||||
|
|
||||||
|
|||||||
@@ -147,36 +147,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
|
|
||||||
"""Ensure cos_sin_cache length > needed_max_pos."""
|
|
||||||
cur_len = int(self.cos_sin_cache.shape[0])
|
|
||||||
if needed_max_pos < cur_len:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Align to 128 to reduce realloc frequency
|
|
||||||
new_len = ((needed_max_pos + 128) // 128) * 128
|
|
||||||
device = self.cos_sin_cache.device
|
|
||||||
dtype = self.cos_sin_cache.dtype
|
|
||||||
|
|
||||||
# Compute inv_freq on same device
|
|
||||||
inv_freq = self._compute_inv_freq(self.base).to(device=device)
|
|
||||||
|
|
||||||
# Incremental computation for new positions only
|
|
||||||
start = cur_len
|
|
||||||
t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device)
|
|
||||||
if t_new.numel() == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
|
|
||||||
cos_new = freqs_new.cos()
|
|
||||||
sin_new = freqs_new.sin()
|
|
||||||
new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype)
|
|
||||||
|
|
||||||
# Update cache with new rows
|
|
||||||
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to(
|
|
||||||
device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@@ -140,7 +140,6 @@ from sglang.srt.utils import (
|
|||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
reserve_rope_cache_for_long_sequences,
|
|
||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
slow_rank_detector,
|
slow_rank_detector,
|
||||||
)
|
)
|
||||||
@@ -899,15 +898,6 @@ class ModelRunner:
|
|||||||
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-expand RoPE cache before CUDA Graph capture
|
|
||||||
reserve_rope_cache_for_long_sequences(
|
|
||||||
self.model,
|
|
||||||
self.server_args,
|
|
||||||
self.model_config,
|
|
||||||
self.req_to_token_pool,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.server_args.elastic_ep_backend == "mooncake":
|
if self.server_args.elastic_ep_backend == "mooncake":
|
||||||
# Mooncake does not support `monitored_barrier`
|
# Mooncake does not support `monitored_barrier`
|
||||||
dist.barrier(group=get_tp_group().cpu_group)
|
dist.barrier(group=get_tp_group().cpu_group)
|
||||||
|
|||||||
@@ -3460,61 +3460,3 @@ def cached_triton_kernel(key_fn=None):
|
|||||||
return CachedKernel(fn, key_fn)
|
return CachedKernel(fn, key_fn)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def reserve_rope_cache_for_long_sequences(
|
|
||||||
model, server_args, model_config, req_to_token_pool=None, logger=None
|
|
||||||
):
|
|
||||||
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
|
|
||||||
from sglang.srt.environ import envs
|
|
||||||
|
|
||||||
if logger is None:
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
|
|
||||||
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
|
|
||||||
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value
|
|
||||||
|
|
||||||
# 1) Estimate base context upper bound
|
|
||||||
base_ctx = (
|
|
||||||
getattr(server_args, "context_length", None)
|
|
||||||
or getattr(model_config, "context_len", None)
|
|
||||||
or getattr(model_config, "max_model_len", None)
|
|
||||||
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
|
|
||||||
or 2048
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Runtime input capacity (including extra_len from req_to_token_pool)
|
|
||||||
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx
|
|
||||||
|
|
||||||
# 3) Speculative decoding expansion
|
|
||||||
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
|
|
||||||
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
|
|
||||||
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN
|
|
||||||
|
|
||||||
# 4) Align to reduce reallocation frequency
|
|
||||||
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Recursively expand all RoPE layers
|
|
||||||
def reserve_rope_cache_recursive(module):
|
|
||||||
for child in module.children():
|
|
||||||
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
|
|
||||||
child, "cos_sin_cache"
|
|
||||||
):
|
|
||||||
old_len = child.cos_sin_cache.shape[0]
|
|
||||||
child._ensure_cos_sin_cache_length(reserve - 1)
|
|
||||||
new_len = child.cos_sin_cache.shape[0]
|
|
||||||
if new_len > old_len:
|
|
||||||
logger.info(
|
|
||||||
f"Expanded RoPE cache from {old_len} to {new_len} positions"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
reserve_rope_cache_recursive(child)
|
|
||||||
|
|
||||||
reserve_rope_cache_recursive(model)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user