diff --git a/docs/source/tutorials/Qwen3-Next.md b/docs/source/tutorials/Qwen3-Next.md index a0a7a6d1..3304751e 100644 --- a/docs/source/tutorials/Qwen3-Next.md +++ b/docs/source/tutorials/Qwen3-Next.md @@ -92,10 +92,8 @@ source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh Run the following script to start the vLLM server on multi-NPU: -For an Atlas A2 with 64 GB of NPU card memory, tensor-parallel-size should be at least 4, and for 32 GB of memory, tensor-parallel-size should be at least 8. - ```bash -vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 4096 --gpu-memory-utilization 0.7 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' +vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 32768 --gpu-memory-utilization 0.8 --max-num-batched-tokens 4096 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' ``` Once your server is started, you can query the model with input prompts. @@ -170,11 +168,11 @@ Prompt: 'Who are you?', Generated text: ' What do you know about me?\n\nHello! I 1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details. -2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.11.0rc3` for reference only. +2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.13.0rc1` for reference only. | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| -| gsm8k | - | accuracy | gen | 96.3 | +| gsm8k | - | accuracy | gen | 95.53 | ## Performance @@ -201,3 +199,15 @@ vllm bench serve --model Qwen/Qwen3-Next-80B-A3B-Instruct --dataset-name random ``` After about several minutes, you can get the performance evaluation result. + +The performance result is: + +**Hardware**: A3-752T, 2 node + +**Deployment**: TP4 + Full Decode Only + +**Input/Output**: 2k/2k + +**Concurrency**: 32 + +**Performance**: 580tps, TPOT 54ms diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index 38b838d8..118de8aa 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -7,292 +7,82 @@ # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # mypy: ignore-errors -from typing import Any, Optional, Union +from typing import Any, Optional import torch +import torch.nn.functional as F import triton import triton.language as tl PAD_SLOT_ID = -1 -@triton.jit() -def _causal_conv1d_fwd_kernel( # continuous batching - # Pointers to matrices - x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences - w_ptr, # (dim, width) - bias_ptr, - conv_states_ptr, - conv_state_indices_ptr, - has_initial_states_ptr, - query_start_loc_ptr, - batch_ptr, - token_chunk_offset_ptr, - o_ptr, # (dim, seqlen) - # Matrix dimensions - dim: tl.constexpr, - state_len: int, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl.constexpr, # stride to get to next token - stride_w_dim: tl.constexpr, # stride to get to next dim-axis value - stride_w_width: tl.constexpr, # stride to get to next width-axis value - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_cache_indices: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - NP2_STATELEN: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", ): - # single-sequence id - idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) - chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape - # BLOCK_N elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) - sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) - # find the actual sequence length - seqlen = sequence_end_index - sequence_start_index - - token_offset = BLOCK_M * chunk_offset - segment_len = min(BLOCK_M, seqlen - token_offset) - - # base of the sequence - x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] - - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_cache_indices).to( - tl.int64) + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) else: - # cache_idx - conv_state_batch_coord = idx_seq + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - conv_states_base = conv_states_ptr + ( - conv_state_batch_coord * stride_conv_state_seq) + ( - idx_feats * stride_conv_state_dim) # [BLOCK_N,] - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq) - - mask_dim = idx_feats < dim - - # read prior-token data from `x` - offset_x = token_offset - KERNEL_WIDTH + 1 - if KERNEL_WIDTH >= 2: - x0_ptrs = x_base + offset_x * stride_x_token - x0 = tl.load(x0_ptrs, mask_dim & (offset_x > 0)) - if KERNEL_WIDTH >= 3: - x1_ptrs = x0_ptrs + 1 * stride_x_token - x1 = tl.load(x1_ptrs, mask_dim & (offset_x + 1 > 0)) - if KERNEL_WIDTH >= 4: - x2_ptrs = x1_ptrs + 1 * stride_x_token - x2 = tl.load(x2_ptrs, mask_dim & (offset_x + 2 > 0)) - - if load_init_state & (chunk_offset == 0): - # load from conv_states - offset_conv_state = state_len - KERNEL_WIDTH + 1 - if KERNEL_WIDTH >= 2: - x0_ptrs = conv_states_base + offset_conv_state * stride_conv_state_tok - x0 = tl.load(x0_ptrs, mask_dim, 0.0) - if KERNEL_WIDTH >= 3: - x1_ptrs = x0_ptrs + 1 * stride_conv_state_tok - x1 = tl.load(x1_ptrs, mask_dim) - if KERNEL_WIDTH >= 4: - x2_ptrs = x1_ptrs + 1 * stride_conv_state_tok - x2 = tl.load(x2_ptrs, mask_dim) - - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) - - x_base_1d = x_base + token_offset * stride_x_token # starting of chunk - - # PRE-LOAD WEIGHTS - mask_dim = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w0 = tl.load(w_ptrs, mask_dim, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w1 = tl.load(w_ptrs, mask_dim, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w2 = tl.load(w_ptrs, mask_dim, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w3 = tl.load(w_ptrs, mask_dim, other=0.0) - - for idx_token in tl.static_range(BLOCK_M): - acc = acc_preload - mask_1d = (idx_token - < segment_len) & mask_dim # token-index # feature-index - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - x = tl.load(x_ptrs_1d, mask=mask_1d) - - if KERNEL_WIDTH == 2: - acc += x0 * w0 + x * w1 - x0 = x - elif KERNEL_WIDTH == 3: - acc += x0 * w0 + x1 * w1 + x * w2 - x0 = x1 - x1 = x - elif KERNEL_WIDTH == 4: - acc += x0 * w0 + x1 * w1 + x2 * w2 + x * w3 - x0 = x1 - x1 = x2 - x2 = x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - - o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token - ) * stride_o_token + (idx_feats * stride_o_dim) - tl.store(o_ptrs, acc, mask=mask_1d) - - # update conv_state with new data [only by the Triton program handles chunk_offset=0] - if chunk_offset == 0: - if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) - # just read from 'x' - # copy 'x' data to conv_state - # load only 'x' data (and set 0 before 'x' if seqlen < state_len) - idx_tokens_last = (seqlen - state_len) + tl.arange( - 0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = x_ptr + ( - (sequence_start_index + idx_tokens_last) * - stride_x_token)[:, None] + ( - idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] - mask_x = ((idx_tokens_last >= 0)[:, None] & - (idx_tokens_last < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + ( - idx_tokens_conv * stride_conv_state_tok)[:, None] - - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] - tl.debug_barrier() - tl.store(conv_states_ptrs_target, new_conv_state, mask) - elif load_init_state: - # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - conv_states_ptrs_source = ( - conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens_conv + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - new_conv_state = tl.where( - mask, conv_state, loaded_x - ) # BUG in 'tl.where' which requires a barrier before this - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] - tl.store(conv_states_ptrs_target, new_conv_state, mask) + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) else: - # update conv_state by shifting left, BUT - # set cols prior to 'x' as zeros + cols from 'x' - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - VAL = state_len - seqlen - - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] - tl.debug_barrier() - tl.store(conv_states_ptrs_target, new_conv_state, mask) + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_fn(x: torch.Tensor, - weight: torch.Tensor, - bias: Union[torch.Tensor, None], - conv_states: torch.Tensor, - query_start_loc: torch.Tensor, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID, - metadata: Optional[Any] = None, - validate_data=False): - """support varlen + continuous batching when x is 2D tensor - x: (dim,cu_seq_len) - cu_seq_len = total tokens of all seqs in that batch +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + conv_states: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + metadata: Optional[Any] = None, + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen weight: (dim, width) - conv_states: (...,dim,width - 1) itype - updated inplace if provided - [it use `cache_indices` to get the index to the cache of conv_state for that sequence - conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True - and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' - ] + bias: (dim,) query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. - if - x = [5, 1, 1, 1] <- continuous batching (batch=4) - then - query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is - the ending index of the last sequence - [length(query_start_loc)-1 == batch] for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 @@ -301,144 +91,48 @@ def causal_conv1d_fn(x: torch.Tensor, has_initial_state: (batch) bool indicates whether should the kernel take the current state as initial state for the calculations - [single boolean for each sequence in the batch: True or False] - bias: (dim,) - activation: either None or "silu" or "swish" or True + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: same shape as `x` + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) """ - if isinstance(activation, bool) and activation: - activation = "silu" + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None - # Store original dtype to cast back at the end - out = torch.empty_strided(x.size(), - x.stride(), - dtype=x.dtype, - device=x.device) + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + width = weight.shape[1] - dim, _ = x.shape - _, width = weight.shape - - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) - - padded_batch = query_start_loc.size(0) - 1 - stride_x_dim = x.stride(0) - stride_x_token = x.stride(1) - stride_w_dim = weight.stride(0) - stride_w_width = weight.stride(1) - stride_istate_seq = 0 - stride_istate_dim = 0 - stride_istate_token = 0 - stride_o_dim = out.stride(0) - stride_o_token = out.stride(1) - - num_cache_lines = 0 - if conv_states is not None: - # extensions to support vLLM: - # 1. conv_states is used to replaced initial_states - # 2. conv_states serve as a cache with num cache lines can be larger than batch size - # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] - # 4. computation can be skipped if cache_indices[idx] == pad_slot_id - num_cache_lines = conv_states.size(0) - stride_istate_seq = conv_states.stride(0) - stride_istate_dim = conv_states.stride(1) - stride_istate_token = conv_states.stride(2) - - stride_cache_indices = cache_indices.stride( - 0) if cache_indices is not None else 0 - - if validate_data: - is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) - assert x.dim() == 2 - assert width in [2, 3, 4] - assert query_start_loc is not None - assert query_start_loc.dim() == 1 - assert x.stride(0) == 1 or x.stride(1) == 1 - if bias is not None: - assert bias.dim() == 1 - assert dim == bias.size(0) - if conv_states is not None: - assert (num_cache_lines == conv_states.shape[0] - and dim == conv_states.shape[1] - and conv_states.shape[2] >= width - 1) - assert stride_istate_dim == 1 - if cache_indices is not None: - assert cache_indices.dim() == 1 - assert padded_batch == cache_indices.size(0) - if has_initial_state is not None: - assert has_initial_state.size() == (padded_batch, ) - assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" - assert weight.stride(1) == 1 - assert (dim, width) == weight.shape - assert is_channel_last, "Need to run in channel-last layout" - - BLOCK_M = 64 - seqlens = query_start_loc.diff() - seq_blocks = -(-seqlens // BLOCK_M) - total_seq_blocks = seq_blocks.sum().item() - # tracking which seq-idx the Triton program is handling - batch_ptr = torch.repeat_interleave( - torch.arange(len(seq_blocks), device=x.device), - seq_blocks).to(torch.int32) - - # tracking BLOCK_M-based index in the sequence the Triton program is handling - max_blocks = seq_blocks.max().item() if len(seq_blocks) > 0 else 0 - arange = torch.arange(max_blocks, device=x.device) - mask = arange.unsqueeze(0) < seq_blocks.unsqueeze(1) - token_chunk_offset_ptr = arange.repeat(len(seq_blocks), - 1)[mask].to(torch.int32) - - BLOCK_N = 256 - grid = (total_seq_blocks, triton.cdiv(dim, BLOCK_N)) - - with torch.npu.device(x.device.index): - _causal_conv1d_fwd_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_states, - cache_indices, - has_initial_state, - query_start_loc, - batch_ptr, - token_chunk_offset_ptr, - out, - # Matrix dimensions - dim, - state_len, - num_cache_lines, - # stride - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_cache_indices, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, - USE_PAD_SLOT=pad_slot_id is not None, - NP2_STATELEN=np2_statelen, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N) - - return out + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]][..., :( + width - 1)].unsqueeze(0), + initial_states=conv_states[cache_indices[i]][..., :(width - 1)] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor @triton.jit