add custom ascendc kernel vocabparallelembedding (#796)
This PR add custom ascendc kernel vocabparallelembedding support in vllm-ascend, related CMakeLists and setuptools is also added in this PR. pytest -s benchmarks/ops/ben_vocabparallelembedding.py pytest -s tests/ops/test_vocabparallelembedding.py --------- Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -99,6 +99,112 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
at::Tensor &input,
|
||||
const int64_t org_vocab_start_index,
|
||||
const int64_t org_vocab_end_index,
|
||||
const int64_t num_org_vocab_padding,
|
||||
const int64_t added_vocab_start_index,
|
||||
const int64_t added_vocab_end_index)
|
||||
/*
|
||||
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L161-L198
|
||||
Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
In order to support various loading methods, we ensure that LoRA-added
|
||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
||||
and place them in the same tensor.
|
||||
In this example, we will have the original vocab size = 1010,
|
||||
added vocab size = 16 and padding to 64. Therefore, the total
|
||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
||||
1024, add 16, and then pad to 1088).
|
||||
Therefore, the tensor format looks like the following:
|
||||
TP1, rank 0 (no sharding):
|
||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
||||
|
||||
TP2, rank 0:
|
||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
||||
Parameters:
|
||||
org_vocab_start_index //base embeddings start
|
||||
org_vocab_end_index //base embeddings end
|
||||
num_org_vocab_padding //base embeddings padding
|
||||
added_vocab_start_index //LoRA embeddings start
|
||||
added_vocab_end_index //LoRA embeddings end
|
||||
*/
|
||||
{
|
||||
// Input validation
|
||||
TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension");
|
||||
TORCH_CHECK(org_vocab_start_index >= 0, "org_vocab_start_index must be non-negative");
|
||||
TORCH_CHECK(org_vocab_end_index >= org_vocab_start_index, "org_vocab_end_index must be greater than org_vocab_start_index");
|
||||
TORCH_CHECK(num_org_vocab_padding >= 0, "num_org_vocab_padding must be non-negative");
|
||||
TORCH_CHECK(added_vocab_start_index >= org_vocab_end_index, "added_vocab_start_index must be greater than org_vocab_end_index");
|
||||
TORCH_CHECK(added_vocab_end_index >= added_vocab_start_index, "added_vocab_end_index must be greater than added_vocab_start_index");
|
||||
|
||||
// Get total number of elements
|
||||
int64_t size = input.numel();
|
||||
|
||||
// Create output tensors
|
||||
at::Tensor masked_input = at::empty_like(input);
|
||||
at::Tensor mask = at::empty_like(input).to(at::kBool);
|
||||
|
||||
// Get data pointers
|
||||
void *input_ptr = input.data_ptr();
|
||||
void *masked_input_ptr = masked_input.data_ptr();
|
||||
void *mask_ptr = mask.data_ptr();
|
||||
|
||||
// Get current stream
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
|
||||
// Get scalar type
|
||||
at::ScalarType scalar_type = input.scalar_type();
|
||||
|
||||
// Create and configure OpCommand
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("get_masked_input_and_mask");
|
||||
cmd.SetCustomHandler([scalar_type, size, stream,
|
||||
input_ptr, masked_input_ptr, mask_ptr,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
num_org_vocab_padding, added_vocab_start_index,
|
||||
added_vocab_end_index]() -> int {
|
||||
// Get platform info
|
||||
fe::PlatFormInfos platform_infos;
|
||||
int device_id = 0;
|
||||
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
|
||||
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
|
||||
uint32_t loop_cnt = (size + aivNum - 1) / aivNum;
|
||||
|
||||
// Call implementation
|
||||
get_masked_input_and_mask_impl(
|
||||
stream,
|
||||
input_ptr,
|
||||
masked_input_ptr,
|
||||
mask_ptr,
|
||||
org_vocab_start_index,
|
||||
org_vocab_end_index,
|
||||
num_org_vocab_padding,
|
||||
added_vocab_start_index,
|
||||
added_vocab_end_index,
|
||||
size,
|
||||
loop_cnt,
|
||||
aivNum);
|
||||
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return {masked_input, mask};
|
||||
}
|
||||
|
||||
void verify_tensor(std::string const& name, at::Tensor const& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
@@ -194,6 +300,16 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
|
||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||
|
||||
ops.def(
|
||||
"get_masked_input_and_mask(Tensor input, "
|
||||
" int org_vocab_start_index, "
|
||||
" int org_vocab_end_index, "
|
||||
" int num_org_vocab_padding, "
|
||||
" int added_vocab_start_index, "
|
||||
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
|
||||
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
|
||||
|
||||
ops.def(
|
||||
"advance_step_flashattn_ascendc(int num_seqs, int num_queries, int block_size,"
|
||||
" Tensor! input_tokens, Tensor! sampled_token_ids, Tensor! input_positions,"
|
||||
|
||||
Reference in New Issue
Block a user