适配qwen3-next
This commit is contained in:
@@ -396,7 +396,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
def _forward_input_proj(self, hidden_states: torch.Tensor):
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
|
||||
seq_len, _ = hidden_states.shape
|
||||
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
|
||||
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and self.alt_stream is not None:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user