Disable all two stream overlap on amd (#6475)
This commit is contained in:
@@ -38,11 +38,17 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import debug_timing, get_compiler_backend
|
from sglang.srt.utils import (
|
||||||
|
debug_timing,
|
||||||
|
get_compiler_backend,
|
||||||
|
is_cuda,
|
||||||
|
next_power_of_2,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GB = 1024 * 1024 * 1024
|
GB = 1024 * 1024 * 1024
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
|
||||||
class ReqToTokenPool:
|
class ReqToTokenPool:
|
||||||
@@ -262,7 +268,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.capture_mode = False
|
self.capture_mode = False
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
self.alt_stream = self.device_module.Stream()
|
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
||||||
|
|
||||||
k_size, v_size = self.get_kv_size_bytes()
|
k_size, v_size = self.get_kv_size_bytes()
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -392,7 +398,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
cache_k = cache_k.view(self.store_dtype)
|
cache_k = cache_k.view(self.store_dtype)
|
||||||
cache_v = cache_v.view(self.store_dtype)
|
cache_v = cache_v.view(self.store_dtype)
|
||||||
|
|
||||||
if self.capture_mode and cache_k.shape[0] < 4:
|
if self.capture_mode and self.alt_stream is not None:
|
||||||
# Overlap the copy of K and V cache for small batch size
|
# Overlap the copy of K and V cache for small batch size
|
||||||
current_stream = self.device_module.current_stream()
|
current_stream = self.device_module.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
|||||||
@@ -76,13 +76,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.expert_distribution import (
|
from sglang.srt.managers.expert_distribution import (
|
||||||
ExpertDistributionRecorder,
|
|
||||||
get_global_expert_distribution_recorder,
|
get_global_expert_distribution_recorder,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.operations import execute_operations
|
from sglang.srt.operations import execute_operations
|
||||||
from sglang.srt.operations_strategy import compute_layer_operations
|
from sglang.srt.operations_strategy import compute_layer_operations
|
||||||
@@ -1321,8 +1320,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||||
)
|
)
|
||||||
# TODO(haishaw): multi-stream performance on ROCm
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||||
self.alt_stream = None if _is_hip else torch.cuda.Stream()
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DeepseekV2DecoderLayer(
|
DeepseekV2DecoderLayer(
|
||||||
|
|||||||
@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
PPProxyTensors,
|
PPProxyTensors,
|
||||||
)
|
)
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||||
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
from sglang.srt.utils import (
|
||||||
|
add_prefix,
|
||||||
|
fast_topk,
|
||||||
|
get_compiler_backend,
|
||||||
|
is_cuda,
|
||||||
|
make_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
|
|||||||
return out_aD
|
return out_aD
|
||||||
|
|
||||||
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
||||||
if hidden_states.shape[0] < 4:
|
if hidden_states.shape[0] < 4 and _is_cuda:
|
||||||
return self._forward_core_shared_routed_overlap(hidden_states)
|
return self._forward_core_shared_routed_overlap(hidden_states)
|
||||||
else:
|
else:
|
||||||
return self._forward_core_normal(hidden_states)
|
return self._forward_core_normal(hidden_states)
|
||||||
|
|||||||
Reference in New Issue
Block a user