[Feat]Add support for optional start len of logprobs (#1035)
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -55,6 +55,9 @@ class LogitsMetadata:
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
|
||||
extend_seq_lens_cpu: List[int] = None
|
||||
logprob_start_lens_cpu: List[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||
return cls(
|
||||
@@ -63,6 +66,8 @@ class LogitsMetadata:
|
||||
extend_start_loc=input_metadata.extend_start_loc,
|
||||
return_logprob=input_metadata.return_logprob,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
||||
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
|
||||
)
|
||||
|
||||
|
||||
@@ -75,12 +80,16 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
||||
self,
|
||||
input_token_logprobs: torch.Tensor,
|
||||
cum_start_len0: torch.Tensor,
|
||||
cum_start_len1: torch.Tensor,
|
||||
logits_metadata: LogitsMetadata,
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = logits_metadata.extend_start_loc.clone()
|
||||
end = start + logits_metadata.extend_seq_lens - 2
|
||||
start = logits_metadata.extend_start_loc.clone() - cum_start_len0
|
||||
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
|
||||
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
sum_logp = (
|
||||
@@ -93,7 +102,7 @@ class LogitsProcessor(nn.Module):
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
@staticmethod
|
||||
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
output_top_logprobs = []
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
@@ -107,7 +116,7 @@ class LogitsProcessor(nn.Module):
|
||||
# TODO: vectorize the code below
|
||||
input_top_logprobs, output_top_logprobs = [], []
|
||||
pt = 0
|
||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
||||
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
ret = all_logprobs.topk(max_k, dim=1)
|
||||
@@ -115,26 +124,30 @@ class LogitsProcessor(nn.Module):
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
||||
pruned_len = extend_seq_len - start_len
|
||||
|
||||
if extend_seq_len == 0:
|
||||
input_top_logprobs.append([])
|
||||
output_top_logprobs.append([])
|
||||
continue
|
||||
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
input_top_logprobs.append(
|
||||
[
|
||||
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
||||
for j in range(extend_seq_len - 1)
|
||||
for j in range(pruned_len - 1)
|
||||
]
|
||||
)
|
||||
output_top_logprobs.append(
|
||||
list(
|
||||
zip(
|
||||
values[pt + extend_seq_len - 1][:k],
|
||||
indices[pt + extend_seq_len - 1][:k],
|
||||
values[pt + pruned_len - 1][:k],
|
||||
indices[pt + pruned_len - 1][:k],
|
||||
)
|
||||
)
|
||||
)
|
||||
pt += extend_seq_len
|
||||
pt += pruned_len
|
||||
|
||||
return input_top_logprobs, output_top_logprobs
|
||||
|
||||
@@ -205,7 +218,23 @@ class LogitsProcessor(nn.Module):
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
)
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
pt, states, pruned_input_ids = 0, [], []
|
||||
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
|
||||
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
||||
states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||
pt += extend_len
|
||||
|
||||
states = torch.cat(states, dim=0)
|
||||
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
||||
|
||||
cum_start_len1 = torch.tensor(
|
||||
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
||||
).cumsum(0)
|
||||
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
||||
cum_start_len0[1:] = cum_start_len1[:-1]
|
||||
|
||||
all_logits = torch.matmul(states, weight.T)
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||
@@ -230,19 +259,25 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
input_top_logprobs = output_top_logprobs = None
|
||||
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
last_logprobs = all_logprobs[last_index - cum_start_len1]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
input_token_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
input_token_logprobs, logits_metadata
|
||||
input_token_logprobs,
|
||||
cum_start_len0,
|
||||
cum_start_len1,
|
||||
logits_metadata,
|
||||
)
|
||||
|
||||
# Remove the last token logprob for the prefill tokens.
|
||||
input_token_logprobs = input_token_logprobs[:-1]
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
|
||||
@@ -75,7 +75,7 @@ class GenerateReqInput:
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = 0
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
else:
|
||||
@@ -141,7 +141,7 @@ class GenerateReqInput:
|
||||
self.return_logprob = [self.return_logprob] * num
|
||||
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = [0] * num
|
||||
self.logprob_start_len = [-1] * num
|
||||
elif not isinstance(self.logprob_start_len, list):
|
||||
self.logprob_start_len = [self.logprob_start_len] * num
|
||||
|
||||
|
||||
@@ -195,6 +195,9 @@ class TokenizerManager:
|
||||
if not_use_index
|
||||
else obj.logprob_start_len[index]
|
||||
)
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num
|
||||
if not_use_index
|
||||
@@ -245,6 +248,8 @@ class TokenizerManager:
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
|
||||
if self.is_generation:
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
@@ -334,6 +339,8 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
|
||||
if self.is_generation:
|
||||
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
||||
obj.logprob_start_len[index] = len(input_ids) - 1
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[index]
|
||||
)
|
||||
|
||||
@@ -61,9 +61,11 @@ class InputMetadata:
|
||||
extend_start_loc: torch.Tensor = None
|
||||
extend_no_prefix: bool = None
|
||||
|
||||
# Output options
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
extend_seq_lens_cpu: List[int] = None
|
||||
logprob_start_lens_cpu: List[int] = None
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
@@ -139,6 +141,7 @@ class InputMetadata:
|
||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
||||
else:
|
||||
extend_lens_cpu = [
|
||||
len(r.fill_ids) - batch.prefix_lens_cpu[i]
|
||||
@@ -149,6 +152,19 @@ class InputMetadata:
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||
|
||||
self.extend_seq_lens_cpu = extend_lens_cpu
|
||||
self.logprob_start_lens_cpu = [
|
||||
(
|
||||
min(
|
||||
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
||||
extend_lens_cpu[i] - 1,
|
||||
)
|
||||
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
||||
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
||||
)
|
||||
for i, req in enumerate(batch.reqs)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls,
|
||||
|
||||
@@ -20,6 +20,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -383,20 +384,33 @@ async def v1_retrieve_file_content(file_id: str):
|
||||
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
||||
|
||||
|
||||
def v1_generate_request(all_requests):
|
||||
def v1_generate_request(all_requests: List[CompletionRequest]):
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
if request.echo and request.logprobs:
|
||||
warnings.warn(
|
||||
"Echo is not compatible with logprobs. "
|
||||
"To compute logprobs of input prompt, please use SGLang /request API."
|
||||
)
|
||||
|
||||
for request in all_requests:
|
||||
prompt = request.prompt
|
||||
assert (
|
||||
type(prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
prompts.append(prompt)
|
||||
prompts.append(request.prompt)
|
||||
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
@@ -416,14 +430,11 @@ def v1_generate_request(all_requests):
|
||||
"ignore_eos": request.ignore_eos,
|
||||
}
|
||||
)
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
@@ -441,6 +452,7 @@ def v1_generate_request(all_requests):
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
)
|
||||
@@ -694,12 +706,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
return response
|
||||
|
||||
|
||||
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
def v1_chat_generate_request(
|
||||
all_requests: List[ChatCompletionRequest], tokenizer_manager
|
||||
):
|
||||
input_ids = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
|
||||
for request in all_requests:
|
||||
# Prep the data needed for the underlying GenerateReqInput:
|
||||
# - prompt: The full prompt string.
|
||||
@@ -732,6 +750,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
image_data = None
|
||||
input_ids.append(prompt_ids)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
@@ -758,17 +777,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data = image_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
stream=all_requests[0].stream,
|
||||
return_text_in_logprobs=True,
|
||||
|
||||
@@ -559,12 +559,14 @@ class Runtime:
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
}
|
||||
response = requests.post(
|
||||
|
||||
@@ -209,6 +209,7 @@ class SRTRunner:
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
|
||||
Reference in New Issue
Block a user