support deepseek quant & mix-parallel with graphmode (#585)

### What this PR does / why we need it?
1. support deepseek with w8a8 quant;
2. support deepseek with mix-parallel(multi-DP, EP+TP);
3. support deepseek with graphmode.
---------

Signed-off-by: wen-jie666 <wenjie39@huawei.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
zzzzwwjj
2025-04-23 16:23:25 +08:00
committed by GitHub
parent e74331a1ed
commit 5c6d05a59e
13 changed files with 520 additions and 221 deletions

View File

@@ -23,10 +23,8 @@ import torch_npu
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
input_offset: torch.Tensor):
out = torch.empty_like(in_tensor, dtype=torch.int8)
torch_npu._npu_quantize_per_tensor(in_tensor, input_scale, input_offset,
out)
return out
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
torch.qint8, -1, True)
class AscendW8A8LinearMethod:
@@ -88,7 +86,11 @@ class AscendW8A8LinearMethod:
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
x = quant_per_tensor(
x,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
return torch_npu.npu_quant_matmul(
x,
@@ -99,6 +101,13 @@ class AscendW8A8LinearMethod:
)
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)