[hotfix] fix mixtral with tensor-level compressed-tensor quantization (#8721)
This commit is contained in:
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_input_scale = None
|
layer.w13_input_scale = None
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||||
# Fp8 moe kernels require a single activation scale.
|
# Fp8 moe kernels require a single activation scale.
|
||||||
# We take the max of all the scales in case they differ.
|
# We take the max of all the scales in case they differ.
|
||||||
if self.static_input_scales:
|
if self.static_input_scales:
|
||||||
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.local_num_experts):
|
for expert_id in range(layer.num_local_experts):
|
||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
|
|||||||
Reference in New Issue
Block a user