Add greedy verification kernel (#4383)

This commit is contained in:
Ying Sheng
2025-03-16 00:58:26 -07:00
committed by GitHub
parent 06d12b39d3
commit 52a34d7448
11 changed files with 394 additions and 153 deletions

View File

@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"
#include "speculative_sampling.cuh"
@@ -40,7 +39,9 @@ void tree_speculative_sampling_target_only(
at::Tensor uniform_samples,
at::Tensor target_probs,
at::Tensor draft_probs,
bool deterministic,
double threshold_single,
double threshold_acc,
bool deterministic = true,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
@@ -112,6 +113,10 @@ void tree_speculative_sampling_target_only(
if (draft_probs.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
}
CHECK_GE(threshold_single, 0);
CHECK_GE(1, threshold_single);
CHECK_GE(threshold_acc, 0);
CHECK_GE(1, threshold_acc);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
@@ -129,6 +134,8 @@ void tree_speculative_sampling_target_only(
num_spec_step,
num_draft_tokens,
vocab_size,
static_cast<float>(threshold_single),
static_cast<float>(threshold_acc),
deterministic,
stream);