Add Custom Kernels For LoRA Performance (#2325)
### What this PR does / why we need it?
Add two custom operators (sgmv_shrink and sgmv_expand) to address the
performance issues of LoRA. Meanwhile, enable the graph mode for LoRA
operators to enter ACL, so as to improve the model inference
performance.
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, in acl graph mode, the TTFT, TPOT and throughput
have increased by about 100%.
Signed-off-by: liuchn <909698896@qq.com>
- vLLM version: v0.10.0
- vLLM main:
1f83e7d849
---------
Signed-off-by: liuchn <909698896@qq.com>
Co-authored-by: liuchn <909698896@qq.com>
This commit is contained in:
@@ -294,6 +294,87 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
|
||||
void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
||||
at::Tensor &y, double scale)
|
||||
{
|
||||
at::ScalarType scalar_type = x.scalar_type();
|
||||
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
||||
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
||||
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
||||
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
||||
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
||||
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* lora_indices_ptr = lora_indices.data_ptr();
|
||||
void* seq_len_ptr = seq_len.data_ptr();
|
||||
void* y_ptr = y.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int input_hidden_token = x.size(1);
|
||||
uint32_t lora_rank = y.size(1);
|
||||
float scale_f = static_cast<float>(scale);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_shrink");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr,
|
||||
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, batch_size,
|
||||
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
||||
at::Tensor &y, int64_t slice_offset, int64_t slice_size)
|
||||
{
|
||||
at::ScalarType scalar_type = y.scalar_type();
|
||||
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
||||
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
||||
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
||||
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
||||
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
||||
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
|
||||
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
|
||||
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
|
||||
"slice_size + slice_offset should be smaller than the second dimension of y")
|
||||
|
||||
at::Tensor y_out = y;
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* lora_indices_ptr = lora_indices.data_ptr();
|
||||
void* seq_len_ptr = seq_len.data_ptr();
|
||||
void* y_ptr = y.data_ptr();
|
||||
void* y_out_ptr = y_out.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int lora_rank = x.size(1);
|
||||
int output_full_dim = y.size(1);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_expand");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr,
|
||||
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr,
|
||||
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
@@ -326,6 +407,14 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
|
||||
|
||||
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
|
||||
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
|
||||
|
||||
ops.def(
|
||||
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C)
|
||||
|
||||
Reference in New Issue
Block a user