Fix logprob in the overlapped mode (#1795)
This commit is contained in:
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
|
||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logits: torch.Tensor
|
||||
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logprobs: torch.Tensor
|
||||
next_token_logprobs: torch.Tensor = None
|
||||
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor
|
||||
normalized_prompt_logprobs: torch.Tensor = None
|
||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||
input_token_logprobs: torch.Tensor
|
||||
input_token_logprobs: torch.Tensor = None
|
||||
|
||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
input_top_logprobs: List
|
||||
input_top_logprobs: List = None
|
||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
output_top_logprobs: List
|
||||
output_top_logprobs: List = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@@ -833,6 +833,7 @@ class Scheduler:
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
|
||||
@@ -103,6 +103,8 @@ class TpModelWorkerClient:
|
||||
while True:
|
||||
self.has_inflight_batch = False
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
if not model_worker_batch:
|
||||
break
|
||||
self.has_inflight_batch = True
|
||||
self.launch_event = threading.Event()
|
||||
|
||||
@@ -122,19 +124,48 @@ class TpModelWorkerClient:
|
||||
] = next_token_ids
|
||||
|
||||
# Copy results to the CPU
|
||||
if model_worker_batch.return_logprob:
|
||||
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].to("cpu", non_blocking=True)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
)
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event = torch.cuda.Event(blocking=True)
|
||||
copy_event.record()
|
||||
|
||||
self.launch_event.set()
|
||||
self.copy_queue.put((copy_event, next_token_ids))
|
||||
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
||||
|
||||
def copy_thread_func(self):
|
||||
while True:
|
||||
copy_event, next_token_ids = self.copy_queue.get()
|
||||
copy_event, logits_output, next_token_ids = self.copy_queue.get()
|
||||
if not copy_event:
|
||||
break
|
||||
while not copy_event.query():
|
||||
time.sleep(1e-5)
|
||||
self.output_queue.put((None, next_token_ids.tolist()))
|
||||
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
||||
|
||||
def resulve_batch_result(self, bid: int):
|
||||
logits_output, next_token_ids = self.output_queue.get()
|
||||
@@ -172,3 +203,7 @@ class TpModelWorkerClient:
|
||||
recv_req.model_path, recv_req.load_format
|
||||
)
|
||||
return success, message
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
|
||||
@@ -263,7 +263,8 @@ class CudaGraphRunner:
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
@@ -318,23 +319,16 @@ class CudaGraphRunner:
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
logits_output = self.output_buffers[bs]
|
||||
|
||||
# Unpad
|
||||
if bs != raw_bs:
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
next_token_logits = self.output_buffers[bs][:raw_bs]
|
||||
|
||||
# Extract logprobs
|
||||
if forward_batch.return_logprob:
|
||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
)
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits,
|
||||
next_token_logprobs=next_token_logprobs,
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
@@ -343,7 +337,11 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
)
|
||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
logits_output.next_token_logprobs, logits_metadata
|
||||
next_token_logprobs, logits_metadata
|
||||
)[1]
|
||||
else:
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits,
|
||||
)
|
||||
|
||||
return logits_output
|
||||
|
||||
Reference in New Issue
Block a user