Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
dynamic_mxfp4_quant: bool = False,
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
if self.dynamic_mxfp4_quant:
|
||||
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
w_s.T.contiguous(), requires_grad=False
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
|
||||
elif self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
if self.dynamic_mxfp4_quant:
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, kwargs)
|
||||
else:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user