diff --git a/vllm_ascend/ops/triton/mamba/causal_conv1d.py b/vllm_ascend/ops/triton/mamba/causal_conv1d.py index 9fb9465b..84c330b5 100644 --- a/vllm_ascend/ops/triton/mamba/causal_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/causal_conv1d.py @@ -595,6 +595,8 @@ def causal_conv1d_update_npu( indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ + weight = weight.transpose(0, 1).contiguous() + conv_state = conv_state.transpose(1, 2).contiguous() if validate_data: assert pad_slot_id is not None assert x.stride(1) == 1 @@ -608,40 +610,33 @@ def causal_conv1d_update_npu( unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) + x = x.unsqueeze(1) if query_start_loc is None: - batch, dim, seqlen = x.shape + batch, seqlen, dim = x.shape else: assert conv_state_indices is not None batch = conv_state_indices.size(0) dim = x.size(1) seqlen = max_query_len - _, width = weight.shape - num_cache_lines, _, state_len_total = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert conv_state.stride(-2) == 1 - assert state_len_total >= width - 1 - assert num_cache_lines >= batch - assert weight.stride(1) == 1 + width, _ = weight.shape + num_cache_lines, state_len_total,_ = conv_state.size() # overwrite-on-x strategy same as original out = x - stride_w_dim, stride_w_width = weight.stride() + stride_w_width, stride_w_dim = weight.stride() if query_start_loc is None: - stride_x_seq, stride_x_dim, stride_x_token = x.stride() - stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_x_seq, stride_x_token,stride_x_dim = x.stride() + stride_o_seq, stride_o_token, stride_o_dim = out.stride() else: stride_x_token, stride_x_dim = x.stride() stride_x_seq = 0 stride_o_token, stride_o_dim = out.stride() stride_o_seq = 0 - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + stride_istate_seq, stride_istate_token, stride_istate_dim = conv_state.stride( ) stride_state_indices = conv_state_indices.stride( 0) if conv_state_indices is not None else 0 @@ -657,7 +652,7 @@ def causal_conv1d_update_npu( #keep program count around ~[80..160] # vector core 40 # TODO: use driver to get the vector core num - CORE_HINT = 40 + CORE_HINT = 40 # channel tile: 512 when dim large (reduce tasks), else 256 block_n = 512 if dim >= 512 else 256 g = triton.cdiv(dim, block_n) @@ -674,14 +669,13 @@ def causal_conv1d_update_npu( b_tile = 8 # token chunk based on block_n (32KB UB idea); conservative - t_chunk = 20 if block_n == 512 else 48 + t_chunk = 1 if block_n == 512 else 48 def grid(META): return ( triton.cdiv(batch, META["B_TILE"]), triton.cdiv(dim, META["BLOCK_N"]), ) - _causal_conv1d_update_kernel_npu_tiled[grid]( x, weight, @@ -725,5 +719,5 @@ def causal_conv1d_update_npu( ) if unsqueeze: - out = out.squeeze(-1) + out = out.squeeze(1) return out.to(original_x_dtype)