[fix]matmul not support cuda graph

This commit is contained in:
Li Wei
2026-01-06 16:07:29 +08:00
parent 515a4eeda9
commit 9533f68e99
3 changed files with 12 additions and 7 deletions

View File

@@ -1616,17 +1616,21 @@ def scaled_int8_quant_cuda(
return x_q, scale, azp, static
def fake_scaled_int8_quant(
def _fake_scaled_int8_quant(
x: torch.Tensor,
scale: torch.Tensor,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device)
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
scale = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
return x_q, scale, azp, False
scaled_int8_quant.register_fake(fake_scaled_int8_quant)
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
######################################################
@@ -1821,7 +1825,7 @@ def _fake_matmul(
w_pc_max: torch.Tensor = None,
) -> torch.Tensor:
return torch.empty(
(x.shape[0], w.shape[0]),
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
dtype=out_dtype,
device=x.device,
)