update
This commit is contained in:
61
vllm/v1/worker/gpu/cp_utils.py
Normal file
61
vllm/v1/worker/gpu/cp_utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def prepare_dcp_local_seq_lens(
|
||||
dcp_local_seq_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_reqs: int,
|
||||
dcp_size: int,
|
||||
dcp_rank: int,
|
||||
cp_interleave: int,
|
||||
) -> None:
|
||||
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
|
||||
if dcp_size == 1:
|
||||
return
|
||||
|
||||
max_num_reqs = dcp_local_seq_lens.shape[0]
|
||||
BLOCK_SIZE = 128
|
||||
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
|
||||
_dcp_local_seq_lens_kernel[(num_blocks,)](
|
||||
dcp_local_seq_lens,
|
||||
seq_lens,
|
||||
dcp_size,
|
||||
dcp_rank,
|
||||
cp_interleave,
|
||||
num_reqs,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dcp_local_seq_lens_kernel(
|
||||
out_ptr,
|
||||
seq_lens_ptr,
|
||||
dcp_size,
|
||||
dcp_rank,
|
||||
cp_interleave,
|
||||
num_reqs,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
|
||||
|
||||
# Distribute KV cache among different ranks, in a round-robin manner.
|
||||
rounds = seq_lens // (dcp_size * cp_interleave)
|
||||
remainder = seq_lens % (dcp_size * cp_interleave)
|
||||
|
||||
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
|
||||
remainder = tl.minimum(remainder, cp_interleave)
|
||||
local_seq_lens = rounds * cp_interleave + remainder
|
||||
|
||||
# For [num_reqs, max_num_reqs), pad with 0
|
||||
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
|
||||
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)
|
||||
Reference in New Issue
Block a user