[Fix] Remove unused import from triton_kernels_moe.py (#11967)
Co-authored-by: Shangming Cai <171321666+shangmingcai@users.noreply.github.com>
This commit is contained in:
@@ -13,8 +13,7 @@ from triton_kernels.matmul_ogs import (
|
||||
PrecisionConfig,
|
||||
matmul_ogs,
|
||||
)
|
||||
from triton_kernels.numerics import InFlexData, MicroscalingCtx
|
||||
from triton_kernels.quantization import downcast_to_mxfp
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
|
||||
@@ -26,30 +25,6 @@ if TYPE_CHECKING:
|
||||
def quantize(w, dtype, dev, **opt):
|
||||
if dtype == "bf16":
|
||||
return w.to(torch.bfloat16), InFlexData()
|
||||
elif dtype == "fp8":
|
||||
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
|
||||
return (
|
||||
wq,
|
||||
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
|
||||
MicroscalingCtx(),
|
||||
)
|
||||
else:
|
||||
assert dtype == "mx4", f"{dtype=}"
|
||||
swizzle_mx_scale = opt["swizzle_mx_scale"]
|
||||
swizzle_axis = 2 if swizzle_mx_scale else None
|
||||
w = w.to(torch.bfloat16)
|
||||
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
|
||||
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
|
||||
)
|
||||
return (
|
||||
w,
|
||||
InFlexData(),
|
||||
MicroscalingCtx(
|
||||
weight_scale=mx_scales,
|
||||
swizzle_mx=swizzle_mx_scale,
|
||||
actual_weight_scale_shape=weight_scale_shape,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
|
||||
Reference in New Issue
Block a user