Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

3715
csrc/rocm/attention.cu Normal file

File diff suppressed because it is too large Load Diff

26
csrc/rocm/ops.h Normal file
View File

@@ -0,0 +1,26 @@
#pragma once
#include <torch/all.h>
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount);
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale,
const std::string& mfma_type);

1810
csrc/rocm/skinny_gemms.cu Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
#include "core/registration.h"
#include "rocm/ops.h"
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// vLLM custom ops for rocm
// Custom gemm op for matrix-vector multiplication
rocm_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
"Tensor");
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
"Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8
rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
"Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
// Custom attention op
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
rocm_ops.def(
"paged_attention(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor seq_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_seq_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale,"
" str mfma_type) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)