[qwen3 next ]add ascend c casual_conv1d_fn (#6661)
### What this PR does / why we need it?
add ascend c casual_conv1d_fn
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: ZT-AIA <1028681969@qq.com>
Signed-off-by: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -597,6 +597,44 @@ void transpose_kv_cache_by_block(
|
||||
|
||||
}
|
||||
|
||||
at::Tensor causal_conv1d_fn(
|
||||
const at::Tensor& mixed_qkv_non_spec_T,
|
||||
const at::Tensor& conv_weights,
|
||||
const c10::optional<at::Tensor>& bias_opt,
|
||||
c10::string_view activation,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& has_initial_state,
|
||||
const at::Tensor& non_spec_state_indices_tensor,
|
||||
const at::Tensor& non_spec_query_start_loc,
|
||||
int64_t pad_slot_id)
|
||||
{
|
||||
at::Tensor x=mixed_qkv_non_spec_T; //不需要转置
|
||||
at::Tensor weight=conv_weights;//不需要转置
|
||||
c10::optional<at::Tensor> biasOptional =bias_opt;
|
||||
at::Tensor convStates= conv_state;
|
||||
at::Tensor queryStartLoc=non_spec_query_start_loc;
|
||||
at::Tensor cacheIndices=non_spec_state_indices_tensor;
|
||||
at::Tensor hasInitialState=has_initial_state;
|
||||
int64_t activationMode=(activation.empty()?0:1);
|
||||
int64_t padSlotId=pad_slot_id;
|
||||
|
||||
at::Tensor output = at::empty(mixed_qkv_non_spec_T.sizes(), mixed_qkv_non_spec_T.options());
|
||||
EXEC_NPU_CMD(aclnnCausalConv1d,
|
||||
x,
|
||||
weight,
|
||||
biasOptional,
|
||||
convStates,
|
||||
queryStartLoc,
|
||||
cacheIndices,
|
||||
hasInitialState,
|
||||
activationMode,
|
||||
padSlotId,
|
||||
output
|
||||
);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
|
||||
std::vector<at::Tensor> moe_grouped_matmul(
|
||||
at::Tensor x,
|
||||
@@ -811,6 +849,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
|
||||
);
|
||||
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
|
||||
// causal_conv1d_fn
|
||||
ops.def(
|
||||
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "
|
||||
" Tensor conv_weights, "
|
||||
" Tensor? bias_opt, "
|
||||
" str activation, "
|
||||
" Tensor conv_state, "
|
||||
" Tensor has_initial_state, "
|
||||
" Tensor non_spec_state_indices_tensor, "
|
||||
" Tensor non_spec_query_start_loc, "
|
||||
" int pad_slot_id) -> (Tensor output)");
|
||||
ops.impl("causal_conv1d_fn", torch::kPrivateUse1, &vllm_ascend::causal_conv1d_fn);
|
||||
ops.def(
|
||||
"moe_grouped_matmul("
|
||||
"Tensor x,"
|
||||
|
||||
Reference in New Issue
Block a user