Add Custom Kernels For LoRA Performance (#1884)
### What this PR does / why we need it?
Add two custom kernels(bgmv_shrink and bgmv expand) to solve the
performance of LoRA
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
we add Unit Test file to test the custom ascendc kernel. See
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py and
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, the TTFT, TPOT and throughput have increased by
about 70%.
- vLLM version: v0.9.2
- vLLM main:
40d86ee412
---------
Signed-off-by: taoxudonghaha <justsheldon@163.com>
This commit is contained in:
@@ -199,6 +199,90 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
cmd.Run();
|
||||
return {masked_input, mask};
|
||||
}
|
||||
|
||||
void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, double scale)
|
||||
{
|
||||
at::ScalarType scalar_type = x.scalar_type();
|
||||
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
||||
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
||||
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
||||
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
||||
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
||||
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
|
||||
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
|
||||
"the first dimension of x, y, indices should be same");
|
||||
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* indices_ptr = indices.data_ptr();
|
||||
void* y_ptr = y.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int input_hidden_token = x.size(1);
|
||||
uint32_t lora_rank = y.size(1);
|
||||
float scale_f = static_cast<float>(scale);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("bgmv_shrink");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, input_hidden_token,
|
||||
lora_rank, scale_f]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, num_tokens_per_core,
|
||||
input_hidden_token, lora_rank, scale_f);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
|
||||
int64_t slice_offset, int64_t slice_size)
|
||||
{
|
||||
at::ScalarType scalar_type = y.scalar_type();
|
||||
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
||||
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
||||
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
||||
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
||||
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
||||
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
|
||||
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
|
||||
"the first dimension of x, y, indices should be same");
|
||||
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
|
||||
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
|
||||
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
|
||||
"slice_size + slice_offset should be smaller than the second dimension of y")
|
||||
|
||||
at::Tensor y_out = y;
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* indices_ptr = indices.data_ptr();
|
||||
void* y_ptr = y.data_ptr();
|
||||
void* y_out_ptr = y_out.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int lora_rank = x.size(1);
|
||||
int output_full_dim = y.size(1);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("bgmv_expand");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, lora_rank,
|
||||
slice_offset, slice_size, output_full_dim]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size,
|
||||
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
@@ -223,6 +307,14 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
" int added_vocab_start_index, "
|
||||
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
|
||||
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
|
||||
|
||||
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
|
||||
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
|
||||
|
||||
ops.def(
|
||||
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C)
|
||||
|
||||
Reference in New Issue
Block a user