upgrade torch npu version (#4433)
vLLM graph feature now rely on torch >=2.8. To make graph mode work, we need upgrade torch version as well. For long term support, upgrade torch to a newer one is good to go as well. Related vLLM change: https://github.com/vllm-project/vllm/pull/25110 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2
This commit is contained in:
@@ -924,8 +924,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
try:
|
||||
return getattr(layer, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
@@ -273,8 +273,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
try:
|
||||
return getattr(layer, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
@@ -18,7 +18,6 @@ import os
|
||||
|
||||
import vllm_ascend.patch.platform.patch_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_dynamo_vllm_backend # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# mypy: ignore-errors
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch.fx as fx
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.caching import VllmSerializableFunction
|
||||
|
||||
_original_vllmbackend_call = VllmBackend.__call__
|
||||
|
||||
|
||||
def __patch_call__(self, graph: fx.GraphModule, example_inputs,
|
||||
options: Dict[str, Any]) -> VllmSerializableFunction:
|
||||
return _original_vllmbackend_call(self, graph, example_inputs)
|
||||
|
||||
|
||||
VllmBackend.__call__ = __patch_call__
|
||||
@@ -119,8 +119,10 @@ class AscendW8A8LinearMethod:
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
|
||||
quant_comm_config = getattr(layer, "_quant_comm_config", {})
|
||||
try:
|
||||
quant_comm_config = getattr(layer, "_quant_comm_config")
|
||||
except AttributeError:
|
||||
quant_comm_config = {}
|
||||
comm_fn = quant_comm_config.get("communication_fn")
|
||||
enable_flashcomm2_quant_comm = comm_fn is not None and (
|
||||
"o_proj" in layer.prefix or "out_proj" in layer.prefix)
|
||||
@@ -151,8 +153,12 @@ class AscendW8A8LinearMethod:
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if getattr(layer, "ascend_quant_method",
|
||||
"") == COMPRESSED_TENSORS_METHOD:
|
||||
|
||||
try:
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method")
|
||||
except AttributeError:
|
||||
ascend_quant_method = ""
|
||||
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
quant_bias = bias
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
@@ -194,8 +200,13 @@ class AscendW8A8LinearMethod:
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
|
||||
if getattr(layer, "ascend_quant_method",
|
||||
"") == COMPRESSED_TENSORS_METHOD:
|
||||
|
||||
try:
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method")
|
||||
except AttributeError:
|
||||
ascend_quant_method = ""
|
||||
|
||||
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
deq_scale = layer.input_scale.data * layer.weight_scale.data
|
||||
layer.deq_scale = torch.nn.Parameter(deq_scale,
|
||||
requires_grad=False)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -73,33 +73,20 @@ class AscendW8A8DynamicLinearMethod:
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
config = getattr(layer, "_ascend_quant_config", {})
|
||||
if not isinstance(x, tuple):
|
||||
output_dtype = config.get("output_dtype", x.dtype)
|
||||
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
else:
|
||||
assert "output_dtype" in config.keys(), (
|
||||
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
||||
f"for pre-quantized input, got config [{config}]")
|
||||
output_dtype = config["output_dtype"]
|
||||
quantized_x, dynamic_scale = x
|
||||
pertoken_scale = (dynamic_scale
|
||||
if config.get("pertoken_scale", True) else None)
|
||||
|
||||
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=bias,
|
||||
output_dtype=output_dtype,
|
||||
output_dtype=x.dtype,
|
||||
)
|
||||
return ((output, dynamic_scale)
|
||||
if config.get("return_scale", False) else output)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
|
||||
@@ -948,7 +948,7 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
|
||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if not flashcomm2_enable():
|
||||
logger.info("FLASHCOMM2 not enable.")
|
||||
logger.debug("FLASHCOMM2 not enable.")
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user