Add greedy verification kernel (#4383)
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
#include "speculative_sampling.cuh"
|
||||
|
||||
@@ -40,7 +39,9 @@ void tree_speculative_sampling_target_only(
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
bool deterministic,
|
||||
double threshold_single,
|
||||
double threshold_acc,
|
||||
bool deterministic = true,
|
||||
int64_t cuda_stream = 0) {
|
||||
CHECK_INPUT(candidates);
|
||||
CHECK_INPUT(retrive_index);
|
||||
@@ -112,6 +113,10 @@ void tree_speculative_sampling_target_only(
|
||||
if (draft_probs.scalar_type() != at::kFloat) {
|
||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||
}
|
||||
CHECK_GE(threshold_single, 0);
|
||||
CHECK_GE(1, threshold_single);
|
||||
CHECK_GE(threshold_acc, 0);
|
||||
CHECK_GE(1, threshold_acc);
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
||||
@@ -129,6 +134,8 @@ void tree_speculative_sampling_target_only(
|
||||
num_spec_step,
|
||||
num_draft_tokens,
|
||||
vocab_size,
|
||||
static_cast<float>(threshold_single),
|
||||
static_cast<float>(threshold_acc),
|
||||
deterministic,
|
||||
stream);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user