[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
PGFLMG
2025-08-28 03:01:30 +08:00
committed by GitHub
parent 07ee0ab750
commit aa3eba8eb4
25 changed files with 210 additions and 383 deletions

View File

@@ -9,7 +9,6 @@ import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
@@ -34,17 +33,6 @@ TEMPLATE = (
"( MARLIN_KERNEL_PARAMS );"
)
KERNEL_FILE_TEMPLATE = (
"// auto generated by generate.py\n"
"// clang-format off\n"
"#pragma once\n\n"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"\n'
"{% endfor %}"
)
KERNEL_FILE_NAME = "kernel_marlin.cuh"
# int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
@@ -60,12 +48,11 @@ DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
kernel_files = set()
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
@@ -108,20 +95,10 @@ def generate_new_kernels():
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
kernel_files.add(filename)
kernel_files = list(kernel_files)
kernel_files.sort()
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
kernel_files=kernel_files
)
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
f.write(file_content)
if __name__ == "__main__":