Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -1,6 +1,21 @@
import torch
from __future__ import annotations
from sglang.srt.distributed import GroupCoordinator, get_tp_group
import functools
from typing import TYPE_CHECKING, Union
import torch
import triton
import triton.language as tl
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
if dp_rank == 0:
local_start_pos = torch.zeros_like(cumtokens[0])
else:
local_start_pos = cumtokens[dp_rank - 1]
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
forward_batch.dp_local_start_pos = local_start_pos
forward_batch.dp_local_num_tokens = local_num_tokens
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
@triton.jit
def memcpy_triton_kernel(
dst_ptr,
src_ptr,
offset_ptr,
sz_ptr,
offset_src,
chunk_size, # multiplied for offset and sz
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0).to(tl.int64)
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
start_index = pid * BLOCK_SIZE
offs = tl.arange(0, BLOCK_SIZE)
mask = start_index + offs < sz
if offset_src:
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
tl.store(dst_ptr + start_index + offs, data, mask=mask)
else:
data = tl.load(src_ptr + start_index + offs, mask=mask)
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
def prod(x):
return functools.reduce(lambda a, b: a * b, x, 1)
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
max_size = min(src.numel(), dst.numel())
assert dim == 0, "dim != 0 unsupported"
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
chunk_size = prod(src.shape[1:])
BLOCK_SIZE = 8192
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: Union[str, int],
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
NUM_GPUS_PER_NODE = 8
if (
not local_tokens.dtype.is_floating_point
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
):
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input
forward_batch: ForwardBatch,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0:
assert (
local_tokens.untyped_storage().data_ptr()
!= global_tokens.untyped_storage().data_ptr()
), "aliasing between local_tokens and global_tokens not allowed"
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter