Merge pull request #13 from ldh2020/v0.11.0dev

[Bugfix] fix the bug of torch_solve_tril
This commit is contained in:
Xinyu Dong
2025-12-12 17:37:46 +08:00
committed by GitHub

View File

@@ -27,14 +27,12 @@ from .wy_fast import recompute_w_u_fwd
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
A = A.transpose(1,2)
A = -A.transpose(1,2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()