[Kimi K2] dsv3_router_gemm supports NUM_EXPERTS == 384 (#8013)
This commit is contained in:
@@ -184,6 +184,7 @@ void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
|
||||
mat_b);
|
||||
}
|
||||
|
||||
// Template instantiations for DEFAULT_NUM_EXPERTS experts
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
@@ -231,3 +232,52 @@ template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>(
|
||||
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
|
||||
|
||||
Reference in New Issue
Block a user