support cuda 13.0 and trtllm kernel by Aug 25 2025 (#9495)
This commit is contained in:
@@ -9,6 +9,7 @@ import jinja2
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -33,6 +34,17 @@ TEMPLATE = (
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
KERNEL_FILE_TEMPLATE = (
|
||||
"// auto generated by generate.py\n"
|
||||
"// clang-format off\n"
|
||||
"#pragma once\n\n"
|
||||
"{% for kernel_file in kernel_files %}"
|
||||
'#include "{{ kernel_file }}"\n'
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
KERNEL_FILE_NAME = "kernel_marlin.cuh"
|
||||
|
||||
# int8 with zero point case (sglang::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||
@@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
kernel_files = set()
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
@@ -95,10 +108,20 @@ def generate_new_kernels():
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_files.add(filename)
|
||||
|
||||
kernel_files = list(kernel_files)
|
||||
kernel_files.sort()
|
||||
|
||||
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
|
||||
kernel_files=kernel_files
|
||||
)
|
||||
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,5 +1,6 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,5 +1,6 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,5 +1,6 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,5 +1,6 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,5 +1,6 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
10
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Normal file
10
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Normal file
@@ -0,0 +1,10 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel_bf16_ku4.cuh"
|
||||
#include "kernel_bf16_ku4b8.cuh"
|
||||
#include "kernel_bf16_ku8b128.cuh"
|
||||
#include "kernel_fp16_ku4.cuh"
|
||||
#include "kernel_fp16_ku4b8.cuh"
|
||||
#include "kernel_fp16_ku8b128.cuh"
|
||||
@@ -18,6 +18,8 @@
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "kernel_marlin.cuh"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert( \
|
||||
|
||||
Reference in New Issue
Block a user