### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|` vllm_ascend/quantization/compressed_tensors/compressed_tensors.py`|
|` vllm_ascend/quantization/quant_config.py`|
|` vllm_ascend/quantization/utils.py`|
|` vllm_ascend/quantization/w4a16.py`|
|` vllm_ascend/quantization/w4a4_flatquant_dynamic.py`|
|` vllm_ascend/quantization/w4a8_dynamic.py`|
|` vllm_ascend/quantization/w8a16.py`|
|` vllm_ascend/quantization/w8a8.py`|
|` vllm_ascend/quantization/w8a8_dynamic.py`|
|` vllm_ascend/quantization/w8a8_pdmix.py`|
|` vllm_ascend/quantization/w8a8mxfp8.py`|
|` vllm_ascend/sample/rejection_sampler.py`|
|` vllm_ascend/sample/sampler.py`|
|` vllm_ascend/worker/block_table.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -15,14 +15,16 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
get_ascend_device_type,
|
||||
get_weight_prefetch_method, maybe_trans_nz)
|
||||
from vllm_ascend.utils import (
|
||||
COMPRESSED_TENSORS_METHOD,
|
||||
get_weight_prefetch_method,
|
||||
maybe_trans_nz,
|
||||
)
|
||||
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
@@ -44,13 +46,11 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
) -> dict[str, Any]:
|
||||
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
return params_dict
|
||||
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
@@ -60,29 +60,23 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||
if params_dtype == torch.bfloat16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size,
|
||||
dtype=torch.float32)
|
||||
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
||||
elif params_dtype == torch.float16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size,
|
||||
dtype=torch.int64)
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
||||
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
bias: torch.Tensor | None = None,
|
||||
tp_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
@@ -95,15 +89,15 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
start_flag=x,
|
||||
)
|
||||
try:
|
||||
quant_comm_config = getattr(layer, "_quant_comm_config")
|
||||
quant_comm_config = layer._quant_comm_config
|
||||
except AttributeError:
|
||||
quant_comm_config = {}
|
||||
comm_fn = quant_comm_config.get("communication_fn")
|
||||
enable_flashcomm2_quant_comm = comm_fn is not None and (
|
||||
"o_proj" in layer.prefix or "out_proj" in layer.prefix)
|
||||
"o_proj" in layer.prefix or "out_proj" in layer.prefix
|
||||
)
|
||||
if enable_flashcomm2_quant_comm:
|
||||
quant_input_x = x.contiguous().view(
|
||||
-1, layer.aclnn_input_scale_reciprocal.size(0))
|
||||
quant_input_x = x.contiguous().view(-1, layer.aclnn_input_scale_reciprocal.size(0))
|
||||
quant_x = torch.ops.vllm.quantize(
|
||||
quant_input_x,
|
||||
layer.aclnn_input_scale,
|
||||
@@ -132,7 +126,7 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
|
||||
try:
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method")
|
||||
ascend_quant_method = layer.ascend_quant_method
|
||||
except AttributeError:
|
||||
ascend_quant_method = ""
|
||||
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
@@ -150,14 +144,14 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.input_scale.data.repeat(expanding_factor), requires_grad=False
|
||||
)
|
||||
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False)
|
||||
layer.input_scale.data.repeat(expanding_factor), requires_grad=False
|
||||
)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
layer.input_offset.data.repeat(expanding_factor), requires_grad=False
|
||||
).to(layer.aclnn_input_scale.dtype)
|
||||
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
@@ -166,5 +160,4 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
ascend_quant_method = getattr(layer, "ascend_quant_method", "")
|
||||
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
deq_scale = layer.input_scale.data * layer.weight_scale.data
|
||||
layer.deq_scale = torch.nn.Parameter(deq_scale,
|
||||
requires_grad=False)
|
||||
layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user