Fix select and normalized logprobs (#67)

This commit is contained in:
Lianmin Zheng
2024-01-21 01:39:23 -08:00
committed by GitHub
parent 11f3cca64f
commit a837166e6f
3 changed files with 52 additions and 101 deletions

View File

@@ -1,79 +0,0 @@
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_segmented_gather(
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
cur_req = tl.program_id(0)
cur_l = tl.load(len_add_1 + cur_req)
cum_l = tl.load(cum_len + cur_req)
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = off < cur_l - 1
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
cached_kernel = None
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
voc_size = all_logits.shape[1]
grid = (len_add_1.shape[0], 1, 1)
max_seq_len = len_add_1.max().item()
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
4,
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
)
return
_fwd_segmented_gather[grid](
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size,
BLOCK_SIZE=128,
)
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
if __name__ == "__main__":
all_logits = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
print(logprobs)
# assert logprobs == [2, 2, 4]

View File

@@ -1,5 +1,4 @@
import torch import torch
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn from torch import nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
normalized_logprobs = compute_normalized_logprobs( normalized_logprobs = compute_normalized_logprobs(
all_logprobs, all_logprobs,
input_metadata.seq_lens - input_metadata.prefix_lens,
input_ids, input_ids,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
) )
last_logits = logits[last_index] last_logits = logits[last_index]
return last_logits, normalized_logprobs return last_logits, normalized_logprobs
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1) logprobs = all_logprobs[
logprobs = torch.zeros( torch.arange(all_logprobs.shape[0], device="cuda"),
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
start = start_loc.clone()
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
return sum_logp / ((seq_lens - 1).clamp(min=1))
if __name__ == "__main__":
all_logprobs = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
) )
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs) seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
end = torch.cumsum(len_add_1.sub_(1), dim=0) logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1) logprobs = all_logprobs[
torch.cuda.synchronize() torch.arange(all_logprobs.shape[0], device="cuda"),
sum_logp = cumsum[end] - cumsum[start] + logprobs[start] torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
res = sum_logp / len_add_1 ]
return res logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
len_cumsum = torch.cumsum(seq_lens, dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
# assert logprobs == [2, _, 2, 4, _]
print("logprobs", logprobs)
print("start", start)
print("end", end)
print("sum_logp", sum_logp)

View File

@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import List from typing import List
import logging
import numpy as np import numpy as np
import torch import torch
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
@@ -112,7 +111,7 @@ class InputMetadata:
def init_extend_args(self): def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens)) self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod @classmethod
@@ -262,7 +261,7 @@ class ModelRunner:
if model_class is None: if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}") raise ValueError(f"Unsupported architectures: {architectures}")
logger.info("load weight begin.") logger.info(f"Rank {self.tp_rank}: load weight begin.")
# Load weights # Load weights
linear_method = None linear_method = None
@@ -287,7 +286,7 @@ class ModelRunner:
) )
self.model = model.eval() self.model = model.eval()
logger.info("load weight end.") logger.info(f"Rank {self.tp_rank}: load weight end.")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
@@ -308,8 +307,9 @@ class ModelRunner:
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory) self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
if self.max_total_num_token <= 0: if self.max_total_num_token <= 0:
raise RuntimeError("Not enought memory. " raise RuntimeError(
"Please try to increase --mem-fraction-static.") "Not enought memory. " "Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256), int(self.max_total_num_token / self.model_config.context_len * 256),