diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 0dbbc31da..10c11f659 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -14,28 +14,11 @@ class LogitsProcessor(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() def forward(self, input_ids, hidden_states, weight, input_metadata): - if not input_metadata.return_logprob: - if input_metadata.forward_mode == ForwardMode.DECODE: - last_hidden = hidden_states - else: - last_index = ( - torch.cumsum( - input_metadata.seq_lens - input_metadata.prefix_lens, - dim=0, - dtype=torch.long, - ) - - 1 - ) - last_hidden = hidden_states[last_index] - hidden_states = None + last_index = None - last_logits = torch.matmul(last_hidden, weight.T) - if self.tp_size > 1: - last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size] - return last_logits, (None, None) - else: - assert input_metadata.forward_mode != ForwardMode.DECODE + # Compute the last index (the first decode token) of each requeast + # if we are in prefill or extend mode. + if input_metadata.forward_mode != ForwardMode.DECODE: last_index = ( torch.cumsum( input_metadata.seq_lens - input_metadata.prefix_lens, @@ -45,29 +28,54 @@ class LogitsProcessor(nn.Module): - 1 ) + if not input_metadata.return_logprob: + # When logprob is not requested, only compute the last logits. + if input_metadata.forward_mode == ForwardMode.DECODE: + last_hidden = hidden_states + else: + last_hidden = hidden_states[last_index] + hidden_states = None + + last_logits = torch.matmul(last_hidden, weight.T) + if self.tp_size > 1: + last_logits = tensor_model_parallel_all_gather(last_logits) + last_logits = last_logits[:, : self.config.vocab_size] + return last_logits, (None, None, None) + else: + # When logprob is requested, compute the logits for all tokens. logits = torch.matmul(hidden_states, weight.T) if self.tp_size > 1: logits = tensor_model_parallel_all_gather(logits) logits = logits[:, : self.config.vocab_size] all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) - logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), - ] - logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) + if input_metadata.forward_mode == ForwardMode.DECODE: + last_logits = logits + last_logprobs = all_logprobs + prefill_logprobs = normalized_logprobs = None + else: + # Compute the logprobs for the last token of each request. + last_logits = logits[last_index] + last_logprobs = all_logprobs[last_index] - start = input_metadata.extend_start_loc.clone() - end = start + input_metadata.extend_seq_lens - 2 - start.clamp_(min=0, max=logprobs.shape[0] - 1) - end.clamp_(min=0, max=logprobs.shape[0] - 1) - sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] - normalized_logprobs = sum_logp / ( - (input_metadata.extend_seq_lens - 1).clamp(min=1) - ) + # Compute the logprobs and normalized logprobs for the prefill tokens. + # Note that we pad a zero at the end of each sequence for easy computation. + prefill_logprobs = all_logprobs[ + torch.arange(all_logprobs.shape[0], device="cuda"), + torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), + ] + logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32) - last_logits = logits[last_index] - return last_logits, (logprobs, normalized_logprobs) + start = input_metadata.extend_start_loc.clone() + end = start + input_metadata.extend_seq_lens - 2 + start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) + end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1) + sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start] + normalized_logprobs = sum_logp / ( + (input_metadata.extend_seq_lens - 1).clamp(min=1) + ) + + return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs) if __name__ == "__main__": diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 4f2f4522a..39748c691 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -99,3 +99,7 @@ class BatchStrOut: @dataclass class FlushCacheReq: pass + +@dataclass +class DetokenizeReqInput: + input_ids: List[int] diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 339e003de..002a12927 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -48,6 +48,7 @@ class Req: self.last_node = None self.logprob = None + self.token_logprob = None self.normalized_logprob = None # For constrained decoding diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 4700b6311..fe61ce8ca 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service): self.model_config.vocab_size, self.int_token_logit_bias ) + logprobs = None if batch.extend_num_tokens != 0: # Forward - logits, (logprobs, normalized_logprobs) = self.model_runner.forward( - batch, ForwardMode.EXTEND, batch.return_logprob + logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = ( + self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob) ) - # print("extend logits", logits) - if logprobs is not None: - logprobs = logprobs.cpu().tolist() + if prefill_logprobs is not None: + logprobs = prefill_logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist() - next_token_ids, next_token_probs = batch.sample(logits) + next_token_ids, _ = batch.sample(logits) next_token_ids = next_token_ids.cpu().tolist() else: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - logprobs = normalized_logprobs = None + logits = logprobs = normalized_logprobs = last_logprobs = None + + # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. + reqs = batch.reqs + if last_logprobs is not None: + last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist() # Check finish condition - reqs = batch.reqs pt = 0 for i, req in enumerate(reqs): req.output_ids = [next_token_ids[i]] @@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service): if logprobs is not None: req.logprob = logprobs[pt : pt + req.extend_input_len - 1] req.normalized_logprob = normalized_logprobs[i] + + token_ids = req.input_ids + [next_token_ids[i]] + token_logprobs = [None] + req.logprob + [last_logprobs[i]] + req.token_logprob = list(zip(token_ids, token_logprobs)) pt += req.extend_input_len self.handle_finished_requests(batch) @@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service): batch.prepare_for_decode() # Forward - logits = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids, next_token_probs = batch.sample(logits) + logits, (_, _, last_logprobs) = self.model_runner.forward( + batch, + ForwardMode.DECODE, + batch.return_logprob, + ) + next_token_ids, _ = batch.sample(logits) next_token_ids = next_token_ids.cpu().tolist() - # Check finish condition + # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead. reqs = batch.reqs - for i in range(len(reqs)): - reqs[i].output_ids.append(next_token_ids[i]) - reqs[i].check_finished() + if last_logprobs is not None: + last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist() + + # Check finish condition + for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)): + req.output_ids.append(next_tok_id) + req.check_finished() + + if last_logprobs is not None: + req.token_logprob.append((next_tok_id, last_logprobs[i])) self.handle_finished_requests(batch) @@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service): } if req.return_logprob: meta_info["prompt_logprob"] = req.logprob + meta_info["token_logprob"] = req.token_logprob meta_info["normalized_prompt_logprob"] = req.normalized_logprob output_meta_info.append(meta_info) output_finished.append(req.finished) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 93e99fe23..1dd2180e8 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -397,6 +397,7 @@ class ModelRunner: out_cache_loc, out_cache_cont_start, out_cache_cont_end, + return_logprob, ): input_metadata = InputMetadata.create( self, @@ -409,10 +410,9 @@ class ModelRunner: out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, + return_logprob=return_logprob, ) - return self.model.forward(input_ids, input_metadata.positions, input_metadata)[ - 0 - ] + return self.model.forward(input_ids, input_metadata.positions, input_metadata) @torch.inference_mode() def forward_extend_multi_modal( @@ -460,8 +460,8 @@ class ModelRunner: "prefix_lens": batch.prefix_lens, "position_ids_offsets": batch.position_ids_offsets, "out_cache_loc": batch.out_cache_loc, + "return_logprob": return_logprob, } - kwargs["return_logprob"] = return_logprob return self.forward_extend_multi_modal(**kwargs) else: kwargs = { @@ -471,6 +471,7 @@ class ModelRunner: "prefix_lens": batch.prefix_lens, "position_ids_offsets": batch.position_ids_offsets, "out_cache_loc": batch.out_cache_loc, + "return_logprob": return_logprob, } if forward_mode == ForwardMode.DECODE: @@ -478,10 +479,8 @@ class ModelRunner: kwargs["out_cache_cont_end"] = batch.out_cache_cont_end return self.forward_decode(**kwargs) elif forward_mode == ForwardMode.EXTEND: - kwargs["return_logprob"] = return_logprob return self.forward_extend(**kwargs) elif forward_mode == ForwardMode.PREFILL: - kwargs["return_logprob"] = return_logprob return self.forward_prefill(**kwargs) else: raise ValueError(f"Invaid forward mode: {forward_mode}") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d67cb49ea..084ac791e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.managers.io_struct import ( BatchStrOut, + DetokenizeReqInput, FlushCacheReq, GenerateReqInput, TokenizedGenerateReqInput, @@ -234,6 +235,10 @@ class TokenizerManager: yield output_list + async def detokenize(self, obj: DetokenizeReqInput): + token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids) + return [t.decode() if isinstance(t, bytes) else t for t in token_texts] + async def flush_cache(self): flush_cache_req = FlushCacheReq() self.send_to_router.send_pyobj(flush_cache_req) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ce615d43a..d92ee54b2 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -30,7 +30,7 @@ from sglang.srt.conversation import ( ) from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput from sglang.srt.managers.openai_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import ( CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + LogProbs, UsageInfo, ) from sglang.srt.managers.router.manager import start_router_process @@ -97,6 +98,23 @@ async def stream_generator(obj): yield out +async def make_openai_style_logprobs(token_logprobs): + ret_logprobs = LogProbs() + + # Detokenize + token_ids = [tid for tid, _ in token_logprobs] + token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids)) + + for token_text, (_, token_logprob) in zip(token_texts, token_logprobs): + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(token_logprob) + + # Not supported yet. + ret_logprobs.top_logprobs.append({}) + ret_logprobs.text_offset.append(-1) + return ret_logprobs + + @app.post("/generate") async def generate_request(obj: GenerateReqInput): obj.post_init() @@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request): "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, }, + return_logprob=request.logprobs is not None, stream=request.stream, ) adapted_request.post_init() @@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request): async def gnerate_stream_resp(): stream_buffer = "" + n_prev_token = 0 async for content in stream_generator(adapted_request): text = content["text"] prompt_tokens = content["meta_info"]["prompt_tokens"] completion_tokens = content["meta_info"]["completion_tokens"] + if not stream_buffer: # The first chunk + if request.echo: + # Prepend prompt in response text. + text = request.prompt + text + else: + # Skip prompt tokens if echo is disabled. + n_prev_token = prompt_tokens + + if request.logprobs is not None: + logprobs = await make_openai_style_logprobs( + content["meta_info"]["token_logprob"][n_prev_token:] + ) + n_prev_token = len(content["meta_info"]["token_logprob"]) + else: + logprobs = None + delta = text[len(stream_buffer) :] - stream_buffer = text + stream_buffer = content["text"] choice_data = CompletionResponseStreamChoice( index=0, text=delta, - logprobs=None, + logprobs=logprobs, finish_reason=None, ) chunk = CompletionStreamResponse( @@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request): # Non-streaming response. ret = await generate_request(adapted_request) + prompt_tokens = ret["meta_info"]["prompt_tokens"] + completion_tokens = ret["meta_info"]["completion_tokens"] + text = ret["text"] + token_logprob_pos = prompt_tokens + if request.echo: + token_logprob_pos = 0 + text = request.prompt + text + else: + token_logprob_pos = prompt_tokens + + logprobs = ( + await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:]) + if request.logprobs is not None + else None + ) choice_data = CompletionResponseChoice( index=0, - text=ret["text"], - logprobs=None, + text=text, + logprobs=logprobs, finish_reason=None, # TODO(comaniac): Add finish reason. ) - prompt_tokens = ret["meta_info"]["prompt_tokens"] - completion_tokens = ret["meta_info"]["completion_tokens"] response = CompletionResponse( id=ret["meta_info"]["id"], model=request.model, @@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request): if not isinstance(m.content, str): raise HTTPException( status_code=503, - detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.", + detail="Structured content requests not supported with " + "HuggingFace Chat Templates. " + "Make sure the server specifies a sglang chat template.", ) prompt = tokenizer_manager.tokenizer.apply_chat_template( request.messages, tokenize=False, add_generation_prompt=True diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index b26eb030d..7462661b7 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -9,10 +9,24 @@ The capital of France is Paris.\nThe capital of the United States is Washington, """ import argparse -import time import requests +def test_decode(url, return_logprob): + response = requests.post( + url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + "return_logprob": return_logprob, + "logprob_start_len": 0, + }, + ) + print(response.json()) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1") @@ -21,16 +35,5 @@ if __name__ == "__main__": url = f"{args.host}:{args.port}" - response = requests.post( - url + "/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - }, - # "return_logprob": True, - # "logprob_start_len": 0, - }, - ) - print(response.json()) + test_decode(url, False) + test_decode(url, True) diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index 3d63e66cb..5713c6380 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington, import argparse import json -import time import requests -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=30000) - args = parser.parse_args() - - url = f"{args.host}:{args.port}" - +def test_decode_stream(url, return_logprob): response = requests.post( url + "/generate", json={ "text": "The capital of France is", "sampling_params": { "temperature": 0, - "max_new_tokens": 512, + "max_new_tokens": 128, }, "stream": True, + "return_logprob": return_logprob, }, stream=True, ) @@ -41,7 +34,29 @@ if __name__ == "__main__": if chunk == "data: [DONE]": break data = json.loads(chunk[5:].strip("\n")) - output = data["text"].strip() - print(output[prev:], end="", flush=True) - prev = len(output) + + if return_logprob: + assert data["meta_info"]["prompt_logprob"] is not None + assert data["meta_info"]["token_logprob"] is not None + assert data["meta_info"]["normalized_prompt_logprob"] is not None + if prev == 0: # Skip prompt logprobs + prev = data["meta_info"]["prompt_tokens"] + for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]: + print(f"{token_txt}\t{logprob}", flush=True) + prev = len(data["meta_info"]["token_logprob"]) + else: + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) print("") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + + url = f"{args.host}:{args.port}" + + test_decode_stream(url, False) + test_decode_stream(url, True) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index f0dc078e2..b1351360d 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -18,15 +18,26 @@ import argparse import openai -def test_completion(args): +def test_completion(args, echo, logprobs): client = openai.Client(api_key="EMPTY", base_url=args.base_url) response = client.completions.create( model="default", prompt="The capital of France is", temperature=0, max_tokens=32, + echo=echo, + logprobs=logprobs, ) + text = response.choices[0].text print(response.choices[0].text) + if echo: + assert text.startswith("The capital of France is") + if logprobs: + assert response.choices[0].logprobs + if echo: + assert response.choices[0].logprobs.token_logprobs[0] == None + else: + assert response.choices[0].logprobs.token_logprobs[0] != None assert response.id assert response.created assert response.usage.prompt_tokens > 0 @@ -34,7 +45,7 @@ def test_completion(args): assert response.usage.total_tokens > 0 -def test_completion_stream(args): +def test_completion_stream(args, echo, logprobs): client = openai.Client(api_key="EMPTY", base_url=args.base_url) response = client.completions.create( model="default", @@ -42,9 +53,23 @@ def test_completion_stream(args): temperature=0, max_tokens=32, stream=True, + echo=echo, + logprobs=logprobs, ) + first = True for r in response: - print(r.choices[0].text, end="", flush=True) + if first: + if echo: + assert r.choices[0].text.startswith("The capital of France is") + first = False + if logprobs: + print( + f"{r.choices[0].text:12s}\t" + f"{r.choices[0].logprobs.token_logprobs}", + flush=True + ) + else: + print(r.choices[0].text, end="", flush=True) assert r.id assert r.usage.prompt_tokens > 0 assert r.usage.completion_tokens > 0 @@ -135,8 +160,14 @@ if __name__ == "__main__": ) args = parser.parse_args() - test_completion(args) - test_completion_stream(args) + test_completion(args, echo=False, logprobs=False) + test_completion(args, echo=True, logprobs=False) + test_completion(args, echo=False, logprobs=True) + test_completion(args, echo=True, logprobs=True) + test_completion_stream(args, echo=False, logprobs=False) + test_completion_stream(args, echo=True, logprobs=False) + test_completion_stream(args, echo=False, logprobs=True) + test_completion_stream(args, echo=True, logprobs=True) test_chat_completion(args) test_chat_completion_stream(args) if args.test_image: