Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -1,6 +1,21 @@
|
||||
import torch
|
||||
from __future__ import annotations
|
||||
|
||||
from sglang.srt.distributed import GroupCoordinator, get_tp_group
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_ATTN_TP_GROUP = None
|
||||
_ATTN_TP_RANK = None
|
||||
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
|
||||
def get_attention_dp_size():
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _DP_SIZE
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
if forward_batch.dp_local_start_pos is None:
|
||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||
if dp_rank == 0:
|
||||
local_start_pos = torch.zeros_like(cumtokens[0])
|
||||
else:
|
||||
local_start_pos = cumtokens[dp_rank - 1]
|
||||
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
|
||||
|
||||
forward_batch.dp_local_start_pos = local_start_pos
|
||||
forward_batch.dp_local_num_tokens = local_num_tokens
|
||||
|
||||
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
|
||||
|
||||
|
||||
@triton.jit
|
||||
def memcpy_triton_kernel(
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
offset_ptr,
|
||||
sz_ptr,
|
||||
offset_src,
|
||||
chunk_size, # multiplied for offset and sz
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0).to(tl.int64)
|
||||
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
|
||||
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
|
||||
|
||||
start_index = pid * BLOCK_SIZE
|
||||
offs = tl.arange(0, BLOCK_SIZE)
|
||||
mask = start_index + offs < sz
|
||||
|
||||
if offset_src:
|
||||
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
|
||||
tl.store(dst_ptr + start_index + offs, data, mask=mask)
|
||||
else:
|
||||
data = tl.load(src_ptr + start_index + offs, mask=mask)
|
||||
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
|
||||
|
||||
|
||||
def prod(x):
|
||||
return functools.reduce(lambda a, b: a * b, x, 1)
|
||||
|
||||
|
||||
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
||||
max_size = min(src.numel(), dst.numel())
|
||||
assert dim == 0, "dim != 0 unsupported"
|
||||
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
|
||||
chunk_size = prod(src.shape[1:])
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
|
||||
|
||||
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
||||
|
||||
|
||||
def dp_gather(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: Union[str, int],
|
||||
):
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
|
||||
global_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
if local_tokens.shape[0] > 0 and (
|
||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||
):
|
||||
assert (
|
||||
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
memcpy_triton(
|
||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||
)
|
||||
|
||||
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
|
||||
NUM_GPUS_PER_NODE = 8
|
||||
if (
|
||||
not local_tokens.dtype.is_floating_point
|
||||
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
|
||||
):
|
||||
torch.ops.sglang.inplace_all_reduce(
|
||||
global_tokens, group_name=get_tp_group().unique_name
|
||||
)
|
||||
else:
|
||||
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def dp_scatter(
|
||||
local_tokens: torch.Tensor, # output
|
||||
global_tokens: torch.Tensor, # input
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
||||
# since local_tokens may be padded for cuda graph
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
local_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
if local_tokens.shape[0] > 0:
|
||||
assert (
|
||||
local_tokens.untyped_storage().data_ptr()
|
||||
!= global_tokens.untyped_storage().data_ptr()
|
||||
), "aliasing between local_tokens and global_tokens not allowed"
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
||||
def do_logits_dp_scatter(logits: torch.Tensor):
|
||||
local_logits = torch.empty(
|
||||
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device,
|
||||
)
|
||||
dp_scatter(local_logits, logits, forward_batch)
|
||||
return local_logits
|
||||
|
||||
return do_logits_dp_scatter
|
||||
|
||||
Reference in New Issue
Block a user