Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (#6890)
This commit is contained in:
@@ -227,8 +227,8 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|||||||
output_dtype = input.dtype
|
output_dtype = input.dtype
|
||||||
dtype_supported = output_dtype == torch.bfloat16
|
dtype_supported = output_dtype == torch.bfloat16
|
||||||
|
|
||||||
# TODO: add more robust shape check here
|
# TODO: https://github.com/sgl-project/sglang/pull/6890#issuecomment-2943395737
|
||||||
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
shape_supported = weight.shape[0] % 64 == 0 and weight.shape[1] % 128 == 0
|
||||||
|
|
||||||
if not (shape_supported and dtype_supported):
|
if not (shape_supported and dtype_supported):
|
||||||
# fall back to triton
|
# fall back to triton
|
||||||
|
|||||||
Reference in New Issue
Block a user