Adjust flashinfer workspace size for Qwen2 models (#2879)

This commit is contained in:
Ke Bao
2025-01-14 13:34:22 +08:00
committed by GitHub
parent 80002562a8
commit c19d84829c

View File

@@ -84,6 +84,10 @@ class FlashInferAttnBackend(AttentionBackend):
self.num_wrappers = 1
self.dispatch_reason = None
# Qwen2 models require higher flashinfer workspace size
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
# Allocate buffers
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,