diff --git a/vllm_ascend/ops/triton/__init__.py b/vllm_ascend/ops/triton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/ops/triton/fla/__init__.py b/vllm_ascend/ops/triton/fla/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/triton/fla/fla.py similarity index 100% rename from vllm_ascend/ops/fla.py rename to vllm_ascend/ops/triton/fla/fla.py diff --git a/vllm_ascend/ops/sigmoid_gating.py b/vllm_ascend/ops/triton/fla/sigmoid_gating.py similarity index 100% rename from vllm_ascend/ops/sigmoid_gating.py rename to vllm_ascend/ops/triton/fla/sigmoid_gating.py diff --git a/vllm_ascend/ops/triton/mamba/__init__.py b/vllm_ascend/ops/triton/mamba/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/triton/mamba/casual_conv1d.py similarity index 100% rename from vllm_ascend/ops/casual_conv1d.py rename to vllm_ascend/ops/triton/mamba/casual_conv1d.py diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index cc550ccc..eb3f300b 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -3,11 +3,12 @@ import vllm.model_executor.layers.fla.ops.fused_recurrent import vllm.model_executor.layers.fla.ops.layernorm_guard import vllm.model_executor.layers.mamba.ops.causal_conv1d -from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, - causal_conv1d_update_npu) -from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule -from vllm_ascend.ops.sigmoid_gating import \ +from vllm_ascend.ops.triton.fla.fla import (LayerNormFn, + torch_chunk_gated_delta_rule) +from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_recurrent_gated_delta_rule_fwd_kernel +from vllm_ascend.ops.triton.mamba.casual_conv1d import ( + causal_conv1d_fn, causal_conv1d_update_npu) vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn