[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -184,13 +184,9 @@ def cutlass_fp4_group_mm(
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
out_dtype,
|
||||
device,
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
@@ -220,10 +216,10 @@ def cutlass_fp4_group_mm(
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
params["ab_strides"],
|
||||
params["c_strides"],
|
||||
params["problem_sizes"],
|
||||
params["expert_offsets"],
|
||||
params["blockscale_offsets"],
|
||||
)
|
||||
return c.to(dtype=out_dtype)
|
||||
|
||||
Reference in New Issue
Block a user