Add graph runner support with torch compile on CPU (#7843)

This commit is contained in:
Cao E
2025-09-08 12:33:58 +08:00
committed by GitHub
parent 8cda5a622c
commit 7577f0e40f
16 changed files with 820 additions and 48 deletions

View File

@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data

View File

@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
return
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
else:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights(
@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
_is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
else:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)