diff --git a/docs/source/installation.md b/docs/source/installation.md index f71e2c7..0ee1d93 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -88,7 +88,7 @@ pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3. ``` ## Install the AIAK custom ops library ``` -pip install "https://cce-ai-models.bj.bcebos.com/v1/chenyili/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl" +pip install "https://cce-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20251219_152418/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl" ``` ## Quick Start diff --git a/vllm_kunlun/ops/fla/chunk.py b/vllm_kunlun/ops/fla/chunk.py index 90dbd0d..2da888e 100644 --- a/vllm_kunlun/ops/fla/chunk.py +++ b/vllm_kunlun/ops/fla/chunk.py @@ -24,6 +24,7 @@ from .solve_tril import solve_tril from .utils import SUPPRESS_LEVEL, input_guard from .wy_fast import recompute_w_u_fwd +import xspeedgate_ops def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,): chunk_size=64 @@ -56,10 +57,8 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, cu_seqlens=cu_seqlens, output_dtype=q.dtype) - #torch版 - for i in range(len(cu_seqlens)-1): - A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] - A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype) + #kernel版 + torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens) w, u = recompute_w_u_fwd( k=k, v=v,