From 9bb2ee06a4c5198d370eb89232d12223492f04d4 Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:01:50 +0800 Subject: [PATCH] [Bugfix] fix the bug of torch_solve_tril --- vllm_kunlun/ops/fla/chunk.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_kunlun/ops/fla/chunk.py b/vllm_kunlun/ops/fla/chunk.py index 01e074d..90dbd0d 100644 --- a/vllm_kunlun/ops/fla/chunk.py +++ b/vllm_kunlun/ops/fla/chunk.py @@ -27,14 +27,12 @@ from .wy_fast import recompute_w_u_fwd def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,): chunk_size=64 - A = A.transpose(1,2) + A = -A.transpose(1,2) sequence_length = A.shape[-2] pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size A = F.pad(A, (0, 0, 0, pad_size)) A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) - mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0) - A = A.masked_fill(mask, 0) for i in range(1, chunk_size): row = A[..., i, :i].clone() sub = A[..., :i, :i].clone()