[Bugfix] Fix w8a8_int8 import error on NPU (#8147)
This commit is contained in:
@@ -754,6 +754,8 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
if isinstance(layer, RowParallelLinear):
|
if isinstance(layer, RowParallelLinear):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||||
|
|||||||
Reference in New Issue
Block a user