diff --git a/README.md b/README.md index ecbdbbfde..b1e372ae8 100644 --- a/README.md +++ b/README.md @@ -297,7 +297,6 @@ curl http://localhost:30000/generate \ Learn more about the argument format [here](docs/sampling_params.md). ### OpenAI Compatible API - In addition, the server supports an experimental OpenAI-compatible API. ```python @@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). ## Benchmark And Performance - - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1 ![llama_7b](assets/llama_7b.jpg) @@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157 } ``` -[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104) - - We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). diff --git a/python/sglang/api.py b/python/sglang/api.py index 087bac031..f2b92a960 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -1,5 +1,6 @@ """Some Public API Definitions""" +import os import re from typing import Callable, List, Optional, Union @@ -31,6 +32,7 @@ def function( def Runtime(*args, **kwargs): # Avoid importing unnecessary dependency + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" from sglang.srt.server import Runtime return Runtime(*args, **kwargs) diff --git a/python/sglang/backend/anthropic.py b/python/sglang/backend/anthropic.py index 851bc176a..330b2a412 100644 --- a/python/sglang/backend/anthropic.py +++ b/python/sglang/backend/anthropic.py @@ -14,7 +14,7 @@ except ImportError as e: class Anthropic(BaseBackend): - def __init__(self, model_name): + def __init__(self, model_name, *args, **kwargs): super().__init__() if isinstance(anthropic, Exception): @@ -22,6 +22,7 @@ class Anthropic(BaseBackend): self.model_name = model_name self.chat_template = get_chat_template("claude") + self.client = anthropic.Anthropic(*args, **kwargs) def get_chat_template(self): return self.chat_template @@ -41,7 +42,7 @@ class Anthropic(BaseBackend): else: system = "" - ret = anthropic.Anthropic().messages.create( + ret = self.client.messages.create( model=self.model_name, system=system, messages=messages, @@ -66,11 +67,11 @@ class Anthropic(BaseBackend): else: system = "" - with anthropic.Anthropic().messages.stream( + with self.client.messages.stream( model=self.model_name, system=system, messages=messages, **sampling_params.to_anthropic_kwargs(), ) as stream: for text in stream.text_stream: - yield text, {} + yield text, {} \ No newline at end of file diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 3c0210975..5ac1d9447 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -228,7 +228,7 @@ class OpenAI(BaseBackend): prompt_tokens.append(ret_token) decision = choices[np.argmax(scores)] - return decision, scores, scores + return decision, scores, None, None def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs): diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index c4adf6987..efa0cd5a4 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -220,7 +220,6 @@ class RuntimeEndpoint(BaseBackend): "sampling_params": {"max_new_tokens": 0}, "return_logprob": True, "logprob_start_len": max(prompt_len - 2, 0), - "return_text_in_logprobs": True, } self._add_images(s, data) res = http_request( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 668cd3390..53d6620e9 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module): for i in range(all_logprobs.shape[0]): k = input_metadata.top_logprobs_nums[i] t = all_logprobs[i].topk(k) - v_cpu = t.values.cpu().tolist() - p_cpu = t.indices.cpu().tolist() + v_cpu = t.values.tolist() + p_cpu = t.indices.tolist() decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) return None, decode_top_logprobs else: prefill_top_logprobs, decode_top_logprobs = [], [] pt = 0 # NOTE: the GPU-CPU overhead can be reduced - extend_seq_lens_cpu = input_metadata.extend_seq_lens - for i in range(len(input_metadata.extend_seq_lens)): + extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy() + for i in range(len(extend_seq_lens_cpu)): if extend_seq_lens_cpu[i] == 0: + prefill_top_logprobs.append([]) + decode_top_logprobs.append([]) continue k = input_metadata.top_logprobs_nums[i] t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) - vs_cpu = t.values.cpu().tolist() - ps_cpu = t.indices.cpu().tolist() + vs_cpu = t.values.tolist() + ps_cpu = t.indices.tolist() prefill_top_logprobs.append( [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] ) decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) + pt += extend_seq_lens_cpu[i] return prefill_top_logprobs, decode_top_logprobs def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): @@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module): all_logits = all_logits[:, : self.config.vocab_size] all_logprobs = all_logits.float() - all_logits = None + del all_logits all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) - prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( - all_logprobs, input_metadata - ) + return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) + if return_top_logprob: + prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( + all_logprobs, input_metadata + ) + else: + prefill_top_logprobs = decode_top_logprobs = None if input_metadata.forward_mode == ForwardMode.DECODE: last_logprobs = all_logprobs return last_logits, ( None, None, - decode_top_logprobs, None, + decode_top_logprobs, last_logprobs, ) else: @@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module): ) return last_logits, ( prefill_token_logprobs, + normalized_prompt_logprobs, prefill_top_logprobs, decode_top_logprobs, - normalized_prompt_logprobs, last_logprobs, ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6e64380c9..a99498949 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -25,7 +25,6 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output stream: bool = False - # TODO: make all parameters a Union[List[T], T] to allow for batched requests def post_init(self): is_single = isinstance(self.text, str) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 3920fe039..1f655513a 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from enum import Enum, auto +from enum import IntEnum, auto from typing import List import numpy as np @@ -9,15 +9,15 @@ from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool -class ForwardMode(Enum): +class ForwardMode(IntEnum): PREFILL = auto() EXTEND = auto() DECODE = auto() -class FinishReason(Enum): - LENGTH = auto() +class FinishReason(IntEnum): EOS_TOKEN = auto() + LENGTH = auto() STOP_STR = auto() @@ -31,6 +31,7 @@ class Req: # Since jump forward may retokenize the prompt with partial outputs, # we maintain the original prompt length to report the correct usage. self.prompt_tokens = len(input_ids) + # The number of decoded tokens for token usage report. Note that # this does not include the jump forward tokens. self.completion_tokens_wo_jump_forward = 0 @@ -41,12 +42,11 @@ class Req: self.image_offset = 0 self.pad_value = None + # Sampling parameters self.sampling_params = None - self.return_logprob = False - self.logprob_start_len = 0 - self.top_logprobs_num = 0 self.stream = False + # Check finish self.tokenizer = None self.finished = False self.finish_reason = None @@ -56,13 +56,17 @@ class Req: self.prefix_indices = [] self.last_node = None + # Logprobs + self.return_logprob = False + self.logprob_start_len = 0 + self.top_logprobs_num = 0 + self.normalized_prompt_logprob = None self.prefill_token_logprobs = None self.decode_token_logprobs = None - self.normalized_prompt_logprob = None self.prefill_top_logprobs = None self.decode_top_logprobs = None - # For constrained decoding + # Constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 self.jump_forward_map = None @@ -165,8 +169,8 @@ class Batch: out_cache_cont_end: torch.Tensor = None # for processing logprobs - top_logprobs_nums: List[int] = None return_logprob: bool = False + top_logprobs_nums: List[int] = None # for multimodal pixel_values: List[torch.Tensor] = None @@ -321,8 +325,8 @@ class Batch: ) retracted_reqs = [] - seq_lens_np = self.seq_lens.cpu().numpy() - req_pool_indices_np = self.req_pool_indices.cpu().numpy() + seq_lens_cpu = self.seq_lens.cpu().numpy() + req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() while self.token_to_kv_pool.available_size() < len(self.reqs): idx = sorted_indices.pop() req = self.reqs[idx] @@ -338,8 +342,8 @@ class Batch: # TODO: apply more fine-grained retraction token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_np[idx] - ][: seq_lens_np[idx]] + req_pool_indices_cpu[idx] + ][: seq_lens_cpu[idx]] self.token_to_kv_pool.dec_refs(token_indices) self.filter_batch(sorted_indices) @@ -363,7 +367,7 @@ class Batch: # insert the old request into tree_cache token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1] if req_pool_indices_cpu is None: - req_pool_indices_cpu = self.req_pool_indices.cpu().tolist() + req_pool_indices_cpu = self.req_pool_indices.tolist() req_pool_idx = req_pool_indices_cpu[i] indices = self.req_to_token_pool.req_to_token[ req_pool_idx, : len(token_ids_in_memory) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index f837d9029..44b9d0210 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -36,7 +36,9 @@ from sglang.srt.utils import ( set_random_seed, ) + logger = logging.getLogger("model_rpc") +vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) @@ -54,9 +56,6 @@ class ModelRpcServer: self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - vllm_default_logger.setLevel( - level=getattr(logging, server_args.log_level.upper()) - ) # Init model and tokenizer self.model_config = ModelConfig( @@ -65,7 +64,7 @@ class ModelRpcServer: context_length=server_args.context_length, ) - # for model end global settings + # For model end global settings server_args_dict = { "enable_flashinfer": server_args.enable_flashinfer, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, @@ -164,7 +163,7 @@ class ModelRpcServer: logger.info("Cache flushed successfully!") else: warnings.warn( - "Cache not flushed because there are pending requests. " + f"Cache not flushed because there are pending requests. " f"#queue-req: {len(self.forward_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) @@ -386,12 +385,12 @@ class ModelRpcServer: f"#running_req: {running_req}. " f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." ) - logger.debug( - f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " - f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " - f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " - f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " - ) + #logger.debug( + # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " + # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " + # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " + # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " + #) new_batch = Batch.init_new( can_run_list, @@ -408,47 +407,41 @@ class ModelRpcServer: self.model_config.vocab_size, self.int_token_logit_bias ) - prefill_token_logprobs = None if batch.extend_num_tokens != 0: # Forward logits, ( prefill_token_logprobs, + normalized_prompt_logprobs, prefill_top_logprobs, decode_top_logprobs, - normalized_prompt_logprobs, last_logprobs, ) = self.model_runner.forward(batch, ForwardMode.EXTEND) if prefill_token_logprobs is not None: - prefill_token_logprobs = prefill_token_logprobs.cpu().tolist() - normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist() + prefill_token_logprobs = prefill_token_logprobs.tolist() + normalized_prompt_logprobs = normalized_prompt_logprobs.tolist() next_token_ids, _ = batch.sample(logits) - next_token_ids = next_token_ids.cpu().tolist() + + # Only transfer the selected logprobs of the next token to CPU to reduce overhead. + if last_logprobs is not None: + last_token_logprobs = ( + last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist() + ) + + next_token_ids = next_token_ids.tolist() else: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - ( - logits, - prefill_token_logprobs, - normalized_prompt_logprobs, - last_logprobs, - ) = (None,) * 4 - - # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. - reqs = batch.reqs - last_token_logprobs = None - if last_logprobs is not None: - last_token_logprobs = ( - last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() - ) # Check finish condition pt = 0 - for i, req in enumerate(reqs): + for i, req in enumerate(batch.reqs): req.completion_tokens_wo_jump_forward += 1 req.output_ids = [next_token_ids[i]] req.check_finished() - if prefill_token_logprobs is not None: + if req.return_logprob: + req.normalized_prompt_logprob = normalized_prompt_logprobs[i] + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. req.prefill_token_logprobs = list( zip( @@ -463,12 +456,14 @@ class ModelRpcServer: req.decode_token_logprobs = [ (last_token_logprobs[i], next_token_ids[i]) ] + + if req.top_logprobs_num > 0: req.prefill_top_logprobs = prefill_top_logprobs[i] if req.logprob_start_len == 0: req.prefill_top_logprobs = [None] + req.prefill_top_logprobs req.decode_top_logprobs = [decode_top_logprobs[i]] - req.normalized_prompt_logprob = normalized_prompt_logprobs[i] - pt += req.extend_input_len + + pt += req.extend_input_len self.handle_finished_requests(batch) @@ -520,29 +515,29 @@ class ModelRpcServer: logits, ( _, _, - decode_top_logprobs, _, + decode_top_logprobs, last_logprobs, ) = self.model_runner.forward(batch, ForwardMode.DECODE) next_token_ids, _ = batch.sample(logits) - next_token_ids = next_token_ids.cpu().tolist() + next_token_ids = next_token_ids.tolist() # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. - reqs = batch.reqs - new_token_logprobs = None if last_logprobs is not None: new_token_logprobs = last_logprobs[ - torch.arange(len(reqs)), next_token_ids + torch.arange(len(batch.reqs)), next_token_ids ].tolist() # Check finish condition - for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)): + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() - if new_token_logprobs is not None: + if req.return_logprob: req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id)) + + if req.top_logprobs_num > 0: req.decode_top_logprobs.append(decode_top_logprobs[i]) self.handle_finished_requests(batch) @@ -590,8 +585,7 @@ class ModelRpcServer: + len(req.output_ids) - req.prompt_tokens, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": str(req.finish_reason), - "hit_stop_str": req.hit_stop_str, + "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string } if req.return_logprob: ( @@ -628,7 +622,7 @@ class ModelRpcServer: # Remove finished reqs if finished_indices: # Update radix cache - req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() + req_pool_indices_cpu = batch.req_pool_indices.tolist() for i in finished_indices: req = batch.reqs[i] req_pool_idx = req_pool_indices_cpu[i] diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 8d6851caf..d24db31c7 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -29,7 +29,7 @@ QUANTIZATION_CONFIG_MAPPING = { logger = logging.getLogger("model_runner") # for server args in model endpoints -global_server_args_dict: dict = None +global_server_args_dict = {} @lru_cache() @@ -86,8 +86,8 @@ class InputMetadata: out_cache_cont_end: torch.Tensor = None other_kv_index: torch.Tensor = None - top_logprobs_nums: List[int] = None return_logprob: bool = False + top_logprobs_nums: List[int] = None # for flashinfer qo_indptr: torch.Tensor = None @@ -107,18 +107,20 @@ class InputMetadata: (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() + seq_lens_cpu = self.seq_lens.cpu().numpy() self.kv_indices = torch.cat( [ self.req_to_token_pool.req_to_token[ - self.req_pool_indices[i].item(), : self.seq_lens[i].item() + req_pool_indices_cpu[i]: seq_lens_cpu[i] ] for i in range(self.batch_size) ], dim=0, ).contiguous() - self.kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) workspace_buffer = torch.empty( 32 * 1024 * 1024, dtype=torch.int8, device="cuda" @@ -195,15 +197,15 @@ class InputMetadata: req_pool_indices[0], seq_lens[0] - 1 ].item() else: - seq_lens_np = seq_lens.cpu().numpy() - prefix_lens_np = prefix_lens.cpu().numpy() - position_ids_offsets_np = position_ids_offsets.cpu().numpy() + seq_lens_cpu = seq_lens.cpu().numpy() + prefix_lens_cpu = prefix_lens.cpu().numpy() + position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() positions = torch.tensor( np.concatenate( [ np.arange( - prefix_lens_np[i] + position_ids_offsets_np[i], - seq_lens_np[i] + position_ids_offsets_np[i], + prefix_lens_cpu[i] + position_ids_offsets_cpu[i], + seq_lens_cpu[i] + position_ids_offsets_cpu[i], ) for i in range(batch_size) ], @@ -229,9 +231,9 @@ class InputMetadata: out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, - top_logprobs_nums=top_logprobs_nums, - return_logprob=return_logprob, other_kv_index=other_kv_index, + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, ) if forward_mode == ForwardMode.EXTEND: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 78241b74b..a95606de1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -185,7 +185,10 @@ class TokenizerManager: while True: await event.wait() - yield state.out_list[-1] + yield self.convert_logprob_style(state.out_list[-1], + obj.return_logprob, + obj.top_logprobs_num, + obj.return_text_in_logprobs) state.out_list = [] if state.finished: del self.rid_to_state[rid] @@ -231,16 +234,16 @@ class TokenizerManager: rid = obj.rid[i] state = self.rid_to_state[rid] await state.event.wait() - output_list.append(state.out_list[-1]) + output_list.append( + self.convert_logprob_style(state.out_list[-1], + obj.return_logprob[i], + obj.top_logprobs_num[i], + obj.return_text_in_logprobs)) assert state.finished del self.rid_to_state[rid] yield output_list - async def detokenize(self, obj: DetokenizeReqInput): - token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids) - return [t.decode() if isinstance(t, bytes) else t for t in token_texts] - async def flush_cache(self): flush_cache_req = FlushCacheReq() self.send_to_router.send_pyobj(flush_cache_req) @@ -267,3 +270,37 @@ class TokenizerManager: state.event.set() else: raise ValueError(f"Invalid object: {recv_obj}") + + def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs): + if return_logprob: + ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs + ) + ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs + ) + if top_logprobs_num > 0: + ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ) + ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ) + return ret + + def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): + if not decode_to_text: + return [(logprob, token_id, None) for logprob, token_id in token_logprobs] + + token_ids = [tid for _, tid in token_logprobs] + token_texts = self.tokenizer.batch_decode(token_ids) + return [ + (logprob, token_id, token_text) + for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) + ] + + def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text): + for i, t in enumerate(top_logprobs): + if t: + top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) + return top_logprobs diff --git a/python/sglang/srt/openai_protocol.py b/python/sglang/srt/openai_protocol.py index 1cf1fed73..95b92948d 100644 --- a/python/sglang/srt/openai_protocol.py +++ b/python/sglang/srt/openai_protocol.py @@ -1,3 +1,4 @@ +"""pydantic models for OpenAI API protocol""" import time from typing import Dict, List, Optional, Union diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c1b7780cc..7a22632f4 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -10,7 +10,7 @@ import threading import time from typing import List, Optional, Union -# Fix a Python bug +# Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp @@ -53,10 +53,10 @@ from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - enable_show_time_cost, allocate_init_ports, - jsonify_pydantic_model, assert_pkg_version, + enable_show_time_cost, + jsonify_pydantic_model, get_exception_traceback, API_KEY_HEADER_NAME, APIKeyValidatorMiddleware @@ -99,12 +99,6 @@ async def flush_cache(): ) -async def stream_generator(obj: GenerateReqInput): - async for out in tokenizer_manager.generate_request(obj): - await handle_token_logprobs_results(obj, out) - yield out - - @app.post("/generate") async def generate_request(obj: GenerateReqInput): obj.post_init() @@ -112,69 +106,16 @@ async def generate_request(obj: GenerateReqInput): if obj.stream: async def stream_results(): - async for out in stream_generator(obj): + async for out in tokenizer_manager.generate_request(obj): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(stream_results(), media_type="text/event-stream") ret = await tokenizer_manager.generate_request(obj).__anext__() - await handle_token_logprobs_results(obj, ret) - return ret -async def detokenize_logprob_tokens(token_logprobs, decode_to_text): - if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - token_ids = [tid for _, tid in token_logprobs] - token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids)) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) - ] - - -async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text): - for i, t in enumerate(top_logprobs): - if top_logprobs[i] is not None: - top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text) - return top_logprobs - - -async def handle_token_logprobs_results(obj: GenerateReqInput, ret): - """Handle the token logprobs results, convert token ids to text if needed. - - Args: - obj (GenerateReqInput): The request object. - ret (Union[Dict, List[Dict]]): The response object. - """ - # NOTE: This is because the multiple requests in one http request. - - async def convert_style(r, return_text): - r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens( - r["meta_info"]["prefill_token_logprobs"], return_text - ) - r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens( - r["meta_info"]["decode_token_logprobs"], return_text - ) - r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens( - r["meta_info"]["prefill_top_logprobs"], return_text - ) - r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens( - r["meta_info"]["decode_top_logprobs"], return_text - ) - - if isinstance(obj.text, str): - if obj.return_logprob: - await convert_style(ret, obj.return_text_in_logprobs) - else: - for i, r in enumerate(ret): - if obj.return_logprob[i]: - await convert_style(r, obj.return_text_in_logprobs) - - @app.post("/v1/completions") async def v1_completions(raw_request: Request): request_json = await raw_request.json() @@ -203,10 +144,10 @@ async def v1_completions(raw_request: Request): if adapted_request.stream: - async def gnerate_stream_resp(): + async def generate_stream_resp(): stream_buffer = "" n_prev_token = 0 - async for content in stream_generator(adapted_request): + async for content in tokenizer_manager.generate_request(adapted_request): text = content["text"] prompt_tokens = content["meta_info"]["prompt_tokens"] completion_tokens = content["meta_info"]["completion_tokens"] @@ -266,7 +207,7 @@ async def v1_completions(raw_request: Request): yield f"data: {jsonify_pydantic_model(chunk)}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") + return StreamingResponse(generate_stream_resp(), media_type="text/event-stream") # Non-streaming response. ret = await generate_request(adapted_request) @@ -384,7 +325,7 @@ async def v1_chat_completions(raw_request: Request): is_first = True stream_buffer = "" - async for content in stream_generator(adapted_request): + async for content in tokenizer_manager.generate_request(adapted_request): if is_first: # First chunk with role is_first = False diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 78b8537c4..0bbbb3e3c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -241,7 +241,7 @@ class ServerArgs: def print_mode_args(self): return ( f"enable_flashinfer={self.enable_flashinfer}, " - f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" + f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, " f"disable_radix_cache={self.disable_radix_cache}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " f"disable_disk_cache={self.disable_disk_cache}, " diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 774bdf2c9..56f408db0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,3 +1,5 @@ +"""Common utilities.""" + import base64 import os import random @@ -13,6 +15,7 @@ import numpy as np import pydantic import requests import torch +from fastapi.responses import JSONResponse from packaging import version as pkg_version from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware @@ -303,6 +306,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): response = await call_next(request) return response + # FIXME: Remove this once we drop support for pydantic 1.x IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 @@ -310,4 +314,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 def jsonify_pydantic_model(obj: BaseModel): if IS_PYDANTIC_1: return obj.json(ensure_ascii=False) - return obj.model_dump_json() + return obj.model_dump_json() \ No newline at end of file diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 32c319166..853ed9846 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -296,7 +296,7 @@ def test_parallel_encoding(check_answer=True): def test_image_qa(): @sgl.function def image_qa(s, question): - s += sgl.user(sgl.image("test_image.png") + question) + s += sgl.user(sgl.image("example_image.png") + question) s += sgl.assistant(sgl.gen("answer")) state = image_qa.run( diff --git a/test/killall_python.sh b/test/killall_sglang.sh similarity index 100% rename from test/killall_python.sh rename to test/killall_sglang.sh diff --git a/test/lang/example_image.png b/test/lang/example_image.png new file mode 100644 index 000000000..851d08560 Binary files /dev/null and b/test/lang/example_image.png differ diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py index bb8b9f77e..90d097956 100644 --- a/test/lang/test_openai_backend.py +++ b/test/lang/test_openai_backend.py @@ -28,7 +28,7 @@ class TestOpenAIBackend(unittest.TestCase): if cls.backend is None: cls.backend = OpenAI("gpt-3.5-turbo-instruct") cls.chat_backend = OpenAI("gpt-3.5-turbo") - cls.chat_vision_backend = OpenAI("gpt-4-vision-preview") + cls.chat_vision_backend = OpenAI("gpt-4-turbo") def test_few_shot_qa(self): set_default_backend(self.backend) @@ -88,14 +88,3 @@ if __name__ == "__main__": # t = TestOpenAIBackend() # t.setUp() # t.test_few_shot_qa() - # t.test_mt_bench() - # t.test_select() - # t.test_decode_int() - # t.test_decode_json() - # t.test_expert_answer() - # t.test_tool_use() - # t.test_react() - # t.test_parallel_decoding() - # t.test_parallel_encoding() - # t.test_image_qa() - # t.test_stream() diff --git a/test/lang/test_openai_spec.py b/test/lang/test_openai_spec.py deleted file mode 100644 index fdda19127..000000000 --- a/test/lang/test_openai_spec.py +++ /dev/null @@ -1,68 +0,0 @@ -from sglang import OpenAI, function, gen, set_default_backend - - -@function() -def gen_character_default(s): - s += "Construct a character within the following format:\n" - s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") - s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n" - - -@function(api_num_spec_tokens=512) -def gen_character_spec(s): - s += "Construct a character within the following format:\n" - s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n") - s += "\nJob:" + gen("job", stop="\n") + "\nWelcome.\n" - - -@function(api_num_spec_tokens=512) -def gen_character_no_stop(s): - s += "Construct a character within the following format:\n" - s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\nWelcome.\n" - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name") + "\nBirthday:" + gen("birthday") - s += "\nJob:" + gen("job") + "\nWelcome.\n" - - -@function(api_num_spec_tokens=512) -def gen_character_multi_stop(s): - s += "Construct a character within the following format:\n" - s += ( - "Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.\nWelcome.\n" - ) - s += "\nPlease generate new Name, Birthday and Job.\n" - s += "Name:" + gen("name", stop=["\n", "###"]) - s += "###Birthday:" + gen("birthday", stop=["\n", "###"]) - s += "###Job:" + gen("job", stop=["\n", "###"]) + "\nWelcome.\n" - - -set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) - -state = gen_character_default.run() -print(state.text()) - -print("=" * 60) - -state = gen_character_no_stop.run() - -print("name###", state["name"]) -print("birthday###:", state["birthday"]) -print("job###", state["job"]) - -print("=" * 60) - -state = gen_character_multi_stop.run() -print(state.text()) - -print("=" * 60) - -state = gen_character_spec.run() -print(state.text()) - -print("name###", state["name"]) -print("birthday###", state["birthday"]) -print("job###", state["job"]) diff --git a/test/srt/example_image.png b/test/srt/example_image.png new file mode 120000 index 000000000..c8a970edd --- /dev/null +++ b/test/srt/example_image.png @@ -0,0 +1 @@ +../lang/example_image.png \ No newline at end of file diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index 04897b398..7e169f3e4 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -3,7 +3,6 @@ Usage: python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 test_httpserver_decode.py - Output: The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo """ @@ -23,6 +22,7 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text): "temperature": 0, "max_new_tokens": 32, }, + "stream": False, "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "return_text_in_logprobs": return_text, diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index 7c2b5da1e..38f090b7d 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -26,6 +26,7 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "return_text_in_logprobs": True, + "logprob_start_len": 0, }, stream=True, ) diff --git a/test/srt/test_httpserver_llava.py b/test/srt/test_httpserver_llava.py index 0f6571b45..e3cf1b799 100644 --- a/test/srt/test_httpserver_llava.py +++ b/test/srt/test_httpserver_llava.py @@ -34,7 +34,7 @@ async def test_concurrent(args): url + "/generate", { "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", - "image_data": "test_image.png", + "image_data": "example_image.png", "sampling_params": { "temperature": 0, "max_new_tokens": 16, @@ -55,7 +55,7 @@ def test_streaming(args): url + "/generate", json={ "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", - "image_data": "test_image.png", + "image_data": "example_image.png", "sampling_params": { "temperature": 0, "max_new_tokens": 128, diff --git a/test/srt/test_httpserver_reuse.py b/test/srt/test_httpserver_reuse.py index c3f786589..36804e4b7 100644 --- a/test/srt/test_httpserver_reuse.py +++ b/test/srt/test_httpserver_reuse.py @@ -6,10 +6,10 @@ The capital of France is Paris.\nThe capital of the United States is Washington, """ import argparse -import time import requests + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1") diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 4cc50af85..569eaccb7 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -163,7 +163,7 @@ def test_regex(args): regex = ( r"""\{\n""" + r""" "name": "[\w]+",\n""" - + r""" "population": "[\w\d\s]+"\n""" + + r""" "population": [\w\d\s]+\n""" + r"""\}""" )