Fix logprob in the overlapped mode (#1795)

This commit is contained in:
Lianmin Zheng
2024-10-25 11:06:57 -07:00
committed by GitHub
parent c555ce2ca2
commit e646c5901e
7 changed files with 62 additions and 29 deletions

View File

@@ -263,7 +263,8 @@ class CudaGraphRunner:
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
)
return forward(input_ids, forward_batch.positions, forward_batch)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits
for _ in range(2):
torch.cuda.synchronize()
@@ -318,23 +319,16 @@ class CudaGraphRunner:
# Replay
self.graphs[bs].replay()
logits_output = self.output_buffers[bs]
# Unpad
if bs != raw_bs:
logits_output = LogitsProcessorOutput(
next_token_logits=logits_output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
next_token_logits = self.output_buffers[bs][:raw_bs]
# Extract logprobs
if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
@@ -343,7 +337,11 @@ class CudaGraphRunner:
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
next_token_logprobs, logits_metadata
)[1]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
)
return logits_output