This PR upgrade CANN from 8.2rc1 to 8.3rc1 and remove the CANN version
check logic.
TODO: we notice that UT runs failed with CANN 8.3 image. So the base
image for UT is still 8.2. We'll fix it later.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -45,8 +45,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
if (is_enable_nz() and torch.version.cann.startswith("8.3") and
|
||||
layer.weight.data.dtype in [torch.float16, torch.bfloat16]):
|
||||
if (is_enable_nz() and layer.weight.data.dtype
|
||||
in [torch.float16, torch.bfloat16]):
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
|
||||
@@ -411,9 +411,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
quant_per_tensor)
|
||||
|
||||
# For unquant
|
||||
if mmrs_fusion and isinstance(
|
||||
self.layer.quant_method, UnquantizedLinearMethod
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
if mmrs_fusion and isinstance(self.layer.quant_method,
|
||||
UnquantizedLinearMethod):
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x,
|
||||
self.layer.weight.t(),
|
||||
@@ -429,8 +428,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
elif mmrs_fusion and (
|
||||
isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
AscendW8A8LinearMethod)):
|
||||
if x.dtype != torch.int8:
|
||||
x_quant = quant_per_tensor(
|
||||
x, self.layer.aclnn_input_scale_reciprocal,
|
||||
|
||||
Reference in New Issue
Block a user