rollback causal_conv1d_fn to torch ops & update qwen3Next doc (#5391)

### What this PR does / why we need it?
Rollback causal_conv1d_fn ops from triton to torch version to fix
hanging issues,meanwhile update Qwen3Next doc

- vLLM version: release/v0.13.0
- vLLM main:
254f6b9867
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2025-12-26 19:57:38 +08:00
committed by GitHub
parent 48854aef5c
commit 7685d0c239
2 changed files with 109 additions and 405 deletions

View File

@@ -92,10 +92,8 @@ source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
Run the following script to start the vLLM server on multi-NPU: Run the following script to start the vLLM server on multi-NPU:
For an Atlas A2 with 64 GB of NPU card memory, tensor-parallel-size should be at least 4, and for 32 GB of memory, tensor-parallel-size should be at least 8.
```bash ```bash
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 4096 --gpu-memory-utilization 0.7 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 32768 --gpu-memory-utilization 0.8 --max-num-batched-tokens 4096 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}'
``` ```
Once your server is started, you can query the model with input prompts. Once your server is started, you can query the model with input prompts.
@@ -170,11 +168,11 @@ Prompt: 'Who are you?', Generated text: ' What do you know about me?\n\nHello! I
1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details. 1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details.
2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.11.0rc3` for reference only. 2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.13.0rc1` for reference only.
| dataset | version | metric | mode | vllm-api-general-chat | | dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----| |----- | ----- | ----- | ----- | -----|
| gsm8k | - | accuracy | gen | 96.3 | | gsm8k | - | accuracy | gen | 95.53 |
## Performance ## Performance
@@ -201,3 +199,15 @@ vllm bench serve --model Qwen/Qwen3-Next-80B-A3B-Instruct --dataset-name random
``` ```
After about several minutes, you can get the performance evaluation result. After about several minutes, you can get the performance evaluation result.
The performance result is:
**Hardware**: A3-752T, 2 node
**Deployment**: TP4 + Full Decode Only
**Input/Output**: 2k/2k
**Concurrency**: 32
**Performance**: 580tps, TPOT 54ms

View File

@@ -7,292 +7,82 @@
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# mypy: ignore-errors # mypy: ignore-errors
from typing import Any, Optional, Union from typing import Any, Optional
import torch import torch
import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
PAD_SLOT_ID = -1 PAD_SLOT_ID = -1
@triton.jit() def causal_conv1d_ref(
def _causal_conv1d_fwd_kernel( # continuous batching x: torch.Tensor,
# Pointers to matrices
x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
w_ptr, # (dim, width)
bias_ptr,
conv_states_ptr,
conv_state_indices_ptr,
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
o_ptr, # (dim, seqlen)
# Matrix dimensions
dim: tl.constexpr,
state_len: int,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.constexpr, # stride to get to next token
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
stride_w_width: tl.constexpr, # stride to get to next width-axis value
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
# BLOCK_N elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
else:
# cache_idx
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = conv_states_ptr + (
conv_state_batch_coord * stride_conv_state_seq) + (
idx_feats * stride_conv_state_dim) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq)
mask_dim = idx_feats < dim
# read prior-token data from `x`
offset_x = token_offset - KERNEL_WIDTH + 1
if KERNEL_WIDTH >= 2:
x0_ptrs = x_base + offset_x * stride_x_token
x0 = tl.load(x0_ptrs, mask_dim & (offset_x > 0))
if KERNEL_WIDTH >= 3:
x1_ptrs = x0_ptrs + 1 * stride_x_token
x1 = tl.load(x1_ptrs, mask_dim & (offset_x + 1 > 0))
if KERNEL_WIDTH >= 4:
x2_ptrs = x1_ptrs + 1 * stride_x_token
x2 = tl.load(x2_ptrs, mask_dim & (offset_x + 2 > 0))
if load_init_state & (chunk_offset == 0):
# load from conv_states
offset_conv_state = state_len - KERNEL_WIDTH + 1
if KERNEL_WIDTH >= 2:
x0_ptrs = conv_states_base + offset_conv_state * stride_conv_state_tok
x0 = tl.load(x0_ptrs, mask_dim, 0.0)
if KERNEL_WIDTH >= 3:
x1_ptrs = x0_ptrs + 1 * stride_conv_state_tok
x1 = tl.load(x1_ptrs, mask_dim)
if KERNEL_WIDTH >= 4:
x2_ptrs = x1_ptrs + 1 * stride_conv_state_tok
x2 = tl.load(x2_ptrs, mask_dim)
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
# PRE-LOAD WEIGHTS
mask_dim = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w0 = tl.load(w_ptrs, mask_dim, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w1 = tl.load(w_ptrs, mask_dim, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w2 = tl.load(w_ptrs, mask_dim, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w3 = tl.load(w_ptrs, mask_dim, other=0.0)
for idx_token in tl.static_range(BLOCK_M):
acc = acc_preload
mask_1d = (idx_token
< segment_len) & mask_dim # token-index # feature-index
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
x = tl.load(x_ptrs_1d, mask=mask_1d)
if KERNEL_WIDTH == 2:
acc += x0 * w0 + x * w1
x0 = x
elif KERNEL_WIDTH == 3:
acc += x0 * w0 + x1 * w1 + x * w2
x0 = x1
x1 = x
elif KERNEL_WIDTH == 4:
acc += x0 * w0 + x1 * w1 + x2 * w2 + x * w3
x0 = x1
x1 = x2
x2 = x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
) * stride_o_token + (idx_feats * stride_o_dim)
tl.store(o_ptrs, acc, mask=mask_1d)
# update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
# just read from 'x'
# copy 'x' data to conv_state
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
idx_tokens_last = (seqlen - state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (
(sequence_start_index + idx_tokens_last) *
stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = ((idx_tokens_last >= 0)[:, None] &
(idx_tokens_last < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = conv_states_base[None, :] + (
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier()
tl.store(conv_states_ptrs_target, new_conv_state, mask)
elif load_init_state:
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_source = (
conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else:
# update conv_state by shifting left, BUT
# set cols prior to 'x' as zeros + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier()
tl.store(conv_states_ptrs_target, new_conv_state, mask)
def causal_conv1d_fn(x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Union[torch.Tensor, None], bias: Optional[torch.Tensor] = None,
conv_states: torch.Tensor, initial_states: Optional[torch.Tensor] = None,
query_start_loc: torch.Tensor, return_final_states: bool = False,
cache_indices: Optional[torch.Tensor] = None, final_states_out: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu", activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID, ):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
conv_states: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
metadata: Optional[Any] = None, metadata: Optional[Any] = None,
validate_data=False): pad_slot_id: int = PAD_SLOT_ID,
"""support varlen + continuous batching when x is 2D tensor ):
x: (dim,cu_seq_len) """
cu_seq_len = total tokens of all seqs in that batch x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen sequences are concatenated from left to right for varlen
weight: (dim, width) weight: (dim, width)
conv_states: (...,dim,width - 1) itype bias: (dim,)
updated inplace if provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
]
query_start_loc: (batch + 1) int32 query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0. the batch, used to index into sequence. prepended by 0.
if
x = [5, 1, 1, 1] <- continuous batching (batch=4)
then
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
the ending index of the last sequence
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]), for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17) x.shape=(dim,17)
cache_indices: (batch) int32 cache_indices: (batch) int32
@@ -301,144 +91,48 @@ def causal_conv1d_fn(x: torch.Tensor,
has_initial_state: (batch) bool has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial indicates whether should the kernel take the current state as initial
state for the calculations state for the calculations
[single boolean for each sequence in the batch: True or False] conv_states: (...,dim,width - 1) itype
bias: (dim,) updated inplace if provided
activation: either None or "silu" or "swish" or True activation: either None or "silu" or "swish"
pad_slot_id: int pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded if cache_indices is passed, lets the kernel identify padded
entries that will not be processed, entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at in this case, the kernel will not process entries at
indices 0 and 3 indices 0 and 3
out: same shape as `x` out: (batch, dim, seqlen)
""" """
if isinstance(activation, bool) and activation: if activation not in [None, "silu", "swish"]:
activation = "silu" raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
# Store original dtype to cast back at the end out_ref = []
out = torch.empty_strided(x.size(), out_ref_b = []
x.stride(), seqlens = query_start_loc[1:] - query_start_loc[:-1]
dtype=x.dtype, seqlens = seqlens.tolist()
device=x.device) splits = torch.split(x, seqlens, dim=-1)
width = weight.shape[1]
dim, _ = x.shape for i in range(len(seqlens)):
_, width = weight.shape x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
state_len = width - 1 continue
np2_statelen = triton.next_power_of_2(state_len) out_ref_b.append(
causal_conv1d_ref(
padded_batch = query_start_loc.size(0) - 1 x_s,
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
stride_w_width = weight.stride(1)
stride_istate_seq = 0
stride_istate_dim = 0
stride_istate_token = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
num_cache_lines = 0
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
stride_cache_indices = cache_indices.stride(
0) if cache_indices is not None else 0
if validate_data:
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
assert x.dim() == 2
assert width in [2, 3, 4]
assert query_start_loc is not None
assert query_start_loc.dim() == 1
assert x.stride(0) == 1 or x.stride(1) == 1
if bias is not None:
assert bias.dim() == 1
assert dim == bias.size(0)
if conv_states is not None:
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and conv_states.shape[2] >= width - 1)
assert stride_istate_dim == 1
if cache_indices is not None:
assert cache_indices.dim() == 1
assert padded_batch == cache_indices.size(0)
if has_initial_state is not None:
assert has_initial_state.size() == (padded_batch, )
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
BLOCK_M = 64
seqlens = query_start_loc.diff()
seq_blocks = -(-seqlens // BLOCK_M)
total_seq_blocks = seq_blocks.sum().item()
# tracking which seq-idx the Triton program is handling
batch_ptr = torch.repeat_interleave(
torch.arange(len(seq_blocks), device=x.device),
seq_blocks).to(torch.int32)
# tracking BLOCK_M-based index in the sequence the Triton program is handling
max_blocks = seq_blocks.max().item() if len(seq_blocks) > 0 else 0
arange = torch.arange(max_blocks, device=x.device)
mask = arange.unsqueeze(0) < seq_blocks.unsqueeze(1)
token_chunk_offset_ptr = arange.repeat(len(seq_blocks),
1)[mask].to(torch.int32)
BLOCK_N = 256
grid = (total_seq_blocks, triton.cdiv(dim, BLOCK_N))
with torch.npu.device(x.device.index):
_causal_conv1d_fwd_kernel[grid](
# Pointers to matrices
x,
weight, weight,
bias, bias,
conv_states, activation=activation,
cache_indices, return_final_states=True,
has_initial_state, final_states_out=conv_states[cache_indices[i]][..., :(
query_start_loc, width - 1)].unsqueeze(0),
batch_ptr, initial_states=conv_states[cache_indices[i]][..., :(width - 1)]
token_chunk_offset_ptr, if has_initial_state[i] else None))
out, out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
# Matrix dimensions out_ref_tensor = torch.cat(out_ref, dim=0)
dim, return out_ref_tensor
state_len,
num_cache_lines,
# stride
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N)
return out
@triton.jit @triton.jit