Fix the bugs about operator registration by PyTorch Dispatcher (#2786)
**Background:**
There are two principles about operator registration in PyTorch
- The same namespace can be only registered once by `TORCH_LIBRARY`
- The operator signatures can be only registered once by `def`
Considering that all custom operators defined in the current repo are
only used by Ascend, instead of defining a common operator schema by
vLLM, all accelerators then follow this operator schema and complete the
implementation based on their respective hardware, which is conducive to
functional abstraction.
Therefore, we can rename the operator registration namespace to an
Ascend-specific namespace(**_C_ascend**).
Related ISSUE: https://github.com/vllm-project/vllm-ascend/issues/2742
- vLLM version: main
- vLLM main:
f592b3174b
Signed-off-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
@@ -141,7 +141,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
||||
Parameters:
|
||||
org_vocab_start_index //base embeddings start
|
||||
org_vocab_end_index //base embeddings end
|
||||
@@ -164,22 +164,22 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
// Create output tensors
|
||||
at::Tensor masked_input = at::empty_like(input);
|
||||
at::Tensor mask = at::empty_like(input).to(at::kBool);
|
||||
|
||||
|
||||
// Get data pointers
|
||||
void *input_ptr = input.data_ptr();
|
||||
void *masked_input_ptr = masked_input.data_ptr();
|
||||
void *mask_ptr = mask.data_ptr();
|
||||
|
||||
|
||||
// Get current stream
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
|
||||
|
||||
// Get scalar type
|
||||
at::ScalarType scalar_type = input.scalar_type();
|
||||
|
||||
|
||||
// Create and configure OpCommand
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("get_masked_input_and_mask");
|
||||
cmd.SetCustomHandler([scalar_type, size, stream,
|
||||
cmd.SetCustomHandler([scalar_type, size, stream,
|
||||
input_ptr, masked_input_ptr, mask_ptr,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
@@ -193,7 +193,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
get_masked_input_and_mask_impl(
|
||||
stream,
|
||||
input_ptr,
|
||||
masked_input_ptr,
|
||||
masked_input_ptr,
|
||||
mask_ptr,
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
@@ -203,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
size,
|
||||
loop_cnt,
|
||||
aiv_num);
|
||||
|
||||
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
@@ -320,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_shrink");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
|
||||
seq_len_ptr, seq_len_size, y_ptr,
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
|
||||
seq_len_ptr, seq_len_size, y_ptr,
|
||||
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
@@ -330,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
|
||||
y_ptr, batch_size,
|
||||
y_ptr, batch_size,
|
||||
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
||||
return 0;
|
||||
});
|
||||
@@ -367,7 +367,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_expand");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
@@ -375,7 +375,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||
return 0;
|
||||
});
|
||||
@@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
{
|
||||
// vLLM-Ascend custom ops
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
|
||||
@@ -40,7 +40,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
|
||||
at::Tensor &positions,
|
||||
at::Tensor &query,
|
||||
at::Tensor &key,
|
||||
int64_t head_size,
|
||||
int64_t head_size,
|
||||
at::Tensor &cos_sin_cache,
|
||||
bool is_neox) {
|
||||
auto num_tokens = positions.sym_numel();
|
||||
@@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
|
||||
} // namespace vllm_ascend
|
||||
|
||||
namespace {
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||
// the custom kernel been captured into aclgraph
|
||||
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
// Rotary embedding meta implementation
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
@@ -99,4 +99,4 @@ namespace {
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user