Misc clean up; Remove the support of jump forward (#4032)

This commit is contained in:
Lianmin Zheng
2025-03-03 07:02:14 -08:00
committed by GitHub
parent 110e006673
commit 935cda944b
41 changed files with 396 additions and 426 deletions

View File

@@ -52,7 +52,7 @@ srt = [
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm

View File

@@ -12,6 +12,5 @@
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
- `profiler.py`: Profile a running server.
- `utils.py`: Common utilities.
- `version.py`: Version info.

View File

@@ -1 +0,0 @@
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")

View File

@@ -4,6 +4,13 @@ import os
class GlobalConfig:
"""
Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
"""
def __init__(self):
# Verbosity level
# 0: do not output anything

View File

@@ -80,7 +80,6 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
grammar_backend = OutlinesGrammarBackend(
tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend

View File

@@ -115,7 +115,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
self,
tokenizer,
whitespace_pattern: bool,
allow_jump_forward: bool,
):
super().__init__()
@@ -140,7 +139,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.allow_jump_forward = allow_jump_forward
self.whitespace_pattern = whitespace_pattern
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
@@ -172,10 +170,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
return None
if self.allow_jump_forward:
jump_forward_map = OutlinesJumpForwardMap(regex)
else:
jump_forward_map = None
jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map)

View File

@@ -438,8 +438,8 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
return Response(status_code=200)
@app.post("/function_call")
async def function_call_request(obj: ParseFunctionCallReq, request: Request):
@app.post("/parse_function_call")
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
"""
A native API endpoint to parse function calls from a text.
"""
@@ -492,7 +492,7 @@ def available_models():
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
)

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

View File

@@ -19,9 +19,8 @@ import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available

View File

@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
@@ -34,7 +34,6 @@ if is_flashinfer_available():
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
@dataclass

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)

View File

@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (

View File

@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 <
class DecodeStatus:
"""Store the status of incremental decoding."""
vid: int
decoded_text: str
decode_ids: List[int]
surr_offset: int
@@ -143,10 +142,8 @@ class DetokenizerManager:
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
vid = recv_obj.vids[i]
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
if rid not in self.decode_status:
s = DecodeStatus(
vid=vid,
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,

View File

@@ -376,8 +376,6 @@ class BatchTokenIDOut:
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
# The version id to sync decode status with in detokenizer_manager
vids: List[int]
decoded_texts: List[str]
decode_ids: List[int]
read_offsets: List[int]

View File

@@ -296,7 +296,6 @@ class Req:
# 1: surr_offset
# 2: read_offset
# 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
self.decoded_text = ""
@@ -357,11 +356,6 @@ class Req:
) = None
self.hidden_states = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# Embedding (return values)
self.embedding = None
@@ -500,68 +494,6 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text)
if not all_ids:
logger.warning("Encoded all_text resulted in empty all_ids")
return False
prompt_tokens = len(self.origin_input_ids_unpadded)
if prompt_tokens > len(all_ids):
logger.warning("prompt_tokens is larger than encoded all_ids")
return False
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("<EFBFBD>"):
self.surr_offset = self.read_offset - i
break
# update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == self.output_ids[i]:
k = k + 1
else:
break
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
return True
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
@@ -574,8 +506,6 @@ class Req:
self.is_chunked = 0
self.req_pool_idx = None
self.last_update_decode_tokens = 0
def __repr__(self):
return (
f"Req(rid={self.rid}, "
@@ -672,7 +602,6 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -687,7 +616,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
return_hidden_states=any(req.return_hidden_states for req in reqs),
)
def batch_size(self):
@@ -1091,59 +1020,6 @@ class ScheduleBatch:
return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = []
keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs):
if req.grammar is not None:
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
if jump_helper:
suffix_ids, _ = jump_helper
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids)
decode_res, new_text = req.get_next_inc_detokenization()
if not decode_res:
req.output_ids = cur_output_ids
continue
(
jump_forward_str,
next_state,
) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
continue
# The decode status has diverged from detokenizer_manager
req.vid += 1
# insert the old request into tree_cache
self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
jump_forward_reqs.append(req)
keep_indices.remove(i)
self.filter_batch(keep_indices=list(keep_indices))
return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)

View File

@@ -150,7 +150,6 @@ class Scheduler:
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_policy = server_args.schedule_policy
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
@@ -251,9 +250,6 @@ class Scheduler:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.")
if self.enable_overlap:
self.disable_jump_forward = True
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
@@ -1024,11 +1020,8 @@ class Scheduler:
if self.running_batch is not None
else set([])
)
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if (
self.lora_paths
and len(
@@ -1114,7 +1107,6 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
return_hidden_states,
)
new_batch.prepare_for_extend()
@@ -1168,14 +1160,6 @@ class Scheduler:
self.min_new_token_ratio,
)
# Check for jump-forward
if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self._extend_requests_to_queue(jump_forward_reqs)
if batch.is_empty():
self.batch_is_full = False
return None
if batch.batch_size() < initial_bs:
self.batch_is_full = False
@@ -1530,8 +1514,6 @@ class Scheduler:
prefill (e.g., computing input token logprobs).
"""
assert output.input_token_logprobs is not None
# It is for jump decoding that will be deprecated.
assert req.last_update_decode_tokens == 0
if req.input_token_logprobs is None:
req.input_token_logprobs = []
if req.temp_input_top_logprobs_val is None:
@@ -1658,50 +1640,12 @@ class Scheduler:
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
]
)
if req.top_logprobs_num > 0:
if req.last_update_decode_tokens != 0:
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
if req.last_update_decode_tokens != 0:
req.output_token_ids_logprobs_val.extend(
output.input_token_ids_logprobs_val[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_idx.extend(
output.input_token_ids_logprobs_idx[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i]
)
@@ -1719,7 +1663,6 @@ class Scheduler:
finished_reasons: List[BaseFinishReason] = []
if self.is_generation:
vids = []
decoded_texts = []
decode_ids_list = []
read_offsets = []
@@ -1786,7 +1729,6 @@ class Scheduler:
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
vids.append(req.vid)
decoded_texts.append(req.decoded_text)
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
@@ -1842,7 +1784,6 @@ class Scheduler:
BatchTokenIDOut(
rids,
finished_reasons,
vids,
decoded_texts,
decode_ids_list,
read_offsets,

View File

@@ -41,7 +41,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner

View File

@@ -26,8 +26,6 @@ from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError
from sglang.lang.chat_template import get_chat_template_by_model_path
try:
from outlines.fsm.json_schema import convert_json_schema_to_str
except ImportError:
@@ -165,24 +163,19 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else:
chat_template_name = chat_template_arg
# check chat-template
chat_template = get_chat_template_by_model_path(model_path)
if chat_template is not None:
official_chat_template = chat_template.name
used_chat_template = chat_template_name
if official_chat_template != used_chat_template:
logger.warning(
f"Using a chat_template: '{used_chat_template}', "
f"which is different from official chat template: '{official_chat_template}', "
f"This discrepancy may lead to performance degradation."
)
# Check chat-template
# TODO:
# 1. Do not import any code from sglang.lang
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
async def v1_files_create(
file: UploadFile, purpose: str, file_storage_path: str = None
):
try:
global storage_dir
if file_storage_pth:
storage_dir = file_storage_pth
if file_storage_path:
storage_dir = file_storage_path
# Read the file content
file_content = await file.read()

View File

@@ -40,17 +40,23 @@ class SamplingParams:
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True,
n: int = 1,
json_schema: Optional[str] = None,
regex: Optional[str] = None,
ebnf: Optional[str] = None,
structural_tag: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
@@ -58,26 +64,21 @@ class SamplingParams:
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.regex = regex
self.n = n
self.json_schema = json_schema
self.ebnf = ebnf
self.structural_tag = structural_tag
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.no_stop_trim = no_stop_trim
self.custom_params = custom_params
# Process some special cases
if self.temperature < _SAMPLING_EPS:
# top_k = 1 means greedy sampling
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:

View File

@@ -15,21 +15,15 @@
import argparse
import dataclasses
import json
import logging
import os
import random
import subprocess
import tempfile
import uuid
from pathlib import Path
from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
create_checksum,
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity,
@@ -101,7 +95,7 @@ class ServerArgs:
# API related
api_key: Optional[str] = None
file_storage_pth: str = "sglang_storage"
file_storage_path: str = "sglang_storage"
enable_cache_report: bool = False
# Data parallelism
@@ -149,7 +143,6 @@ class ServerArgs:
# Optimization/debug options
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
@@ -627,9 +620,9 @@ class ServerArgs:
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
)
parser.add_argument(
"--file-storage-pth",
"--file-storage-path",
type=str,
default=ServerArgs.file_storage_pth,
default=ServerArgs.file_storage_path,
help="The path of the file storage in backend.",
)
parser.add_argument(
@@ -836,11 +829,6 @@ class ServerArgs:
action="store_true",
help="Disable RadixAttention for prefix caching.",
)
parser.add_argument(
"--disable-jump-forward",
action="store_true",
help="Disable jump-forward for grammar-guided decoding.",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",

View File

@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
def is_in_ci():