diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index f21c03e..c703947 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -131,7 +131,6 @@ def vanilla_chunked_prefill( attn_output = (attn_output[q_mask].view([-1, num_query_heads, head_dim]).to(output.dtype)) - output = output.view_as(attn_output) output.copy_(attn_output) return attn_output @@ -248,6 +247,7 @@ def vanilla_chunked_prefill_mla( attn_output = (attn_output[q_mask].view([-1, num_heads, v_head_dim]).to(output.dtype)) + output = output.view_as(attn_output) output.copy_(attn_output) return attn_output diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index ae9dd46..f740d8f 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -24,7 +24,7 @@ import torch_npu def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, input_offset: torch.Tensor): return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, - torch.qint8, -1, True) + torch.qint8, -1, False) class AscendW8A8LinearMethod: @@ -102,12 +102,12 @@ class AscendW8A8LinearMethod: def process_weights_after_loading(self, layer): expanding_factor = layer.weight.data.shape[1] - layer.aclnn_input_scale = torch.nn.Parameter( + layer.aclnn_input_scale = 1 / torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor), requires_grad=False) layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor), - requires_grad=False) + requires_grad=False).to(layer.aclnn_input_scale.dtype) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight_scale.data = torch.flatten(layer.weight_scale.data)