Misc clean up; Remove the support of jump forward (#4032)
This commit is contained in:
@@ -52,7 +52,7 @@ srt = [
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||
srt_hip = ["sglang[runtime_common]", "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
||||
srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
||||
|
||||
# xpu is not enabled in public vllm and torch whl,
|
||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||
|
||||
@@ -12,6 +12,5 @@
|
||||
- `global_config.py`: The global configs and constants.
|
||||
- `launch_server.py`: The entry point for launching the local server.
|
||||
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
|
||||
- `profiler.py`: Profile a running server.
|
||||
- `utils.py`: Common utilities.
|
||||
- `version.py`: Version info.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
|
||||
@@ -4,6 +4,13 @@ import os
|
||||
|
||||
|
||||
class GlobalConfig:
|
||||
"""
|
||||
Store some global constants.
|
||||
|
||||
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
|
||||
many global runtime arguments as well.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Verbosity level
|
||||
# 0: do not output anything
|
||||
|
||||
@@ -80,7 +80,6 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
|
||||
grammar_backend = OutlinesGrammarBackend(
|
||||
tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||
|
||||
@@ -115,7 +115,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_pattern: bool,
|
||||
allow_jump_forward: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -140,7 +139,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
self.allow_jump_forward = allow_jump_forward
|
||||
self.whitespace_pattern = whitespace_pattern
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||
@@ -172,10 +170,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
||||
return None
|
||||
|
||||
if self.allow_jump_forward:
|
||||
jump_forward_map = OutlinesJumpForwardMap(regex)
|
||||
else:
|
||||
jump_forward_map = None
|
||||
jump_forward_map = None
|
||||
return OutlinesGrammar(guide, jump_forward_map)
|
||||
|
||||
|
||||
|
||||
@@ -438,8 +438,8 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/function_call")
|
||||
async def function_call_request(obj: ParseFunctionCallReq, request: Request):
|
||||
@app.post("/parse_function_call")
|
||||
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
|
||||
"""
|
||||
A native API endpoint to parse function calls from a text.
|
||||
"""
|
||||
@@ -492,7 +492,7 @@ def available_models():
|
||||
@app.post("/v1/files")
|
||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||
return await v1_files_create(
|
||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
|
||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@@ -19,9 +19,8 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
@@ -34,7 +34,6 @@ if is_flashinfer_available():
|
||||
BatchMLAPagedAttentionWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
|
||||
@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 <
|
||||
class DecodeStatus:
|
||||
"""Store the status of incremental decoding."""
|
||||
|
||||
vid: int
|
||||
decoded_text: str
|
||||
decode_ids: List[int]
|
||||
surr_offset: int
|
||||
@@ -143,10 +142,8 @@ class DetokenizerManager:
|
||||
read_ids, surr_ids = [], []
|
||||
for i in range(bs):
|
||||
rid = recv_obj.rids[i]
|
||||
vid = recv_obj.vids[i]
|
||||
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
||||
if rid not in self.decode_status:
|
||||
s = DecodeStatus(
|
||||
vid=vid,
|
||||
decoded_text=recv_obj.decoded_texts[i],
|
||||
decode_ids=recv_obj.decode_ids[i],
|
||||
surr_offset=0,
|
||||
|
||||
@@ -376,8 +376,6 @@ class BatchTokenIDOut:
|
||||
# The finish reason
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
# For incremental decoding
|
||||
# The version id to sync decode status with in detokenizer_manager
|
||||
vids: List[int]
|
||||
decoded_texts: List[str]
|
||||
decode_ids: List[int]
|
||||
read_offsets: List[int]
|
||||
|
||||
@@ -296,7 +296,6 @@ class Req:
|
||||
# 1: surr_offset
|
||||
# 2: read_offset
|
||||
# 3: last token
|
||||
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
self.decoded_text = ""
|
||||
@@ -357,11 +356,6 @@ class Req:
|
||||
) = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
# and should be updated for the decode logprobs
|
||||
self.last_update_decode_tokens = 0
|
||||
|
||||
# Embedding (return values)
|
||||
self.embedding = None
|
||||
|
||||
@@ -500,68 +494,6 @@ class Req:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||
if self.origin_input_text is None:
|
||||
# Recovering text can only use unpadded ids
|
||||
self.origin_input_text = self.tokenizer.decode(
|
||||
self.origin_input_ids_unpadded
|
||||
)
|
||||
|
||||
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
||||
all_ids = self.tokenizer.encode(all_text)
|
||||
if not all_ids:
|
||||
logger.warning("Encoded all_text resulted in empty all_ids")
|
||||
return False
|
||||
|
||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||
if prompt_tokens > len(all_ids):
|
||||
logger.warning("prompt_tokens is larger than encoded all_ids")
|
||||
return False
|
||||
|
||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||
# TODO(lsyin): fix token fusion
|
||||
logger.warning(
|
||||
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
||||
)
|
||||
return False
|
||||
|
||||
old_output_ids = self.output_ids
|
||||
self.output_ids = all_ids[prompt_tokens:]
|
||||
self.decoded_text = self.decoded_text + jump_forward_str
|
||||
self.surr_offset = prompt_tokens
|
||||
self.read_offset = len(all_ids)
|
||||
|
||||
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
||||
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
||||
surr_text_ = self.tokenizer.decode(
|
||||
all_ids[self.read_offset - i : self.read_offset]
|
||||
)
|
||||
if not surr_text_.endswith("<EFBFBD>"):
|
||||
self.surr_offset = self.read_offset - i
|
||||
break
|
||||
|
||||
# update the inner state of the grammar
|
||||
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
||||
|
||||
if self.return_logprob:
|
||||
# For fast-forward part's logprobs
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == self.output_ids[i]:
|
||||
k = k + 1
|
||||
else:
|
||||
break
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
||||
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
return True
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
@@ -574,8 +506,6 @@ class Req:
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
|
||||
self.last_update_decode_tokens = 0
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Req(rid={self.rid}, "
|
||||
@@ -672,7 +602,6 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -687,7 +616,7 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1091,59 +1020,6 @@ class ScheduleBatch:
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def check_for_jump_forward(self, pad_input_ids_func):
|
||||
jump_forward_reqs = []
|
||||
keep_indices = set(i for i in range(len(self.reqs)))
|
||||
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.grammar is not None:
|
||||
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
||||
if jump_helper:
|
||||
suffix_ids, _ = jump_helper
|
||||
|
||||
# Current ids, for cache and revert
|
||||
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
||||
cur_output_ids = req.output_ids
|
||||
|
||||
req.output_ids.extend(suffix_ids)
|
||||
decode_res, new_text = req.get_next_inc_detokenization()
|
||||
if not decode_res:
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
(
|
||||
jump_forward_str,
|
||||
next_state,
|
||||
) = req.grammar.jump_forward_str_state(jump_helper)
|
||||
|
||||
# Make the incrementally decoded text part of jump_forward_str
|
||||
# so that the UTF-8 will not corrupt
|
||||
jump_forward_str = new_text + jump_forward_str
|
||||
if not req.jump_forward_and_retokenize(
|
||||
jump_forward_str, next_state
|
||||
):
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
# The decode status has diverged from detokenizer_manager
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
||||
|
||||
# re-applying image padding
|
||||
if req.image_inputs is not None:
|
||||
req.origin_input_ids = pad_input_ids_func(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
jump_forward_reqs.append(req)
|
||||
keep_indices.remove(i)
|
||||
|
||||
self.filter_batch(keep_indices=list(keep_indices))
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
@@ -150,7 +150,6 @@ class Scheduler:
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.disable_jump_forward = server_args.disable_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
@@ -251,9 +250,6 @@ class Scheduler:
|
||||
self.enable_overlap = False
|
||||
logger.info("Overlap scheduler is disabled for multimodal models.")
|
||||
|
||||
if self.enable_overlap:
|
||||
self.disable_jump_forward = True
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
@@ -1024,11 +1020,8 @@ class Scheduler:
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
return_hidden_states = False
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if req.return_hidden_states:
|
||||
return_hidden_states = True
|
||||
if (
|
||||
self.lora_paths
|
||||
and len(
|
||||
@@ -1114,7 +1107,6 @@ class Scheduler:
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
return_hidden_states,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1168,14 +1160,6 @@ class Scheduler:
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_jump_forward and batch.has_grammar:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self._extend_requests_to_queue(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
self.batch_is_full = False
|
||||
return None
|
||||
|
||||
if batch.batch_size() < initial_bs:
|
||||
self.batch_is_full = False
|
||||
|
||||
@@ -1530,8 +1514,6 @@ class Scheduler:
|
||||
prefill (e.g., computing input token logprobs).
|
||||
"""
|
||||
assert output.input_token_logprobs is not None
|
||||
# It is for jump decoding that will be deprecated.
|
||||
assert req.last_update_decode_tokens == 0
|
||||
if req.input_token_logprobs is None:
|
||||
req.input_token_logprobs = []
|
||||
if req.temp_input_top_logprobs_val is None:
|
||||
@@ -1658,50 +1640,12 @@ class Scheduler:
|
||||
self.add_input_logprob_return_values(
|
||||
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
||||
)
|
||||
if req.last_update_decode_tokens != 0:
|
||||
# Some decode tokens are re-computed in an extend batch
|
||||
req.output_token_logprobs_val.extend(
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
],
|
||||
)
|
||||
req.output_token_logprobs_idx.extend(
|
||||
req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||
]
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.output_top_logprobs_val.extend(
|
||||
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs_idx.extend(
|
||||
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
|
||||
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
||||
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
||||
|
||||
if req.token_ids_logprob is not None:
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.output_token_ids_logprobs_val.extend(
|
||||
output.input_token_ids_logprobs_val[i][
|
||||
-req.last_update_decode_tokens :
|
||||
]
|
||||
)
|
||||
req.output_token_ids_logprobs_idx.extend(
|
||||
output.input_token_ids_logprobs_idx[i][
|
||||
-req.last_update_decode_tokens :
|
||||
]
|
||||
)
|
||||
|
||||
req.output_token_ids_logprobs_val.append(
|
||||
output.next_token_token_ids_logprobs_val[i]
|
||||
)
|
||||
@@ -1719,7 +1663,6 @@ class Scheduler:
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
|
||||
if self.is_generation:
|
||||
vids = []
|
||||
decoded_texts = []
|
||||
decode_ids_list = []
|
||||
read_offsets = []
|
||||
@@ -1786,7 +1729,6 @@ class Scheduler:
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
)
|
||||
vids.append(req.vid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||
decode_ids_list.append(decode_ids)
|
||||
@@ -1842,7 +1784,6 @@ class Scheduler:
|
||||
BatchTokenIDOut(
|
||||
rids,
|
||||
finished_reasons,
|
||||
vids,
|
||||
decoded_texts,
|
||||
decode_ids_list,
|
||||
read_offsets,
|
||||
|
||||
@@ -41,7 +41,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
@@ -26,8 +26,6 @@ from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import convert_json_schema_to_str
|
||||
except ImportError:
|
||||
@@ -165,24 +163,19 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
|
||||
else:
|
||||
chat_template_name = chat_template_arg
|
||||
|
||||
# check chat-template
|
||||
chat_template = get_chat_template_by_model_path(model_path)
|
||||
if chat_template is not None:
|
||||
official_chat_template = chat_template.name
|
||||
used_chat_template = chat_template_name
|
||||
if official_chat_template != used_chat_template:
|
||||
logger.warning(
|
||||
f"Using a chat_template: '{used_chat_template}', "
|
||||
f"which is different from official chat template: '{official_chat_template}', "
|
||||
f"This discrepancy may lead to performance degradation."
|
||||
)
|
||||
# Check chat-template
|
||||
# TODO:
|
||||
# 1. Do not import any code from sglang.lang
|
||||
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
|
||||
|
||||
|
||||
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
|
||||
async def v1_files_create(
|
||||
file: UploadFile, purpose: str, file_storage_path: str = None
|
||||
):
|
||||
try:
|
||||
global storage_dir
|
||||
if file_storage_pth:
|
||||
storage_dir = file_storage_pth
|
||||
if file_storage_path:
|
||||
storage_dir = file_storage_path
|
||||
# Read the file content
|
||||
file_content = await file.read()
|
||||
|
||||
|
||||
@@ -40,17 +40,23 @@ class SamplingParams:
|
||||
presence_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
min_new_tokens: int = 0,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
n: int = 1,
|
||||
json_schema: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
ebnf: Optional[str] = None,
|
||||
structural_tag: Optional[str] = None,
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
no_stop_trim: bool = False,
|
||||
custom_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.stop_strs = stop
|
||||
if stop_token_ids:
|
||||
self.stop_token_ids = set(stop_token_ids)
|
||||
else:
|
||||
self.stop_token_ids = None
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
@@ -58,26 +64,21 @@ class SamplingParams:
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.stop_strs = stop
|
||||
if stop_token_ids:
|
||||
self.stop_token_ids = set(stop_token_ids)
|
||||
else:
|
||||
self.stop_token_ids = None
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
self.json_schema = json_schema
|
||||
self.ebnf = ebnf
|
||||
self.structural_tag = structural_tag
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.custom_params = custom_params
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# top_k = 1 means greedy sampling
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
|
||||
@@ -15,21 +15,15 @@
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.utils import (
|
||||
create_checksum,
|
||||
get_amdgpu_memory_capacity,
|
||||
get_hpu_memory_capacity,
|
||||
get_nvgpu_memory_capacity,
|
||||
@@ -101,7 +95,7 @@ class ServerArgs:
|
||||
|
||||
# API related
|
||||
api_key: Optional[str] = None
|
||||
file_storage_pth: str = "sglang_storage"
|
||||
file_storage_path: str = "sglang_storage"
|
||||
enable_cache_report: bool = False
|
||||
|
||||
# Data parallelism
|
||||
@@ -149,7 +143,6 @@ class ServerArgs:
|
||||
|
||||
# Optimization/debug options
|
||||
disable_radix_cache: bool = False
|
||||
disable_jump_forward: bool = False
|
||||
disable_cuda_graph: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
@@ -627,9 +620,9 @@ class ServerArgs:
|
||||
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-storage-pth",
|
||||
"--file-storage-path",
|
||||
type=str,
|
||||
default=ServerArgs.file_storage_pth,
|
||||
default=ServerArgs.file_storage_path,
|
||||
help="The path of the file storage in backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -836,11 +829,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable RadixAttention for prefix caching.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-jump-forward",
|
||||
action="store_true",
|
||||
help="Disable jump-forward for grammar-guided decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-cuda-graph",
|
||||
action="store_true",
|
||||
|
||||
@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
||||
|
||||
|
||||
def is_in_ci():
|
||||
|
||||
Reference in New Issue
Block a user