Fix logprob in the overlapped mode (#1795)
This commit is contained in:
@@ -263,7 +263,8 @@ class CudaGraphRunner:
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
@@ -318,23 +319,16 @@ class CudaGraphRunner:
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
logits_output = self.output_buffers[bs]
|
||||
|
||||
# Unpad
|
||||
if bs != raw_bs:
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
next_token_logits = self.output_buffers[bs][:raw_bs]
|
||||
|
||||
# Extract logprobs
|
||||
if forward_batch.return_logprob:
|
||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
)
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits,
|
||||
next_token_logprobs=next_token_logprobs,
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
@@ -343,7 +337,11 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
)
|
||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
logits_output.next_token_logprobs, logits_metadata
|
||||
next_token_logprobs, logits_metadata
|
||||
)[1]
|
||||
else:
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits,
|
||||
)
|
||||
|
||||
return logits_output
|
||||
|
||||
Reference in New Issue
Block a user