Fix select and normalized logprobs (#67)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from torch import nn
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
normalized_logprobs = compute_normalized_logprobs(
|
||||
all_logprobs,
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
input_ids,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, normalized_logprobs
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
|
||||
logprobs = torch.zeros(
|
||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = start_loc.clone()
|
||||
end = start + seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
return sum_logp / ((seq_lens - 1).clamp(min=1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_logprobs = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
|
||||
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||
end.sub_(1)
|
||||
torch.cuda.synchronize()
|
||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||
res = sum_logp / len_add_1
|
||||
return res
|
||||
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
||||
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
||||
end = start + seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
|
||||
# assert logprobs == [2, _, 2, 4, _]
|
||||
print("logprobs", logprobs)
|
||||
print("start", start)
|
||||
print("end", end)
|
||||
print("sum_logp", sum_logp)
|
||||
|
||||
Reference in New Issue
Block a user