feat: add trtllm-gen mha from direct call (#8782)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -441,6 +441,23 @@ class ServerArgs:
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.attention_backend == "trtllm_mha":
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||
)
|
||||
|
||||
if self.page_size not in [16, 32, 64]:
|
||||
logger.warning(
|
||||
f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
|
||||
)
|
||||
self.page_size = 64
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
@@ -1275,6 +1292,7 @@ class ServerArgs:
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
|
||||
Reference in New Issue
Block a user