Files
xc-llm-kunlun/vllm_kunlun/csrc/xops.h
2025-12-10 12:05:39 +08:00

241 lines
8.7 KiB
C++

#ifndef OPS_H
#define OPS_H
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
void rms_norm_xpu(torch::Tensor &output,
torch::Tensor &input,
torch::Tensor &weight,
double eps);
// inplace
void fused_add_rms_norm_xpu(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon);
void silu_and_mul_xpu(torch::Tensor &output,
torch::Tensor &input);
void quick_gelu_xpu(torch::Tensor &output,
torch::Tensor &input);
// neox && gptj
void rotary_embedding(torch::Tensor &positions,
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding(torch::Tensor &positions,
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int64_t rot_dim,
torch::Tensor& offsets);
// x = 16 // sizeof(cache dtype)
void paged_attention_v1_xpu(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
int64_t num_kv_heads,
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
torch::Tensor& seq_lens_host, // [num_seqs]
int64_t block_size,
int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, // [num_heads]
const std::string& kv_cache_dtype,
double k_scale,
double v_scale,
int64_t tp_rank, int64_t blocksparse_local_blocks, // no used but to keep same with vllm-offficial
int64_t blocksparse_vert_stride, int64_t blocksparse_block_size, // no used but to keep same with vllm-offficial
int64_t blocksparse_head_sliding_step // no used but to keep same with vllm-offficial
);
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype,
const double k_scale,
const double v_scale);
void flash_attention_context_vllm_xpu(
torch::Tensor& query, // [num_tokens, num_heads, head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads, head_size]
torch::Tensor& value, // [num_tokens, num_kv_heads, head_size]
torch::Tensor& out, // [num_tokens, num_heads, head_size]
torch::Tensor& seq_lod, // [batch_size + 1]
torch::Tensor& seq_lod_host, // [batch_size + 1]
int64_t max_seq_len,
int64_t max_kv_len,
double scale,
const c10::optional<torch::Tensor>& alibi_slopes, // [num_heads],
const c10::optional<torch::Tensor>& key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const c10::optional<torch::Tensor>& value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const c10::optional<torch::Tensor>& block_tables, // [num_seqs, max_num_blocks_per_seq]
const c10::optional<torch::Tensor>& kv_prefix_start_loc, // [lod of prefix]
const c10::optional<torch::Tensor>& kv_prefix_start_loc_host, // [lod of prefix]
const c10::optional<bool> is_causal // use causal mask or not, default true
);
void paged_attention_v2_xpu(
torch::Tensor &out,
torch::Tensor &exp_sums,
torch::Tensor &max_logits,
torch::Tensor &tmp_out,
torch::Tensor &query, // [num_seqs, num_heads, head_size]
torch::Tensor &
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor &
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
int64_t num_kv_heads,
double scale,
torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor &seq_lens, // [num_seqs]
torch::Tensor& seq_lens_host, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, // [num_heads]
const std::string &kv_cache_dtype, double k_scale, double v_scale,
int64_t tp_rank, int64_t blocksparse_local_blocks, // no used but to keep same with vllm-offficial
int64_t blocksparse_vert_stride, int64_t blocksparse_block_size, // no used but to keep same with vllm-offficial
int64_t blocksparse_head_sliding_step // no used but to keep same with vllm-offficial
);
void weight_only_quant_matmul_xpu(
torch::Tensor &x,
torch::Tensor &out,
torch::Tensor &qweight,
torch::Tensor &qscale
);
void multi_latent_attention_xpu(
torch::Tensor q,
torch::Tensor kv_rope_cache,
torch::Tensor out,
torch::Tensor block_tables,
torch::Tensor seq_lens,
double scale,
int64_t max_seq_len
);
void outplace_fused_experts_xpu(
torch::Tensor &hidden_states,
torch::Tensor &output,
torch::Tensor &w1,
torch::Tensor &w2,
torch::Tensor &topk_weights,
torch::Tensor &topk_ids
);
void outplace_fused_experts_sorted_xpu(
torch::Tensor &hidden_states,
torch::Tensor &output,
torch::Tensor &w1,
torch::Tensor &w2,
torch::Tensor &topk_weights,
torch::Tensor &topk_ids
);
void grouped_topk_xpu(torch::Tensor &router_logits,
torch::Tensor& score_bias,
torch::Tensor& topk_weight,
torch::Tensor& topk_ids,
double scale,
int64_t expert_group_num,
int64_t moe_topk_group,
int64_t moe_top_k);
void topk_softmax_xpu(torch::Tensor &topk_weights, /* [m, topk] */
torch::Tensor& topk_indices, /* [m, topk] */
torch::Tensor& token_expert_indices, /* no used in xpu */
torch::Tensor& gating_output /* [m, n] */
);
torch::Tensor weak_ref_tensor(torch::Tensor& tensor);
void dynamic_scaled_int8_quant_xpu(torch::Tensor &out,
torch::Tensor &x,
torch::Tensor &input_scale,
const c10::optional<torch::Tensor>& input_azp
);
void cutlass_scaled_mm_xpu(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void castte_xpu(
torch::Tensor& input, // [num_tokens, hidden_dim]
torch::Tensor& ouput, // [num_tokens, hidden_dim]
torch::Tensor& scale // [1]
);
void castte_per_token_xpu(
torch::Tensor& input, // [num_tokens, hidden_dim]
torch::Tensor& ouput, // [num_tokens, hidden_dim]
torch::Tensor& scale // [num_tokens]
);
void fc_fusion_castte_xpu(
torch::Tensor& x, // [num_tokens, in_dim]
torch::Tensor& ouput, // [num_tokens, out_dim]
torch::Tensor& x_scale, // [1]
torch::Tensor& qweight, // [out_dim, in_dim]
torch::Tensor& qscale, // [1]
const c10::optional<torch::Tensor>& bias // [out_dim]
);
void fc_fusion_castte_per_token_xpu(
torch::Tensor& x, // [num_tokens, in_dim]
torch::Tensor& ouput, // [num_tokens, out_dim]
torch::Tensor& x_scale, // [num_tokens]
torch::Tensor& qweight, // [out_dim, in_dim]
torch::Tensor& qscale, // [1]
const c10::optional<torch::Tensor>& bias // [out_dim]
);
// trival cutlass
bool cutlass_scaled_mm_supports_fp8_xpu(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8_xpu(int64_t cuda_device_capability);
void outplace_split_norm_rope_xpu(
torch::Tensor &qkv,
torch::Tensor &cos_sin_cache,
torch::Tensor &q_weight,
torch::Tensor &k_weight,
torch::Tensor &positions,
torch::Tensor &q_emb_out,
torch::Tensor &k_emb_out,
torch::Tensor &v_out,
const int64_t emb_batch_size,
const int64_t max_seqlen,
const int64_t head_num,
const int64_t kv_head_num,
const int64_t head_dim,
const int64_t rotary_dim
);
void moe_fc_int8(
torch::Tensor &hidden_states, // dtype : bfloat16
torch::Tensor &output,
torch::Tensor &w1,
torch::Tensor &w1_scale,
torch::Tensor &w2,
torch::Tensor &w2_scale,
torch::Tensor &topk_weights,
torch::Tensor &topk_ids
);
#endif // OPS_H