Misc fixes (#432)

This commit is contained in:
Lianmin Zheng
2024-05-12 15:05:40 -07:00
committed by GitHub
parent 72bb344388
commit 6e09cf6a15
6 changed files with 23 additions and 5 deletions

View File

@@ -426,7 +426,9 @@ class ModelRpcServer:
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids].tolist()
)
next_token_ids = next_token_ids.tolist()
@@ -587,6 +589,7 @@ class ModelRpcServer:
- req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
"hit_stop_str": req.hit_stop_str,
}
if req.return_logprob:
(

View File

@@ -110,8 +110,8 @@ class InputMetadata:
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
seq_lens_cpu = self.seq_lens.tolist()
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[