[fix]matmul not support cuda graph
This commit is contained in:
@@ -1616,17 +1616,21 @@ def scaled_int8_quant_cuda(
|
||||
return x_q, scale, azp, static
|
||||
|
||||
|
||||
def fake_scaled_int8_quant(
|
||||
def _fake_scaled_int8_quant(
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
azp: Optional[torch.Tensor] = None,
|
||||
symmetric: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
||||
x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device)
|
||||
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
|
||||
scale = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||
return x_q, scale, azp, False
|
||||
|
||||
|
||||
scaled_int8_quant.register_fake(fake_scaled_int8_quant)
|
||||
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
|
||||
|
||||
|
||||
######################################################
|
||||
@@ -1821,7 +1825,7 @@ def _fake_matmul(
|
||||
w_pc_max: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(x.shape[0], w.shape[0]),
|
||||
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
|
||||
dtype=out_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user