fix: resolve flashinfer 0.4.1 import (#11940)
This commit is contained in:
@@ -1061,8 +1061,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
):
|
):
|
||||||
from flashinfer import nvfp4_block_scale_interleave
|
from flashinfer import nvfp4_block_scale_interleave
|
||||||
from flashinfer.fused_moe.core import (
|
from flashinfer.fused_moe.core import (
|
||||||
_maybe_get_cached_w2_permute_indices,
|
|
||||||
_maybe_get_cached_w3_w1_permute_indices,
|
_maybe_get_cached_w3_w1_permute_indices,
|
||||||
|
get_w2_permute_indices_with_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||||
@@ -1123,7 +1123,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
permute_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
gemm2_weights_fp4[i].view(torch.uint8),
|
gemm2_weights_fp4[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
@@ -1134,7 +1134,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
|
||||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||||
self._cache_permute_indices,
|
self._cache_permute_indices,
|
||||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||||
epilogue_tile_m,
|
epilogue_tile_m,
|
||||||
|
|||||||
Reference in New Issue
Block a user