[Kernel] Replace native torch solve_tril by solve_tril_fwd kernel op
This commit is contained in:
@@ -88,7 +88,7 @@ pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.
|
||||
```
|
||||
## Install the AIAK custom ops library
|
||||
```
|
||||
pip install "https://cce-ai-models.bj.bcebos.com/v1/chenyili/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl"
|
||||
pip install "https://cce-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20251219_152418/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl"
|
||||
```
|
||||
## Quick Start
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
import xspeedgate_ops
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
@@ -56,10 +57,8 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=q.dtype)
|
||||
|
||||
#torch版
|
||||
for i in range(len(cu_seqlens)-1):
|
||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||
#kernel版
|
||||
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
|
||||
Reference in New Issue
Block a user