62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
def prepare_dcp_local_seq_lens(
|
|
dcp_local_seq_lens: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
num_reqs: int,
|
|
dcp_size: int,
|
|
dcp_rank: int,
|
|
cp_interleave: int,
|
|
) -> None:
|
|
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
|
|
if dcp_size == 1:
|
|
return
|
|
|
|
max_num_reqs = dcp_local_seq_lens.shape[0]
|
|
BLOCK_SIZE = 128
|
|
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
|
|
_dcp_local_seq_lens_kernel[(num_blocks,)](
|
|
dcp_local_seq_lens,
|
|
seq_lens,
|
|
dcp_size,
|
|
dcp_rank,
|
|
cp_interleave,
|
|
num_reqs,
|
|
max_num_reqs,
|
|
BLOCK_SIZE,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _dcp_local_seq_lens_kernel(
|
|
out_ptr,
|
|
seq_lens_ptr,
|
|
dcp_size,
|
|
dcp_rank,
|
|
cp_interleave,
|
|
num_reqs,
|
|
max_num_reqs,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(0)
|
|
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
|
|
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
|
|
|
|
# Distribute KV cache among different ranks, in a round-robin manner.
|
|
rounds = seq_lens // (dcp_size * cp_interleave)
|
|
remainder = seq_lens % (dcp_size * cp_interleave)
|
|
|
|
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
|
|
remainder = tl.minimum(remainder, cp_interleave)
|
|
local_seq_lens = rounds * cp_interleave + remainder
|
|
|
|
# For [num_reqs, max_num_reqs), pad with 0
|
|
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
|
|
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)
|