Optional extension for green context (#9231)

This commit is contained in:
Liangsheng Yin
2025-08-15 21:33:52 +08:00
committed by GitHub
parent c186feed7f
commit 0c8594e67d
6 changed files with 73 additions and 20 deletions

View File

@@ -433,12 +433,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
/*
* From csrc/spatial
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
}
REGISTER_EXTENSION(common_ops)

View File

@@ -42,6 +42,7 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
// This symbol is introduced in CUDA 12.5
const static auto pfn = probe_cuGreenCtxStreamCreate();
if (!pfn) {
TORCH_WARN("cuGreenCtxStreamCreate(cuda>=12.5) is not available, using fallback");
return create_greenctx_stream_fallback(gctx);
}
@@ -55,17 +56,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION));
if (CUDA_DRIVER_VERSION < 12040) {
TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
}
CUgreenCtx gctx[3];
CUdevResourceDesc desc[3];
CUdevResource input;
CUdevResource resources[4];
if (smA <= 0 || smB <= 0) {
TORCH_CHECK(false, "SM counts must be positive");
}
TORCH_CHECK(smA > 0 && smB > 0, "SM counts must be positive");
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));

View File

@@ -0,0 +1,29 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/spatial
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
}
REGISTER_EXTENSION(spatial_ops)