fix: zero_init buffer (#9065)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -63,7 +63,7 @@ srt = [
|
|||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.2.11.post1",
|
"flashinfer_python==0.2.11.post3",
|
||||||
]
|
]
|
||||||
|
|
||||||
blackwell = [
|
blackwell = [
|
||||||
@@ -73,7 +73,7 @@ blackwell = [
|
|||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.2.11.post1",
|
"flashinfer_python==0.2.11.post3",
|
||||||
]
|
]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
|
|||||||
@@ -647,7 +647,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if server_args.attention_backend == "flashinfer":
|
if server_args.attention_backend == "flashinfer":
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer_python",
|
"flashinfer_python",
|
||||||
"0.2.11.post1",
|
"0.2.11.post3",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
global global_workspace_buffer
|
global global_workspace_buffer
|
||||||
if global_workspace_buffer is None:
|
if global_workspace_buffer is None:
|
||||||
|
# different from flashinfer zero_init_global_workspace_buffer
|
||||||
global_workspace_buffer = torch.empty(
|
global_workspace_buffer = torch.empty(
|
||||||
global_config.flashinfer_workspace_size,
|
global_config.flashinfer_workspace_size,
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
global global_workspace_buffer
|
global global_workspace_buffer
|
||||||
if global_workspace_buffer is None:
|
if global_workspace_buffer is None:
|
||||||
|
# different from flashinfer zero_init_global_workspace_buffer
|
||||||
global_workspace_buffer = torch.empty(
|
global_workspace_buffer = torch.empty(
|
||||||
global_config.flashinfer_workspace_size,
|
global_config.flashinfer_workspace_size,
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
|
|||||||
@@ -23,10 +23,12 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
DEFAULT_WORKSPACE_SIZE_MB = (
|
||||||
|
512 # Memory workspace size in MB, todo(Yingyi): read from config
|
||||||
|
)
|
||||||
|
|
||||||
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
||||||
global_workspace_buffer = None
|
global_zero_init_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|||||||
# Workspace allocation
|
# Workspace allocation
|
||||||
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
global global_workspace_buffer
|
global global_zero_init_workspace_buffer
|
||||||
if global_workspace_buffer is None:
|
if global_zero_init_workspace_buffer is None:
|
||||||
global_workspace_buffer = torch.empty(
|
global_zero_init_workspace_buffer = torch.zeros(
|
||||||
self.workspace_size,
|
self.workspace_size,
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
self.workspace_buffer = global_workspace_buffer
|
self.workspace_buffer = global_zero_init_workspace_buffer
|
||||||
|
|
||||||
# CUDA graph state
|
# CUDA graph state
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|||||||
# compute the LCM with other padding constraints.
|
# compute the LCM with other padding constraints.
|
||||||
TRTLLM_BLOCK_CONSTRAINT = 128
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
||||||
|
|
||||||
|
global_zero_init_workspace_buffer = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TRTLLMMLADecodeMetadata:
|
class TRTLLMMLADecodeMetadata:
|
||||||
@@ -83,9 +85,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
# Workspace allocation
|
# Workspace allocation
|
||||||
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
||||||
self.workspace_buffer = torch.empty(
|
global global_zero_init_workspace_buffer
|
||||||
self.workspace_size, dtype=torch.int8, device=self.device
|
if global_zero_init_workspace_buffer is None:
|
||||||
)
|
global_zero_init_workspace_buffer = torch.zeros(
|
||||||
|
self.workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
self.workspace_buffer = global_zero_init_workspace_buffer
|
||||||
|
|
||||||
# CUDA graph state
|
# CUDA graph state
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
|||||||
@@ -143,4 +143,4 @@
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,4 +143,4 @@
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,4 +143,4 @@
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,4 +143,4 @@
|
|||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user