diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index be9eba0eb..acc8b0e68 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -222,11 +222,6 @@ class Envs: SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) - # RoPE cache configuration - SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2) - SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256) - SGLANG_ROPE_CACHE_ALIGN = EnvInt(128) - # Overlap Spec V2 SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index ee6fe199c..e4ca62c86 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -147,36 +147,6 @@ class RotaryEmbedding(CustomOp): cache = torch.cat((cos, sin), dim=-1) return cache - def _ensure_cos_sin_cache_length(self, needed_max_pos: int): - """Ensure cos_sin_cache length > needed_max_pos.""" - cur_len = int(self.cos_sin_cache.shape[0]) - if needed_max_pos < cur_len: - return - - # Align to 128 to reduce realloc frequency - new_len = ((needed_max_pos + 128) // 128) * 128 - device = self.cos_sin_cache.device - dtype = self.cos_sin_cache.dtype - - # Compute inv_freq on same device - inv_freq = self._compute_inv_freq(self.base).to(device=device) - - # Incremental computation for new positions only - start = cur_len - t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device) - if t_new.numel() == 0: - return - - freqs_new = torch.einsum("i,j->ij", t_new, inv_freq) - cos_new = freqs_new.cos() - sin_new = freqs_new.sin() - new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype) - - # Update cache with new rows - self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to( - device=device, dtype=dtype - ) - def forward_native( self, positions: torch.Tensor, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7619faeb2..6fce4cda4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -140,7 +140,6 @@ from sglang.srt.utils import ( log_info_on_rank0, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - reserve_rope_cache_for_long_sequences, set_cuda_arch, slow_rank_detector, ) @@ -899,15 +898,6 @@ class ModelRunner: f"mem usage={self.weight_load_mem_usage:.2f} GB." ) - # Pre-expand RoPE cache before CUDA Graph capture - reserve_rope_cache_for_long_sequences( - self.model, - self.server_args, - self.model_config, - self.req_to_token_pool, - logger, - ) - if self.server_args.elastic_ep_backend == "mooncake": # Mooncake does not support `monitored_barrier` dist.barrier(group=get_tp_group().cpu_group) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 7d0be4e7d..3436b2682 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3460,61 +3460,3 @@ def cached_triton_kernel(key_fn=None): return CachedKernel(fn, key_fn) return decorator - - -def reserve_rope_cache_for_long_sequences( - model, server_args, model_config, req_to_token_pool=None, logger=None -): - """Pre-expand RoPE cache for long sequences and speculative decoding.""" - from sglang.srt.environ import envs - - if logger is None: - import logging - - logger = logging.getLogger(__name__) - - SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value - MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value - ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value - - # 1) Estimate base context upper bound - base_ctx = ( - getattr(server_args, "context_length", None) - or getattr(model_config, "context_len", None) - or getattr(model_config, "max_model_len", None) - or getattr(model_config.hf_text_config, "max_position_embeddings", None) - or 2048 - ) - - # 2) Runtime input capacity (including extra_len from req_to_token_pool) - inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx - - # 3) Speculative decoding expansion - steps = int(getattr(server_args, "speculative_num_steps", 0) or 0) - draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0) - reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN - - # 4) Align to reduce reallocation frequency - reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN - - logger.info( - f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})" - ) - - # Recursively expand all RoPE layers - def reserve_rope_cache_recursive(module): - for child in module.children(): - if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr( - child, "cos_sin_cache" - ): - old_len = child.cos_sin_cache.shape[0] - child._ensure_cos_sin_cache_length(reserve - 1) - new_len = child.cos_sin_cache.shape[0] - if new_len > old_len: - logger.info( - f"Expanded RoPE cache from {old_len} to {new_len} positions" - ) - else: - reserve_rope_cache_recursive(child) - - reserve_rope_cache_recursive(model)