fix dual stream bug (#10352)
This commit is contained in:
@@ -62,6 +62,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
||||||
@@ -194,7 +195,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states.clone())
|
||||||
|
|
||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
router_output = self._forward_router_experts(hidden_states)
|
router_output = self._forward_router_experts(hidden_states)
|
||||||
@@ -217,6 +218,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
self.alt_stream is not None
|
self.alt_stream is not None
|
||||||
and hidden_states.shape[0] > 0
|
and hidden_states.shape[0] > 0
|
||||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
|
and get_is_capture_mode()
|
||||||
):
|
):
|
||||||
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
||||||
hidden_states
|
hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user