Support decode token logprobs (#130)
This commit is contained in:
@@ -14,28 +14,11 @@ class LogitsProcessor(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||
if not input_metadata.return_logprob:
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
last_index = None
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, (None, None)
|
||||
else:
|
||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
||||
# Compute the last index (the first decode token) of each requeast
|
||||
# if we are in prefill or extend mode.
|
||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
@@ -45,29 +28,54 @@ class LogitsProcessor(nn.Module):
|
||||
- 1
|
||||
)
|
||||
|
||||
if not input_metadata.return_logprob:
|
||||
# When logprob is not requested, only compute the last logits.
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, (None, None, None)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
logits = logits[:, : self.config.vocab_size]
|
||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logits = logits
|
||||
last_logprobs = all_logprobs
|
||||
prefill_logprobs = normalized_logprobs = None
|
||||
else:
|
||||
# Compute the logprobs for the last token of each request.
|
||||
last_logits = logits[last_index]
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
prefill_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, (logprobs, normalized_logprobs)
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -99,3 +99,7 @@ class BatchStrOut:
|
||||
@dataclass
|
||||
class FlushCacheReq:
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class DetokenizeReqInput:
|
||||
input_ids: List[int]
|
||||
|
||||
@@ -48,6 +48,7 @@ class Req:
|
||||
self.last_node = None
|
||||
|
||||
self.logprob = None
|
||||
self.token_logprob = None
|
||||
self.normalized_logprob = None
|
||||
|
||||
# For constrained decoding
|
||||
|
||||
@@ -388,24 +388,28 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
logprobs = None
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
||||
batch, ForwardMode.EXTEND, batch.return_logprob
|
||||
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
|
||||
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
|
||||
)
|
||||
# print("extend logits", logits)
|
||||
if logprobs is not None:
|
||||
logprobs = logprobs.cpu().tolist()
|
||||
if prefill_logprobs is not None:
|
||||
logprobs = prefill_logprobs.cpu().tolist()
|
||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
logprobs = normalized_logprobs = None
|
||||
logits = logprobs = normalized_logprobs = last_logprobs = None
|
||||
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
if last_logprobs is not None:
|
||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||
|
||||
# Check finish condition
|
||||
reqs = batch.reqs
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req.output_ids = [next_token_ids[i]]
|
||||
@@ -414,6 +418,10 @@ class ModelRpcServer(rpyc.Service):
|
||||
if logprobs is not None:
|
||||
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
||||
req.normalized_logprob = normalized_logprobs[i]
|
||||
|
||||
token_ids = req.input_ids + [next_token_ids[i]]
|
||||
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
|
||||
req.token_logprob = list(zip(token_ids, token_logprobs))
|
||||
pt += req.extend_input_len
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
@@ -463,15 +471,26 @@ class ModelRpcServer(rpyc.Service):
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward
|
||||
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
batch.return_logprob,
|
||||
)
|
||||
next_token_ids, _ = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# Check finish condition
|
||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||
reqs = batch.reqs
|
||||
for i in range(len(reqs)):
|
||||
reqs[i].output_ids.append(next_token_ids[i])
|
||||
reqs[i].check_finished()
|
||||
if last_logprobs is not None:
|
||||
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
||||
req.output_ids.append(next_tok_id)
|
||||
req.check_finished()
|
||||
|
||||
if last_logprobs is not None:
|
||||
req.token_logprob.append((next_tok_id, last_logprobs[i]))
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
@@ -513,6 +532,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
}
|
||||
if req.return_logprob:
|
||||
meta_info["prompt_logprob"] = req.logprob
|
||||
meta_info["token_logprob"] = req.token_logprob
|
||||
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
||||
output_meta_info.append(meta_info)
|
||||
output_finished.append(req.finished)
|
||||
|
||||
@@ -397,6 +397,7 @@ class ModelRunner:
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
return_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
@@ -409,10 +410,9 @@ class ModelRunner:
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
return_logprob=return_logprob,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
|
||||
0
|
||||
]
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(
|
||||
@@ -460,8 +460,8 @@ class ModelRunner:
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
"return_logprob": return_logprob,
|
||||
}
|
||||
kwargs["return_logprob"] = return_logprob
|
||||
return self.forward_extend_multi_modal(**kwargs)
|
||||
else:
|
||||
kwargs = {
|
||||
@@ -471,6 +471,7 @@ class ModelRunner:
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
"return_logprob": return_logprob,
|
||||
}
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
@@ -478,10 +479,8 @@ class ModelRunner:
|
||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
||||
return self.forward_decode(**kwargs)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
kwargs["return_logprob"] = return_logprob
|
||||
return self.forward_extend(**kwargs)
|
||||
elif forward_mode == ForwardMode.PREFILL:
|
||||
kwargs["return_logprob"] = return_logprob
|
||||
return self.forward_prefill(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
DetokenizeReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -234,6 +235,10 @@ class TokenizerManager:
|
||||
|
||||
yield output_list
|
||||
|
||||
async def detokenize(self, obj: DetokenizeReqInput):
|
||||
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
||||
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.conversation import (
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.managers.openai_protocol import (
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
@@ -97,6 +98,23 @@ async def stream_generator(obj):
|
||||
yield out
|
||||
|
||||
|
||||
async def make_openai_style_logprobs(token_logprobs):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
# Detokenize
|
||||
token_ids = [tid for tid, _ in token_logprobs]
|
||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||
|
||||
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
# Not supported yet.
|
||||
ret_logprobs.top_logprobs.append({})
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
@@ -132,6 +150,7 @@ async def v1_completions(raw_request: Request):
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
},
|
||||
return_logprob=request.logprobs is not None,
|
||||
stream=request.stream,
|
||||
)
|
||||
adapted_request.post_init()
|
||||
@@ -140,17 +159,34 @@ async def v1_completions(raw_request: Request):
|
||||
|
||||
async def gnerate_stream_resp():
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
async for content in stream_generator(adapted_request):
|
||||
text = content["text"]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
# Prepend prompt in response text.
|
||||
text = request.prompt + text
|
||||
else:
|
||||
# Skip prompt tokens if echo is disabled.
|
||||
n_prev_token = prompt_tokens
|
||||
|
||||
if request.logprobs is not None:
|
||||
logprobs = await make_openai_style_logprobs(
|
||||
content["meta_info"]["token_logprob"][n_prev_token:]
|
||||
)
|
||||
n_prev_token = len(content["meta_info"]["token_logprob"])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = text
|
||||
stream_buffer = content["text"]
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=0,
|
||||
text=delta,
|
||||
logprobs=None,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = CompletionStreamResponse(
|
||||
@@ -172,15 +208,28 @@ async def v1_completions(raw_request: Request):
|
||||
# Non-streaming response.
|
||||
ret = await generate_request(adapted_request)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
text = ret["text"]
|
||||
token_logprob_pos = prompt_tokens
|
||||
if request.echo:
|
||||
token_logprob_pos = 0
|
||||
text = request.prompt + text
|
||||
else:
|
||||
token_logprob_pos = prompt_tokens
|
||||
|
||||
logprobs = (
|
||||
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
|
||||
if request.logprobs is not None
|
||||
else None
|
||||
)
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=0,
|
||||
text=ret["text"],
|
||||
logprobs=None,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None, # TODO(comaniac): Add finish reason.
|
||||
)
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
response = CompletionResponse(
|
||||
id=ret["meta_info"]["id"],
|
||||
model=request.model,
|
||||
@@ -216,7 +265,9 @@ async def v1_chat_completions(raw_request: Request):
|
||||
if not isinstance(m.content, str):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
|
||||
detail="Structured content requests not supported with "
|
||||
"HuggingFace Chat Templates. "
|
||||
"Make sure the server specifies a sglang chat template.",
|
||||
)
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
|
||||
Reference in New Issue
Block a user