From 9ba1f0976035fe7212002cac3b2b9df9f0685334 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 15 Sep 2024 06:36:06 -0700 Subject: [PATCH] [Fix] Fix logprob and normalized_logprob (#1428) --- .github/workflows/pr-test.yml | 25 +++- docs/en/backend.md | 2 +- python/sglang/bench_latency.py | 3 + .../sglang/lang/backend/runtime_endpoint.py | 7 +- python/sglang/lang/interpreter.py | 2 +- python/sglang/srt/layers/attention_backend.py | 4 + python/sglang/srt/layers/logits_processor.py | 137 ++++++++---------- python/sglang/srt/managers/io_struct.py | 3 +- python/sglang/srt/managers/schedule_batch.py | 94 +++++++----- .../sglang/srt/managers/tokenizer_manager.py | 6 - python/sglang/srt/managers/tp_worker.py | 68 ++++++--- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../srt/model_executor/forward_batch_info.py | 26 +--- python/sglang/srt/openai_api/adapter.py | 6 +- .../srt/sampling/sampling_batch_info.py | 32 ++-- test/srt/run_suite.py | 12 +- test/srt/test_chunked_prefill.py | 4 +- test/srt/test_json_constrained.py | 9 +- test/srt/test_openai_server.py | 24 ++- test/srt/test_srt_endpoint.py | 60 ++++++-- test/srt/test_torchao.py | 2 +- test/srt/test_triton_attention_kernels.py | 1 - 22 files changed, 314 insertions(+), 215 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index a6602ea8b..fa0fefdaf 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -54,7 +54,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 8 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 7 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -73,7 +73,26 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 8 + python3 run_suite.py --suite minimal --range-begin 7 --range-end 14 + + unit-test-backend-part-3: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[dev]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Run test + timeout-minutes: 20 + run: | + cd test/srt + python3 run_suite.py --suite minimal --range-begin 14 performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -217,7 +236,7 @@ jobs: finish: needs: [ - unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, + unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, accuracy-test-1-gpu, accuracy-test-2-gpu ] diff --git a/docs/en/backend.md b/docs/en/backend.md index e93ce5543..d8c3c7fb1 100644 --- a/docs/en/backend.md +++ b/docs/en/backend.md @@ -91,7 +91,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct # Node 1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 ``` - + ### Supported Models **Generative Models** diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 93fcc0115..670d85f48 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -164,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): req.prefix_indices = [] req.sampling_params = sampling_params req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) reqs.append(req) return input_ids, reqs @@ -178,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test( req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ i, : bench_args.cut_len ] + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) return reqs @@ -194,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): req.prefix_indices = [] req.sampling_params = sampling_params req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) reqs.append(req) return reqs diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 344b51d2d..e1194b6cf 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -239,9 +239,12 @@ class RuntimeEndpoint(BaseBackend): # Compute logprob data = { "text": [s.text_ + c for c in choices], - "sampling_params": {"max_new_tokens": 0}, + "sampling_params": { + "max_new_tokens": 0, + "temperature": 0, + }, "return_logprob": True, - "logprob_start_len": max(prompt_len - 2, 0), + "logprob_start_len": max(prompt_len - 2, 0), # for token healing } obj = self._generate_http_request(s, data) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 2f8ea7e78..0f1a5c9f2 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -9,7 +9,7 @@ import uuid import warnings from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import tqdm diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 835664fb6..b66fc9342 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -56,6 +56,7 @@ class AttentionBackend(ABC): raise NotImplementedError() def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): @@ -66,9 +67,11 @@ class AttentionBackend(ABC): return self.forward_extend(q, k, v, layer, input_metadata) def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + """Run a forward for decode.""" raise NotImplementedError() def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + """Run a forward for extend.""" raise NotImplementedError() @@ -299,6 +302,7 @@ class FlashInferAttnBackend(AttentionBackend): ) if total_num_tokens >= global_config.layer_sync_threshold: + # TODO: Revisit this. Why is this synchronize needed? torch.cuda.synchronize() return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 72a926cab..440c96392 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -37,7 +37,7 @@ class LogitsProcessorOutput: # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor - # The logprobs of input tokens. shape: [#token, vocab_size] + # The logprobs of input tokens. shape: [#token, vocab_size] input_token_logprobs: torch.Tensor # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) @@ -49,25 +49,39 @@ class LogitsProcessorOutput: @dataclasses.dataclass class LogitsMetadata: forward_mode: ForwardMode + top_logprobs_nums: Optional[List[int]] + return_logprob: bool = False + return_top_logprob: bool = False extend_seq_lens: Optional[torch.Tensor] = None - extend_start_loc: Optional[torch.Tensor] = None - top_logprobs_nums: Optional[List[int]] = None + extend_seq_lens_cpu: Optional[List[int]] = None - extend_seq_lens_cpu: List[int] = None - logprob_start_lens_cpu: List[int] = None + extend_logprob_start_lens_cpu: Optional[List[int]] = None + extend_logprob_pruned_lens_cpu: Optional[List[int]] = None @classmethod def from_input_metadata(cls, input_metadata: InputMetadata): + return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) + if input_metadata.forward_mode.is_extend(): + extend_logprob_pruned_lens_cpu = [ + extend_len - start_len + for extend_len, start_len in zip( + input_metadata.extend_seq_lens, + input_metadata.extend_logprob_start_lens_cpu, + ) + ] + else: + extend_logprob_pruned_lens_cpu = None return cls( forward_mode=input_metadata.forward_mode, - extend_seq_lens=input_metadata.extend_seq_lens, - extend_start_loc=input_metadata.extend_start_loc, - return_logprob=input_metadata.return_logprob, top_logprobs_nums=input_metadata.top_logprobs_nums, + return_logprob=input_metadata.return_logprob, + return_top_logprob=return_top_logprob, + extend_seq_lens=input_metadata.extend_seq_lens, extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu, - logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu, + extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu, + extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, ) @@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module): def _get_normalized_prompt_logprobs( self, input_token_logprobs: torch.Tensor, - cum_start_len0: torch.Tensor, - cum_start_len1: torch.Tensor, logits_metadata: LogitsMetadata, ): logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) + pruned_lens = torch.tensor( + logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda" + ) - start = logits_metadata.extend_start_loc.clone() - cum_start_len0 - end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1 - start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) - end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) + start = torch.zeros_like(pruned_lens) + start[1:] = torch.cumsum(pruned_lens[:-1], dim=0) + end = torch.clamp( + start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1 + ) sum_logp = ( logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] ) - normalized_prompt_logprobs = sum_logp / ( - (logits_metadata.extend_seq_lens - 1).clamp(min=1) - ) - + normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1) return normalized_prompt_logprobs @staticmethod def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): + max_k = max(logits_metadata.top_logprobs_nums) + ret = all_logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + if logits_metadata.forward_mode.is_decode(): output_top_logprobs = [] - max_k = max(logits_metadata.top_logprobs_nums) - ret = all_logprobs.topk(max_k, dim=1) - values = ret.values.tolist() - indices = ret.indices.tolist() for i, k in enumerate(logits_metadata.top_logprobs_nums): output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) return None, output_top_logprobs else: - # TODO: vectorize the code below input_top_logprobs, output_top_logprobs = [], [] + pt = 0 - extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu - - max_k = max(logits_metadata.top_logprobs_nums) - ret = all_logprobs.topk(max_k, dim=1) - values = ret.values.tolist() - indices = ret.indices.tolist() - - for i, extend_seq_len in enumerate(extend_seq_lens_cpu): - start_len = logits_metadata.logprob_start_lens_cpu[i] - pruned_len = extend_seq_len - start_len - - if extend_seq_len == 0: + for k, pruned_len in zip( + logits_metadata.top_logprobs_nums, + logits_metadata.extend_logprob_pruned_lens_cpu, + ): + if pruned_len <= 0: input_top_logprobs.append([]) output_top_logprobs.append([]) continue - k = logits_metadata.top_logprobs_nums[i] input_top_logprobs.append( [ list(zip(values[pt + j][:k], indices[pt + j][:k])) @@ -167,10 +173,7 @@ class LogitsProcessor(nn.Module): last_index = None last_hidden = hidden_states else: - last_index = ( - torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long) - - 1 - ) + last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] last_logits = torch.matmul(last_hidden, weight.T) @@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module): output_top_logprobs=None, ) else: - # When logprob is requested, compute the logits for all tokens. - if logits_metadata.forward_mode.is_decode(): - last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) + last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) - # Get the logprob of top-k tokens - return_top_logprob = any( - x > 0 for x in logits_metadata.top_logprobs_nums - ) - if return_top_logprob: + if logits_metadata.forward_mode.is_decode(): + if logits_metadata.return_top_logprob: output_top_logprobs = self.get_top_logprobs( last_logprobs, logits_metadata )[1] else: output_top_logprobs = None - return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, @@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module): output_top_logprobs=output_top_logprobs, ) else: + # Slice the requested tokens to compute logprob pt, states, pruned_input_ids = 0, [], [] - for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu): - start_len = logits_metadata.logprob_start_lens_cpu[i] + for start_len, extend_len in zip( + logits_metadata.extend_logprob_start_lens_cpu, + logits_metadata.extend_seq_lens_cpu, + ): states.append(hidden_states[pt + start_len : pt + extend_len]) pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) pt += extend_len + # Compute the logits and logprobs for all required tokens states = torch.cat(states, dim=0) - pruned_input_ids = torch.cat(pruned_input_ids, dim=0) - - cum_start_len1 = torch.tensor( - logits_metadata.logprob_start_lens_cpu, device="cuda" - ).cumsum(0) - cum_start_len0 = torch.zeros_like(cum_start_len1) - cum_start_len0[1:] = cum_start_len1[:-1] - all_logits = torch.matmul(states, weight.T) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) @@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module): all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) # Get the logprob of top-k tokens - return_top_logprob = any( - x > 0 for x in logits_metadata.top_logprobs_nums - ) - if return_top_logprob: + if logits_metadata.return_top_logprob: input_top_logprobs, output_top_logprobs = self.get_top_logprobs( all_logprobs, logits_metadata ) else: input_top_logprobs = output_top_logprobs = None - last_logprobs = all_logprobs[last_index - cum_start_len1] - - # Compute the logprobs and normalized logprobs for the prefill tokens. - # Note that we pad a zero at the end of each sequence for easy computation. + # Compute the normalized logprobs for the requested tokens. + # Note that we pad a zero at the end for easy batching. input_token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]), + torch.cat( + [ + torch.cat(pruned_input_ids)[1:], + torch.tensor([0], device="cuda"), + ] + ), ] - normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( input_token_logprobs, - cum_start_len0, - cum_start_len1, logits_metadata, ) - # Remove the last token logprob for the prefill tokens. - input_token_logprobs = input_token_logprobs[:-1] - return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index abd10a9f1..08e43ea08 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). import copy import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason @@ -43,6 +43,7 @@ class GenerateReqInput: # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None # If return logprobs, the start location in the prompt for returning logprobs. + # By default, this value is "-1", which means it will only return logprobs for output tokens. logprob_start_len: Optional[Union[List[int], int]] = None # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: Optional[Union[List[int], int]] = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 13339cddc..cd6d148a7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -19,7 +19,7 @@ limitations under the License. import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Tuple, Union import torch @@ -53,7 +53,7 @@ class BaseFinishReason: self.is_error = is_error def to_json(self): - raise NotImplementedError("Subclasses must implement this method") + raise NotImplementedError() class FINISH_MATCHED_TOKEN(BaseFinishReason): @@ -105,7 +105,13 @@ class FINISH_ABORT(BaseFinishReason): class Req: """Store all inforamtion of a request.""" - def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None): + def __init__( + self, + rid: str, + origin_input_text: str, + origin_input_ids: Tuple[int], + lora_path: Optional[str] = None, + ): # Input and output info self.rid = rid self.origin_input_text = origin_input_text @@ -118,6 +124,10 @@ class Req: # Memory info self.req_pool_idx = None + # Check finish + self.tokenizer = None + self.finished_reason = None + # For incremental decoding # ----- | --------- read_ids -------| # ----- | surr_ids | @@ -136,7 +146,7 @@ class Req: # this does not include the jump forward tokens. self.completion_tokens_wo_jump_forward = 0 - # For vision input + # For vision inputs self.pixel_values = None self.image_sizes = None self.image_offsets = None @@ -144,31 +154,35 @@ class Req: self.modalities = None # Prefix info - self.extend_input_len = 0 self.prefix_indices = [] + self.extend_input_len = 0 self.last_node = None # Sampling parameters self.sampling_params = None self.stream = False - # Check finish - self.tokenizer = None - self.finished_reason = None - - # Logprobs + # Logprobs (arguments) self.return_logprob = False - self.embedding = None self.logprob_start_len = 0 self.top_logprobs_num = 0 + + # Logprobs (return value) self.normalized_prompt_logprob = None self.input_token_logprobs = None self.input_top_logprobs = None self.output_token_logprobs = [] self.output_top_logprobs = [] + + # Logprobs (internal values) # The tokens is prefilled but need to be considered as decode tokens # and should be updated for the decode logprobs self.last_update_decode_tokens = 0 + # The relative logprob_start_len in an extend batch + self.extend_logprob_start_len = 0 + + # Embedding + self.embedding = None # Constrained decoding self.regex_fsm: RegexGuide = None @@ -363,9 +377,13 @@ class ScheduleBatch: return_logprob: bool = False top_logprobs_nums: List[int] = None + # Stream + has_stream: bool = False + @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) + has_stream = any(req.stream for req in reqs) return cls( reqs=reqs, @@ -373,18 +391,15 @@ class ScheduleBatch: token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, return_logprob=return_logprob, + has_stream=has_stream, ) def batch_size(self): - return len(self.reqs) if self.reqs else 0 + return len(self.reqs) def is_empty(self): return len(self.reqs) == 0 - def has_stream(self) -> bool: - # Return whether batch has at least 1 streaming request - return any(r.stream for r in self.reqs) - def alloc_req_slots(self, num_reqs): req_pool_indices = self.req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: @@ -427,8 +442,8 @@ class ScheduleBatch: for i, req in enumerate(reqs): req.req_pool_idx = req_pool_indices_cpu[i] pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) - ext_len = seq_len - pre_len seq_lens.append(seq_len) + assert seq_len - pre_len == req.extend_input_len if pre_len > 0: self.req_to_token_pool.req_to_token[req.req_pool_idx][ @@ -436,9 +451,19 @@ class ScheduleBatch: ] = req.prefix_indices self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( - out_cache_loc[pt : pt + ext_len] + out_cache_loc[pt : pt + req.extend_input_len] ) - pt += ext_len + + # Compute the relative logprob_start_len in an extend batch + if req.logprob_start_len >= pre_len: + extend_logprob_start_len = min( + req.logprob_start_len - pre_len, req.extend_input_len - 1 + ) + else: + extend_logprob_start_len = req.extend_input_len - 1 + + req.extend_logprob_start_len = extend_logprob_start_len + pt += req.extend_input_len # Set fields with torch.device("cuda"): @@ -451,21 +476,13 @@ class ScheduleBatch: self.out_cache_loc = out_cache_loc self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] - + self.extend_lens_cpu = [r.extend_input_len for r in reqs] + self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED - self.running_bs = running_batch.batch_size() - - # NOTE: prefix_indices is what has been cached, but we don't cache each decode step - prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs] - prefix_lens_cpu.extend( - [ - len(r.origin_input_ids) + len(r.output_ids) - 1 - for r in running_batch.reqs - ] - ) + running_bs = running_batch.batch_size() for req in running_batch.reqs: req.fill_ids = req.origin_input_ids + req.output_ids @@ -473,12 +490,22 @@ class ScheduleBatch: input_ids = torch.cat([self.input_ids, running_batch.input_ids]) out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) - extend_num_tokens = self.extend_num_tokens + running_batch.batch_size() + extend_num_tokens = self.extend_num_tokens + running_bs + self.merge(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc self.extend_num_tokens = extend_num_tokens - self.prefix_lens_cpu = prefix_lens_cpu + + # NOTE: prefix_indices is what has been cached, but we don't cache each decode step + self.prefix_lens_cpu.extend( + [ + len(r.origin_input_ids) + len(r.output_ids) - 1 + for r in running_batch.reqs + ] + ) + self.extend_lens_cpu.extend([1] * running_bs) + self.extend_logprob_start_lens_cpu.extend([0] * running_bs) def check_decode_mem(self): bs = self.batch_size() @@ -685,6 +712,7 @@ class ScheduleBatch: self.out_cache_loc = None self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) + self.has_stream = any(req.stream for req in self.reqs) self.sampling_info.filter(unfinished_indices, new_indices) @@ -695,7 +723,6 @@ class ScheduleBatch: self.sampling_info.merge(other.sampling_info) self.reqs.extend(other.reqs) - self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) @@ -706,3 +733,4 @@ class ScheduleBatch: self.out_cache_loc = None self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) + self.has_stream = any(req.stream for req in self.reqs) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 54b08b337..84cd39469 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -197,8 +197,6 @@ class TokenizerManager: if not_use_index else obj.logprob_start_len[index] ) - if return_logprob and logprob_start_len == -1: - logprob_start_len = len(input_ids) - 1 top_logprobs_num = ( obj.top_logprobs_num if not_use_index @@ -251,8 +249,6 @@ class TokenizerManager: # Send to the controller if self.is_generation: - if return_logprob and logprob_start_len == -1: - logprob_start_len = len(input_ids) - 1 tokenized_obj = TokenizedGenerateReqInput( rid, input_text, @@ -349,8 +345,6 @@ class TokenizerManager: sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: - if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: - obj.logprob_start_len[index] = len(input_ids) - 1 pixel_values, image_hashes, image_sizes = ( await self._get_pixel_values(obj.image_data[index]) ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 3952e081a..09a2ede21 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -278,7 +278,7 @@ class ModelTpServer: self.running_batch = None break - if self.out_pyobjs and self.running_batch.has_stream(): + if self.out_pyobjs and self.running_batch.has_stream: break else: self.check_memory() @@ -360,9 +360,13 @@ class ModelTpServer: # Only when pixel values is not None we have modalities req.modalities = recv_req.modalites req.return_logprob = recv_req.return_logprob - req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream + req.logprob_start_len = recv_req.logprob_start_len + + if req.logprob_start_len == -1: + # By default, only return the logprobs for output tokens + req.logprob_start_len = len(recv_req.input_ids) - 1 # Init regex FSM if ( @@ -384,7 +388,7 @@ class ModelTpServer: # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: - logger.warn( + logger.warning( "Request length is longer than the KV cache pool size or " "the max context length. Truncated!!!" ) @@ -583,7 +587,7 @@ class ModelTpServer: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) # Check finish conditions - pt = 0 + logprob_pt = 0 for i, req in enumerate(batch.reqs): if req is not self.current_inflight_req: # Inflight reqs' prefill is not finished @@ -607,10 +611,9 @@ class ModelTpServer: self.req_to_token_pool.free(req.req_pool_idx) if req.return_logprob: - self.add_logprob_return_values( - i, req, pt, next_token_ids, logits_output + logprob_pt += self.add_logprob_return_values( + i, req, logprob_pt, next_token_ids, logits_output ) - pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 logits_output = self.model_runner.forward(batch) @@ -638,48 +641,63 @@ class ModelTpServer: def add_logprob_return_values( self, - i, + i: int, req: Req, pt: int, next_token_ids: List[int], output: LogitsProcessorOutput, ): + """Attach logprobs to the return values.""" + req.output_token_logprobs.append( + (output.next_token_logprobs[i], next_token_ids[i]) + ) + + # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. + num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len + if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] if req.input_token_logprobs is None: - # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - req.input_token_logprobs = list( - zip( - output.input_token_logprobs[pt : pt + req.extend_input_len - 1], - req.fill_ids[-req.extend_input_len + 1 :], - ) - ) - if req.logprob_start_len == 0: + input_token_logprobs = output.input_token_logprobs[ + pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens + ] + input_token_ids = req.fill_ids[ + len(req.fill_ids) + - num_input_logprobs + + 1 : len(req.fill_ids) + - req.last_update_decode_tokens + ] + req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids)) + + if ( + req.logprob_start_len == 0 + ): # The first token does not have logprob, pad it. req.input_token_logprobs = [ (None, req.fill_ids[0]) ] + req.input_token_logprobs if req.last_update_decode_tokens != 0: + # Some decode tokens are re-computed in an extend batch req.output_token_logprobs.extend( list( zip( output.input_token_logprobs[ pt - + req.extend_input_len + + num_input_logprobs + - 1 - req.last_update_decode_tokens : pt - + req.extend_input_len + + num_input_logprobs - 1 ], - req.fill_ids[-req.last_update_decode_tokens + 1 :], + req.fill_ids[ + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) + ], ) ) ) - req.output_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) - if req.top_logprobs_num > 0: if req.input_top_logprobs is None: req.input_top_logprobs = output.input_top_logprobs[i] @@ -688,10 +706,12 @@ class ModelTpServer: if req.last_update_decode_tokens != 0: req.output_top_logprobs.extend( - output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :] + output.input_top_logprobs[i][-req.last_update_decode_tokens :] ) req.output_top_logprobs.append(output.output_top_logprobs[i]) + return num_input_logprobs + def forward_decode_batch(self, batch: ScheduleBatch): # Check if decode out of memory if not batch.check_decode_mem(): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d25694329..66d0b5c53 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -193,7 +193,7 @@ class CudaGraphRunner: attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, return_logprob=False, - top_logprobs_nums=0, + top_logprobs_nums=[0] * bs, positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), ) return forward(input_ids, input_metadata.positions, input_metadata) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 0ad568860..4815fbc56 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -81,7 +81,7 @@ class InputMetadata: return_logprob: bool = False top_logprobs_nums: List[int] = None extend_seq_lens_cpu: List[int] = None - logprob_start_lens_cpu: List[int] = None + extend_logprob_start_lens_cpu: List[int] = None # For multimodal pixel_values: List[torch.Tensor] = None @@ -138,27 +138,13 @@ class InputMetadata: self.positions = self.positions.to(torch.int64) def compute_extend_infos(self, batch: ScheduleBatch): - extend_lens_cpu = [ - len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs) - ] - self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") + self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda") self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") - self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc = torch.zeros_like(self.extend_seq_lens) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) - - self.extend_seq_lens_cpu = extend_lens_cpu - self.logprob_start_lens_cpu = [ - ( - min( - req.logprob_start_len - batch.prefix_lens_cpu[i], - extend_lens_cpu[i] - 1, - ) - if req.logprob_start_len >= batch.prefix_lens_cpu[i] - else extend_lens_cpu[i] - 1 # Fake extend, actually decode - ) - for i, req in enumerate(batch.reqs) - ] + self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu) + self.extend_seq_lens_cpu = batch.extend_lens_cpu + self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu @classmethod def from_schedule_batch( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 1b8169af6..a4869f5cc 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -22,7 +22,7 @@ import os import time import uuid from http import HTTPStatus -from typing import Dict, List, Optional +from typing import Dict, List from fastapi import HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse @@ -472,7 +472,7 @@ def v1_generate_request( first_prompt_type = type(all_requests[0].prompt) for request in all_requests: assert ( - type(request.prompt) == first_prompt_type + type(request.prompt) is first_prompt_type ), "All prompts must be of the same type in file input settings" if len(all_requests) > 1 and request.n > 1: raise ValueError( @@ -887,7 +887,7 @@ def v1_chat_generate_request( input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) - top_logprobs_nums.append(request.top_logprobs) + top_logprobs_nums.append(request.top_logprobs or 0) sampling_params = { "temperature": request.temperature, diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 8eb8e0882..b73a15a91 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -86,24 +86,24 @@ class SamplingBatchInfo: @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): - device = "cuda" reqs = batch.reqs ret = cls(vocab_size=vocab_size) - ret.temperatures = torch.tensor( - [r.sampling_params.temperature for r in reqs], - dtype=torch.float, - device=device, - ).view(-1, 1) - ret.top_ps = torch.tensor( - [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device - ) - ret.top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device - ) - ret.min_ps = torch.tensor( - [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device - ) + with torch.device("cuda"): + ret.temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], + dtype=torch.float, + ).view(-1, 1) + ret.top_ps = torch.tensor( + [r.sampling_params.top_p for r in reqs], dtype=torch.float + ) + ret.top_ks = torch.tensor( + [r.sampling_params.top_k for r in reqs], dtype=torch.int + ) + ret.min_ps = torch.tensor( + [r.sampling_params.min_p for r in reqs], dtype=torch.float + ) + ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) # Each penalizers will do nothing if they evaluate themselves as not required by looking at @@ -116,7 +116,7 @@ class SamplingBatchInfo: ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( vocab_size=vocab_size, batch=batch, - device=device, + device="cuda", Penalizers={ penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedMinNewTokensPenalizer, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 943c50144..9210be948 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -11,16 +11,18 @@ suites = { "test_chunked_prefill.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_json_constrained.py", "test_large_max_new_tokens.py", "test_openai_server.py", - "test_json_constrained.py", - "test_skip_tokenizer_init.py", - "test_torch_compile.py", - "test_triton_attn_backend.py", "test_pytorch_sampling_backend.py", + "test_server_args.py", + "test_skip_tokenizer_init.py", + "test_srt_endpoint.py", + "test_torch_compile.py", + "test_torchao.py", + "test_triton_attn_backend.py", "test_update_weights.py", "test_vision_openai_server.py", - "test_server_args.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 2eb704dc9..057f42f05 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase): base_url=base_url, model=model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) try: metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 finally: kill_child_process(process.pid) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 122d79968..d3abc70a4 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase): def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" cls.json_schema = json.dumps( { "type": "object", @@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase): "required": ["name", "population"], } ) - cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, api_key=cls.api_key - ) + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): - headers = {"Authorization": f"Bearer {self.api_key}"} response = requests.post( self.base_url + "/generate", json={ @@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase): "top_logprobs_num": top_logprobs_num, "logprob_start_len": 0, }, - headers=headers, ) print(json.dumps(response.json())) print("=" * 100) @@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase): self.run_decode() def test_json_openai(self): - client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") response = client.chat.completions.create( model=self.model, diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 3fc578551..87d85c0cd 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase): assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) - # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" - assert ret_num_top_logprobs > 0 - assert response.choices[0].logprobs.token_logprobs[0] != None + + assert response.choices[0].logprobs.token_logprobs[0] assert response.id assert response.created @@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase): ret_num_top_logprobs = len( response.choices[0].logprobs.top_logprobs[0] ) - # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map + # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" assert ret_num_top_logprobs > 0 @@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase): assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) + def test_penalty(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=32, + frequency_penalty=1.0, + ) + text = response.choices[0].message.content + assert isinstance(text, str) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 818aae215..9a0a37c60 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -1,3 +1,7 @@ +""" +python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode +""" + import json import unittest @@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase): "text": "The capital of France is", "sampling_params": { "temperature": 0 if n == 1 else 0.5, - "max_new_tokens": 32, + "max_new_tokens": 16, "n": n, }, "stream": stream, @@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase): for line in response.iter_lines(): if line.startswith(b"data: ") and line[6:] != b"[DONE]": response_json.append(json.loads(line[6:])) - print(json.dumps(response_json)) + + print(json.dumps(response_json, indent=2)) print("=" * 100) def test_simple_decode(self): @@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase): self.run_decode(n=3, stream=True) def test_logprob(self): - for top_logprobs_num in [0, 3]: - for return_text in [True, False]: - self.run_decode( - return_logprob=True, - top_logprobs_num=top_logprobs_num, - return_text=return_text, - ) + self.run_decode( + return_logprob=True, + top_logprobs_num=5, + return_text=True, + ) + + def test_logprob_start_len(self): + logprob_start_len = 4 + new_tokens = 4 + prompts = [ + "I have a very good idea on", + "Today is a sunndy day and", + ] + + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompts, + "sampling_params": { + "temperature": 0, + "max_new_tokens": new_tokens, + }, + "return_logprob": True, + "top_logprobs_num": 5, + "return_text_in_logprobs": True, + "logprob_start_len": logprob_start_len, + }, + ) + response_json = response.json() + print(json.dumps(response_json, indent=2)) + + for i, res in enumerate(response_json): + assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len( + res["meta_info"]["input_token_logprobs"] + ) + assert prompts[i].endswith( + "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]]) + ) + + assert res["meta_info"]["completion_tokens"] == new_tokens + assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens + res["text"] == "".join( + [x[-1] for x in res["meta_info"]["output_token_logprobs"]] + ) if __name__ == "__main__": diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py index d2084e7d5..8b5ce58ed 100644 --- a/test/srt/test_torchao.py +++ b/test/srt/test_torchao.py @@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.65 + assert metrics["score"] >= 0.60 def run_decode(self, max_new_tokens): response = requests.post( diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 79b26f67a..b312a8c30 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -127,7 +127,6 @@ class TestExtendAttention(unittest.TestCase): def _test_context_attention_once(self, head_dim): # Set up a simple test case - batch_size = 2 num_heads = 4 seq_lens = [8, 12] max_seq_len = max(seq_lens)