diff --git a/vllm_kunlun/ops/fla/chunk.py b/vllm_kunlun/ops/fla/chunk.py index 01e074d..90dbd0d 100644 --- a/vllm_kunlun/ops/fla/chunk.py +++ b/vllm_kunlun/ops/fla/chunk.py @@ -27,14 +27,12 @@ from .wy_fast import recompute_w_u_fwd def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,): chunk_size=64 - A = A.transpose(1,2) + A = -A.transpose(1,2) sequence_length = A.shape[-2] pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size A = F.pad(A, (0, 0, 0, pad_size)) A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) - mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0) - A = A.masked_fill(mask, 0) for i in range(1, chunk_size): row = A[..., i, :i].clone() sub = A[..., :i, :i].clone()