init v0.11.0rc0
This commit is contained in:
@@ -40,7 +40,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
|
||||
at::Tensor &positions,
|
||||
at::Tensor &query,
|
||||
at::Tensor &key,
|
||||
int64_t head_size,
|
||||
int64_t head_size,
|
||||
at::Tensor &cos_sin_cache,
|
||||
bool is_neox) {
|
||||
auto num_tokens = positions.sym_numel();
|
||||
@@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
|
||||
} // namespace vllm_ascend
|
||||
|
||||
namespace {
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
// Rotary embedding meta implementation
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
@@ -99,4 +99,4 @@ namespace {
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user