### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
75 lines
2.7 KiB
Python
75 lines
2.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
import contextlib
|
|
import functools
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
|
return cu_seqlens[1:] - cu_seqlens[:-1]
|
|
|
|
|
|
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
|
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
|
|
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
|
|
|
|
|
def prepare_final_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
|
indices = triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + 1
|
|
return torch.cumsum(indices, 0) - 1
|
|
|
|
|
|
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
|
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
|
|
|
|
|
|
def prepare_update_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
|
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + 1]).cumsum(-1)
|
|
|
|
|
|
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
|
"""
|
|
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
|
|
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
|
|
|
|
tensor = None
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
tensor = arg
|
|
break
|
|
if tensor is None:
|
|
for value in kwargs.values():
|
|
if isinstance(value, torch.Tensor):
|
|
tensor = value
|
|
break
|
|
|
|
if tensor is not None:
|
|
ctx = torch.npu.device(tensor.device.index)
|
|
else:
|
|
ctx = contextlib.nullcontext()
|
|
|
|
with ctx:
|
|
return fn(*contiguous_args, **contiguous_kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@triton.jit
|
|
def safe_exp(x):
|
|
return tl.exp(tl.where(x <= 0, x, float("-inf")))
|