Merge pull request #39 from LiangYC1021/v0.11.0dev
[Kernel] Replace native torch solve_tril by solve_tril_fwd kernel op
This commit is contained in:
@@ -88,7 +88,7 @@ pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.
|
|||||||
```
|
```
|
||||||
## Install the AIAK custom ops library
|
## Install the AIAK custom ops library
|
||||||
```
|
```
|
||||||
pip install "https://cce-ai-models.bj.bcebos.com/v1/chenyili/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl"
|
pip install "https://cce-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20251219_152418/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl"
|
||||||
```
|
```
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .solve_tril import solve_tril
|
|||||||
from .utils import SUPPRESS_LEVEL, input_guard
|
from .utils import SUPPRESS_LEVEL, input_guard
|
||||||
from .wy_fast import recompute_w_u_fwd
|
from .wy_fast import recompute_w_u_fwd
|
||||||
|
|
||||||
|
import xspeedgate_ops
|
||||||
|
|
||||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||||
chunk_size=64
|
chunk_size=64
|
||||||
@@ -56,10 +57,8 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
output_dtype=q.dtype)
|
output_dtype=q.dtype)
|
||||||
|
|
||||||
#torch版
|
#kernel版
|
||||||
for i in range(len(cu_seqlens)-1):
|
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
|
||||||
A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
|
||||||
w, u = recompute_w_u_fwd(
|
w, u = recompute_w_u_fwd(
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
|
|||||||
Reference in New Issue
Block a user