use default for torch.ops (#4835)
This commit is contained in:
@@ -12,49 +12,49 @@ if torch.version.hip is not None:
|
|||||||
rank: int,
|
rank: int,
|
||||||
full_nvlink: bool,
|
full_nvlink: bool,
|
||||||
) -> int:
|
) -> int:
|
||||||
return torch.ops.sgl_kernel.init_custom_ar(
|
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||||
)
|
)
|
||||||
|
|
||||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||||
torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out)
|
torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
|
||||||
|
|
||||||
def all_reduce_unreg(
|
def all_reduce_unreg(
|
||||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out)
|
torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
|
||||||
|
|
||||||
def dispose(fa: int) -> None:
|
def dispose(fa: int) -> None:
|
||||||
torch.ops.sgl_kernel.dispose(fa)
|
torch.ops.sgl_kernel.dispose.default(fa)
|
||||||
|
|
||||||
def meta_size() -> int:
|
def meta_size() -> int:
|
||||||
return torch.ops.sgl_kernel.meta_size()
|
return torch.ops.sgl_kernel.meta_size.default()
|
||||||
|
|
||||||
def register_buffer(
|
def register_buffer(
|
||||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets)
|
return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
|
||||||
|
|
||||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||||
|
|
||||||
def register_graph_buffers(
|
def register_graph_buffers(
|
||||||
fa: int, handles: List[str], offsets: List[List[int]]
|
fa: int, handles: List[str], offsets: List[List[int]]
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||||
|
|
||||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||||
return torch.ops.sgl_kernel.allocate_meta_buffer(size)
|
return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
|
||||||
|
|
||||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp)
|
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TRTLLM custom allreduce
|
# TRTLLM custom allreduce
|
||||||
def init_custom_reduce(
|
def init_custom_reduce(
|
||||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||||
):
|
):
|
||||||
return torch.ops.sgl_kernel.init_custom_ar(
|
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||||
rank_id,
|
rank_id,
|
||||||
num_devices,
|
num_devices,
|
||||||
rank_data,
|
rank_data,
|
||||||
@@ -65,13 +65,13 @@ else:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def custom_dispose(fa):
|
def custom_dispose(fa):
|
||||||
torch.ops.sgl_kernel.dispose(fa)
|
torch.ops.sgl_kernel.dispose.default(fa)
|
||||||
|
|
||||||
def custom_reduce(fa, inp, out):
|
def custom_reduce(fa, inp, out):
|
||||||
torch.ops.sgl_kernel.all_reduce(fa, inp, out)
|
torch.ops.sgl_kernel.all_reduce.default(fa, inp, out)
|
||||||
|
|
||||||
def get_graph_buffer_ipc_meta(fa):
|
def get_graph_buffer_ipc_meta(fa):
|
||||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||||
|
|
||||||
def register_graph_buffers(fa, handles, offsets):
|
def register_graph_buffers(fa, handles, offsets):
|
||||||
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||||
|
|||||||
@@ -2,6 +2,6 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||||
torch.ops.sgl_kernel.lightning_attention_decode(
|
torch.ops.sgl_kernel.lightning_attention_decode.default(
|
||||||
q, k, v, past_kv, slope, output, new_kv
|
q, k, v, past_kv, slope, output, new_kv
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,14 +14,14 @@ def rmsnorm(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def fused_add_rmsnorm(
|
def fused_add_rmsnorm(
|
||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
|
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
|
||||||
|
|
||||||
|
|
||||||
def gemma_rmsnorm(
|
def gemma_rmsnorm(
|
||||||
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
torch.ops.sgl_kernel.gemma_rmsnorm.default(
|
||||||
|
out, input, weight, eps, get_cuda_stream()
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def gemma_fused_add_rmsnorm(
|
def gemma_fused_add_rmsnorm(
|
||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
|
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||||
input, residual, weight, eps, get_cuda_stream()
|
input, residual, weight, eps, get_cuda_stream()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|||||||
device=input.device,
|
device=input.device,
|
||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
)
|
)
|
||||||
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
|
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
|||||||
device=input.device,
|
device=input.device,
|
||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
)
|
)
|
||||||
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|||||||
device=input.device,
|
device=input.device,
|
||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
)
|
)
|
||||||
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
|
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
|||||||
if cos_sin_cache.dtype != torch.float32:
|
if cos_sin_cache.dtype != torch.float32:
|
||||||
raise ValueError("cos_sin_cache should be float32")
|
raise ValueError("cos_sin_cache should be float32")
|
||||||
|
|
||||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||||
q=query.view(query.shape[0], -1, head_size),
|
q=query.view(query.shape[0], -1, head_size),
|
||||||
k=key.view(key.shape[0], -1, head_size),
|
k=key.view(key.shape[0], -1, head_size),
|
||||||
q_rope=query.view(query.shape[0], -1, head_size),
|
q_rope=query.view(query.shape[0], -1, head_size),
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
|||||||
def awq_dequantize(
|
def awq_dequantize(
|
||||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||||
) -> torch.ByteTensor:
|
) -> torch.ByteTensor:
|
||||||
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
|
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
|
||||||
|
|
||||||
|
|
||||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
return torch.ops.sgl_kernel.int8_scaled_mm(
|
return torch.ops.sgl_kernel.int8_scaled_mm.default(
|
||||||
mat_a,
|
mat_a,
|
||||||
mat_b,
|
mat_b,
|
||||||
scales_a,
|
scales_a,
|
||||||
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
|
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
|
||||||
mat_a,
|
mat_a,
|
||||||
mat_b,
|
mat_b,
|
||||||
scales_a,
|
scales_a,
|
||||||
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
return torch.ops.sgl_kernel.fp8_scaled_mm(
|
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
|
||||||
mat_a,
|
mat_a,
|
||||||
mat_b,
|
mat_b,
|
||||||
scales_a,
|
scales_a,
|
||||||
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
|
|||||||
B_scale: torch.Tensor,
|
B_scale: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
cublas_handle = torch.cuda.current_blas_handle()
|
cublas_handle = torch.cuda.current_blas_handle()
|
||||||
torch.ops.sgl_kernel.bmm_fp8(
|
torch.ops.sgl_kernel.bmm_fp8.default(
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
D,
|
D,
|
||||||
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
|
|||||||
fp8_min: float,
|
fp8_min: float,
|
||||||
fp8_max: float,
|
fp8_max: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
|
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
|
|||||||
int8_min: float,
|
int8_min: float,
|
||||||
int8_max: float,
|
int8_max: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8(
|
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
|
||||||
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
|
|||||||
output_s: torch.Tensor,
|
output_s: torch.Tensor,
|
||||||
is_static: bool,
|
is_static: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
|
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
|
||||||
|
input, output_q, output_s, is_static
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cublas_grouped_gemm(
|
def cublas_grouped_gemm(
|
||||||
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
|
|||||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||||
), "Inputs/weights/outputs should not be empty!"
|
), "Inputs/weights/outputs should not be empty!"
|
||||||
cublas_handle = torch.cuda.current_blas_handle()
|
cublas_handle = torch.cuda.current_blas_handle()
|
||||||
torch.ops.sgl_kernel.cublas_grouped_gemm(
|
torch.ops.sgl_kernel.cublas_grouped_gemm.default(
|
||||||
inputs,
|
inputs,
|
||||||
weights,
|
weights,
|
||||||
outputs,
|
outputs,
|
||||||
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
|
|||||||
output_q: torch.Tensor,
|
output_q: torch.Tensor,
|
||||||
output_s: torch.Tensor,
|
output_s: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
|
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_scaled_fp4_mm(
|
def cutlass_scaled_fp4_mm(
|
||||||
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
|
|||||||
assert a.ndim == 2 and b.ndim == 2
|
assert a.ndim == 2 and b.ndim == 2
|
||||||
m, n = a.shape[0], b.shape[0]
|
m, n = a.shape[0], b.shape[0]
|
||||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||||
torch.ops.sgl_kernels.cutlass_scaled_fp4_mm(
|
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
|
||||||
out, a, b, block_scale_a, block_scale_b, alpha
|
out, a, b, block_scale_a, block_scale_b, alpha
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
|
|||||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.ops.sgl_kernels.scaled_fp4_quant(
|
torch.ops.sgl_kernel.scaled_fp4_quant.default(
|
||||||
output, input, output_scale, input_global_scale
|
output, input, output_scale, input_global_scale
|
||||||
)
|
)
|
||||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ def moe_align_block_size(
|
|||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
):
|
):
|
||||||
torch.ops.sgl_kernel.moe_align_block_size(
|
torch.ops.sgl_kernel.moe_align_block_size.default(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -29,6 +29,6 @@ def topk_softmax(
|
|||||||
token_expert_indices: torch.Tensor,
|
token_expert_indices: torch.Tensor,
|
||||||
gating_output: float,
|
gating_output: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.topk_softmax(
|
torch.ops.sgl_kernel.topk_softmax.default(
|
||||||
topk_weights, topk_ids, token_expert_indices, gating_output
|
topk_weights, topk_ids, token_expert_indices, gating_output
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
|
|||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
renorm_probs = torch.empty_like(probs)
|
renorm_probs = torch.empty_like(probs)
|
||||||
torch.ops.sgl_kernel.top_k_renorm_probs(
|
torch.ops.sgl_kernel.top_k_renorm_probs.default(
|
||||||
probs,
|
probs,
|
||||||
renorm_probs,
|
renorm_probs,
|
||||||
maybe_top_k_arr,
|
maybe_top_k_arr,
|
||||||
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
|
|||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
renorm_probs = torch.empty_like(probs)
|
renorm_probs = torch.empty_like(probs)
|
||||||
torch.ops.sgl_kernel.top_p_renorm_probs(
|
torch.ops.sgl_kernel.top_p_renorm_probs.default(
|
||||||
probs,
|
probs,
|
||||||
renorm_probs,
|
renorm_probs,
|
||||||
maybe_top_p_arr,
|
maybe_top_p_arr,
|
||||||
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
|
|||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||||
torch.ops.sgl_kernel.top_p_sampling_from_probs(
|
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
samples,
|
samples,
|
||||||
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
|
|||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||||
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
|
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
samples,
|
samples,
|
||||||
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
|
|||||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
torch.ops.sgl_kernel.min_p_sampling_from_probs(
|
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
samples,
|
samples,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
|
|||||||
threshold_acc: float = 1.0,
|
threshold_acc: float = 1.0,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
|
||||||
predicts,
|
predicts,
|
||||||
accept_index,
|
accept_index,
|
||||||
accept_token_num,
|
accept_token_num,
|
||||||
@@ -45,7 +45,7 @@ def verify_tree_greedy(
|
|||||||
retrive_next_sibling: torch.Tensor,
|
retrive_next_sibling: torch.Tensor,
|
||||||
target_predict: torch.Tensor,
|
target_predict: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.verify_tree_greedy(
|
torch.ops.sgl_kernel.verify_tree_greedy.default(
|
||||||
predicts,
|
predicts,
|
||||||
accept_index,
|
accept_index,
|
||||||
accept_token_num,
|
accept_token_num,
|
||||||
@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
|
|||||||
depth: int,
|
depth: int,
|
||||||
draft_token_num: int,
|
draft_token_num: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.build_tree_kernel_efficient(
|
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
|
||||||
parent_list,
|
parent_list,
|
||||||
selected_index,
|
selected_index,
|
||||||
verified_seq_len,
|
verified_seq_len,
|
||||||
@@ -92,7 +92,7 @@ def segment_packbits(
|
|||||||
output_indptr: torch.Tensor,
|
output_indptr: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.segment_packbits(
|
torch.ops.sgl_kernel.segment_packbits.default(
|
||||||
x,
|
x,
|
||||||
input_indptr,
|
input_indptr,
|
||||||
output_indptr,
|
output_indptr,
|
||||||
|
|||||||
Reference in New Issue
Block a user