Add graph runner support with torch compile on CPU (#7843)
This commit is contained in:
@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
_is_cpu_amx_available
|
||||
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
||||
_amx_process_weight_after_loading(layer, ["weight"])
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
layer.weight_scale_inv.data, requires_grad=False
|
||||
)
|
||||
return
|
||||
else:
|
||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||
|
||||
@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
_is_cpu_amx_available
|
||||
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
||||
_amx_process_weight_after_loading(layer, ["weight"])
|
||||
return
|
||||
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
else:
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
_is_cpu_amx_available
|
||||
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
||||
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||
return
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user