Fix select and normalized logprobs (#67)
This commit is contained in:
@@ -1,79 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_segmented_gather(
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
voc_size: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
cur_req = tl.program_id(0)
|
||||
cur_l = tl.load(len_add_1 + cur_req)
|
||||
cum_l = tl.load(cum_len + cur_req)
|
||||
|
||||
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
|
||||
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = off < cur_l - 1
|
||||
|
||||
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
|
||||
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
|
||||
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
|
||||
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
|
||||
voc_size = all_logits.shape[1]
|
||||
grid = (len_add_1.shape[0], 1, 1)
|
||||
max_seq_len = len_add_1.max().item()
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
4,
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_segmented_gather[grid](
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
voc_size,
|
||||
BLOCK_SIZE=128,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_logits = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
|
||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
|
||||
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
|
||||
print(logprobs)
|
||||
# assert logprobs == [2, 2, 4]
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from torch import nn
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
normalized_logprobs = compute_normalized_logprobs(
|
||||
all_logprobs,
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
input_ids,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, normalized_logprobs
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
|
||||
logprobs = torch.zeros(
|
||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = start_loc.clone()
|
||||
end = start + seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
return sum_logp / ((seq_lens - 1).clamp(min=1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_logprobs = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
|
||||
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||
end.sub_(1)
|
||||
torch.cuda.synchronize()
|
||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||
res = sum_logp / len_add_1
|
||||
return res
|
||||
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
||||
|
||||
logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
||||
end = start + seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
|
||||
# assert logprobs == [2, _, 2, 4, _]
|
||||
print("logprobs", logprobs)
|
||||
print("start", start)
|
||||
print("end", end)
|
||||
print("sum_logp", sum_logp)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
|
||||
@@ -112,7 +111,7 @@ class InputMetadata:
|
||||
def init_extend_args(self):
|
||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||
|
||||
@classmethod
|
||||
@@ -262,7 +261,7 @@ class ModelRunner:
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||
|
||||
logger.info("load weight begin.")
|
||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||
|
||||
# Load weights
|
||||
linear_method = None
|
||||
@@ -287,7 +286,7 @@ class ModelRunner:
|
||||
)
|
||||
self.model = model.eval()
|
||||
|
||||
logger.info("load weight end.")
|
||||
logger.info(f"Rank {self.tp_rank}: load weight end.")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
@@ -308,8 +307,9 @@ class ModelRunner:
|
||||
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
|
||||
|
||||
if self.max_total_num_token <= 0:
|
||||
raise RuntimeError("Not enought memory. "
|
||||
"Please try to increase --mem-fraction-static.")
|
||||
raise RuntimeError(
|
||||
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
int(self.max_total_num_token / self.model_config.context_len * 256),
|
||||
|
||||
Reference in New Issue
Block a user