rollback causal_conv1d_fn to torch ops & update qwen3Next doc (#5391)

### What this PR does / why we need it?
Rollback causal_conv1d_fn ops from triton to torch version to fix
hanging issues,meanwhile update Qwen3Next doc

- vLLM version: release/v0.13.0
- vLLM main:
254f6b9867
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2025-12-26 19:57:38 +08:00
committed by GitHub
parent 48854aef5c
commit 7685d0c239
2 changed files with 109 additions and 405 deletions

View File

@@ -92,10 +92,8 @@ source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
Run the following script to start the vLLM server on multi-NPU:
For an Atlas A2 with 64 GB of NPU card memory, tensor-parallel-size should be at least 4, and for 32 GB of memory, tensor-parallel-size should be at least 8.
```bash
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 4096 --gpu-memory-utilization 0.7 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}'
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 32768 --gpu-memory-utilization 0.8 --max-num-batched-tokens 4096 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}'
```
Once your server is started, you can query the model with input prompts.
@@ -170,11 +168,11 @@ Prompt: 'Who are you?', Generated text: ' What do you know about me?\n\nHello! I
1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details.
2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.11.0rc3` for reference only.
2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.13.0rc1` for reference only.
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k | - | accuracy | gen | 96.3 |
| gsm8k | - | accuracy | gen | 95.53 |
## Performance
@@ -201,3 +199,15 @@ vllm bench serve --model Qwen/Qwen3-Next-80B-A3B-Instruct --dataset-name random
```
After about several minutes, you can get the performance evaluation result.
The performance result is:
**Hardware**: A3-752T, 2 node
**Deployment**: TP4 + Full Decode Only
**Input/Output**: 2k/2k
**Concurrency**: 32
**Performance**: 580tps, TPOT 54ms

View File

@@ -7,292 +7,82 @@
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# mypy: ignore-errors
from typing import Any, Optional, Union
from typing import Any, Optional
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
PAD_SLOT_ID = -1
@triton.jit()
def _causal_conv1d_fwd_kernel( # continuous batching
# Pointers to matrices
x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
w_ptr, # (dim, width)
bias_ptr,
conv_states_ptr,
conv_state_indices_ptr,
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
o_ptr, # (dim, seqlen)
# Matrix dimensions
dim: tl.constexpr,
state_len: int,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.constexpr, # stride to get to next token
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
stride_w_width: tl.constexpr, # stride to get to next width-axis value
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
# BLOCK_N elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
# cache_idx
conv_state_batch_coord = idx_seq
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = conv_states_ptr + (
conv_state_batch_coord * stride_conv_state_seq) + (
idx_feats * stride_conv_state_dim) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq)
mask_dim = idx_feats < dim
# read prior-token data from `x`
offset_x = token_offset - KERNEL_WIDTH + 1
if KERNEL_WIDTH >= 2:
x0_ptrs = x_base + offset_x * stride_x_token
x0 = tl.load(x0_ptrs, mask_dim & (offset_x > 0))
if KERNEL_WIDTH >= 3:
x1_ptrs = x0_ptrs + 1 * stride_x_token
x1 = tl.load(x1_ptrs, mask_dim & (offset_x + 1 > 0))
if KERNEL_WIDTH >= 4:
x2_ptrs = x1_ptrs + 1 * stride_x_token
x2 = tl.load(x2_ptrs, mask_dim & (offset_x + 2 > 0))
if load_init_state & (chunk_offset == 0):
# load from conv_states
offset_conv_state = state_len - KERNEL_WIDTH + 1
if KERNEL_WIDTH >= 2:
x0_ptrs = conv_states_base + offset_conv_state * stride_conv_state_tok
x0 = tl.load(x0_ptrs, mask_dim, 0.0)
if KERNEL_WIDTH >= 3:
x1_ptrs = x0_ptrs + 1 * stride_conv_state_tok
x1 = tl.load(x1_ptrs, mask_dim)
if KERNEL_WIDTH >= 4:
x2_ptrs = x1_ptrs + 1 * stride_conv_state_tok
x2 = tl.load(x2_ptrs, mask_dim)
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
# PRE-LOAD WEIGHTS
mask_dim = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w0 = tl.load(w_ptrs, mask_dim, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w1 = tl.load(w_ptrs, mask_dim, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w2 = tl.load(w_ptrs, mask_dim, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w3 = tl.load(w_ptrs, mask_dim, other=0.0)
for idx_token in tl.static_range(BLOCK_M):
acc = acc_preload
mask_1d = (idx_token
< segment_len) & mask_dim # token-index # feature-index
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
x = tl.load(x_ptrs_1d, mask=mask_1d)
if KERNEL_WIDTH == 2:
acc += x0 * w0 + x * w1
x0 = x
elif KERNEL_WIDTH == 3:
acc += x0 * w0 + x1 * w1 + x * w2
x0 = x1
x1 = x
elif KERNEL_WIDTH == 4:
acc += x0 * w0 + x1 * w1 + x2 * w2 + x * w3
x0 = x1
x1 = x2
x2 = x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
) * stride_o_token + (idx_feats * stride_o_dim)
tl.store(o_ptrs, acc, mask=mask_1d)
# update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
# just read from 'x'
# copy 'x' data to conv_state
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
idx_tokens_last = (seqlen - state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (
(sequence_start_index + idx_tokens_last) *
stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = ((idx_tokens_last >= 0)[:, None] &
(idx_tokens_last < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = conv_states_base[None, :] + (
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier()
tl.store(conv_states_ptrs_target, new_conv_state, mask)
elif load_init_state:
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_source = (
conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
# update conv_state by shifting left, BUT
# set cols prior to 'x' as zeros + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier()
tl.store(conv_states_ptrs_target, new_conv_state, mask)
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn(x: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, None],
conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
metadata: Optional[Any] = None,
validate_data=False):
"""support varlen + continuous batching when x is 2D tensor
x: (dim,cu_seq_len)
cu_seq_len = total tokens of all seqs in that batch
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
conv_states: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
metadata: Optional[Any] = None,
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
conv_states: (...,dim,width - 1) itype
updated inplace if provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
]
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
if
x = [5, 1, 1, 1] <- continuous batching (batch=4)
then
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
the ending index of the last sequence
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
@@ -301,144 +91,48 @@ def causal_conv1d_fn(x: torch.Tensor,
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
[single boolean for each sequence in the batch: True or False]
bias: (dim,)
activation: either None or "silu" or "swish" or True
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: same shape as `x`
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if isinstance(activation, bool) and activation:
activation = "silu"
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
# Store original dtype to cast back at the end
out = torch.empty_strided(x.size(),
x.stride(),
dtype=x.dtype,
device=x.device)
out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
width = weight.shape[1]
dim, _ = x.shape
_, width = weight.shape
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
padded_batch = query_start_loc.size(0) - 1
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
stride_w_width = weight.stride(1)
stride_istate_seq = 0
stride_istate_dim = 0
stride_istate_token = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
num_cache_lines = 0
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
stride_cache_indices = cache_indices.stride(
0) if cache_indices is not None else 0
if validate_data:
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
assert x.dim() == 2
assert width in [2, 3, 4]
assert query_start_loc is not None
assert query_start_loc.dim() == 1
assert x.stride(0) == 1 or x.stride(1) == 1
if bias is not None:
assert bias.dim() == 1
assert dim == bias.size(0)
if conv_states is not None:
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and conv_states.shape[2] >= width - 1)
assert stride_istate_dim == 1
if cache_indices is not None:
assert cache_indices.dim() == 1
assert padded_batch == cache_indices.size(0)
if has_initial_state is not None:
assert has_initial_state.size() == (padded_batch, )
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
BLOCK_M = 64
seqlens = query_start_loc.diff()
seq_blocks = -(-seqlens // BLOCK_M)
total_seq_blocks = seq_blocks.sum().item()
# tracking which seq-idx the Triton program is handling
batch_ptr = torch.repeat_interleave(
torch.arange(len(seq_blocks), device=x.device),
seq_blocks).to(torch.int32)
# tracking BLOCK_M-based index in the sequence the Triton program is handling
max_blocks = seq_blocks.max().item() if len(seq_blocks) > 0 else 0
arange = torch.arange(max_blocks, device=x.device)
mask = arange.unsqueeze(0) < seq_blocks.unsqueeze(1)
token_chunk_offset_ptr = arange.repeat(len(seq_blocks),
1)[mask].to(torch.int32)
BLOCK_N = 256
grid = (total_seq_blocks, triton.cdiv(dim, BLOCK_N))
with torch.npu.device(x.device.index):
_causal_conv1d_fwd_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_states,
cache_indices,
has_initial_state,
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
out,
# Matrix dimensions
dim,
state_len,
num_cache_lines,
# stride
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N)
return out
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]][..., :(
width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
if has_initial_state[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
@triton.jit