support deepseek quant & mix-parallel with graphmode (#585)
### What this PR does / why we need it? 1. support deepseek with w8a8 quant; 2. support deepseek with mix-parallel(multi-DP, EP+TP); 3. support deepseek with graphmode. --------- Signed-off-by: wen-jie666 <wenjie39@huawei.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
@@ -23,10 +23,8 @@ import torch_npu
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor):
|
||||
out = torch.empty_like(in_tensor, dtype=torch.int8)
|
||||
torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset,
|
||||
out)
|
||||
return out
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
||||
torch.qint8, -1, True)
|
||||
|
||||
|
||||
class AscendW8A8LinearMethod:
|
||||
@@ -88,7 +86,11 @@ class AscendW8A8LinearMethod:
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
return torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
@@ -99,6 +101,13 @@ class AscendW8A8LinearMethod:
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
|
||||
Reference in New Issue
Block a user