Fix select and normalized logprobs (#67)
This commit is contained in:
@@ -1,79 +0,0 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from sglang.srt.utils import wrap_kernel_launcher
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_segmented_gather(
|
|
||||||
all_logits,
|
|
||||||
len_add_1,
|
|
||||||
cum_len,
|
|
||||||
input_ids,
|
|
||||||
logprobs,
|
|
||||||
max_seq_len,
|
|
||||||
voc_size: tl.constexpr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_req = tl.program_id(0)
|
|
||||||
cur_l = tl.load(len_add_1 + cur_req)
|
|
||||||
cum_l = tl.load(cum_len + cur_req)
|
|
||||||
|
|
||||||
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
|
|
||||||
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = off < cur_l - 1
|
|
||||||
|
|
||||||
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
|
|
||||||
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
|
|
||||||
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
cached_kernel = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
|
|
||||||
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
|
|
||||||
voc_size = all_logits.shape[1]
|
|
||||||
grid = (len_add_1.shape[0], 1, 1)
|
|
||||||
max_seq_len = len_add_1.max().item()
|
|
||||||
|
|
||||||
global cached_kernel
|
|
||||||
if cached_kernel:
|
|
||||||
cached_kernel(
|
|
||||||
grid,
|
|
||||||
4,
|
|
||||||
all_logits,
|
|
||||||
len_add_1,
|
|
||||||
cum_len,
|
|
||||||
input_ids,
|
|
||||||
logprobs,
|
|
||||||
max_seq_len,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
_fwd_segmented_gather[grid](
|
|
||||||
all_logits,
|
|
||||||
len_add_1,
|
|
||||||
cum_len,
|
|
||||||
input_ids,
|
|
||||||
logprobs,
|
|
||||||
max_seq_len,
|
|
||||||
voc_size,
|
|
||||||
BLOCK_SIZE=128,
|
|
||||||
)
|
|
||||||
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
all_logits = torch.tensor(
|
|
||||||
# s s s
|
|
||||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
|
|
||||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
|
||||||
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
|
|
||||||
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
|
|
||||||
print(logprobs)
|
|
||||||
# assert logprobs == [2, 2, 4]
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
|
|
||||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
normalized_logprobs = compute_normalized_logprobs(
|
normalized_logprobs = compute_normalized_logprobs(
|
||||||
all_logprobs,
|
all_logprobs,
|
||||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
|
||||||
input_ids,
|
input_ids,
|
||||||
|
input_metadata.extend_seq_lens,
|
||||||
|
input_metadata.extend_start_loc,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_logits = logits[last_index]
|
last_logits = logits[last_index]
|
||||||
return last_logits, normalized_logprobs
|
return last_logits, normalized_logprobs
|
||||||
|
|
||||||
|
|
||||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
|
||||||
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
|
logprobs = all_logprobs[
|
||||||
logprobs = torch.zeros(
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
|
]
|
||||||
|
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||||
|
|
||||||
|
start = start_loc.clone()
|
||||||
|
end = start + seq_lens - 2
|
||||||
|
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||||
|
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||||
|
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||||
|
return sum_logp / ((seq_lens - 1).clamp(min=1))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
all_logprobs = torch.tensor(
|
||||||
|
# s s s
|
||||||
|
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
)
|
)
|
||||||
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
|
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
||||||
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
||||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
|
||||||
end.sub_(1)
|
logprobs = all_logprobs[
|
||||||
torch.cuda.synchronize()
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
res = sum_logp / len_add_1
|
]
|
||||||
return res
|
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||||
|
|
||||||
|
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
||||||
|
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
||||||
|
end = start + seq_lens - 2
|
||||||
|
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||||
|
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||||
|
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||||
|
|
||||||
|
# assert logprobs == [2, _, 2, 4, _]
|
||||||
|
print("logprobs", logprobs)
|
||||||
|
print("start", start)
|
||||||
|
print("end", end)
|
||||||
|
print("sum_logp", sum_logp)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
import logging
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("model_runner")
|
||||||
|
|
||||||
|
|
||||||
@@ -112,7 +111,7 @@ class InputMetadata:
|
|||||||
def init_extend_args(self):
|
def init_extend_args(self):
|
||||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -262,7 +261,7 @@ class ModelRunner:
|
|||||||
if model_class is None:
|
if model_class is None:
|
||||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||||
|
|
||||||
logger.info("load weight begin.")
|
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
linear_method = None
|
linear_method = None
|
||||||
@@ -287,7 +286,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
|
||||||
logger.info("load weight end.")
|
logger.info(f"Rank {self.tp_rank}: load weight end.")
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
@@ -308,8 +307,9 @@ class ModelRunner:
|
|||||||
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
|
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
|
||||||
|
|
||||||
if self.max_total_num_token <= 0:
|
if self.max_total_num_token <= 0:
|
||||||
raise RuntimeError("Not enought memory. "
|
raise RuntimeError(
|
||||||
"Please try to increase --mem-fraction-static.")
|
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
||||||
|
)
|
||||||
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
int(self.max_total_num_token / self.model_config.context_len * 256),
|
int(self.max_total_num_token / self.model_config.context_len * 256),
|
||||||
|
|||||||
Reference in New Issue
Block a user