[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
This commit is contained in:
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 32768,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
dual_chunk_attention_config = getattr(
|
||||
config, "dual_chunk_attention_config", None
|
||||
)
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
|
||||
Reference in New Issue
Block a user