| import torch |
| import torch.nn as nn |
|
|
| |
| from unet_transformer import RMSNorm, TransformerBlock |
|
|
|
|
| class BaselineTransformer(nn.Module): |
| """ |
| A vanilla stacked transformer baseline for comparison with UNetTransformer. |
| |
| Uses the exact same building blocks (RMSNorm, rotary MHA, SwiGLU MLP) as |
| unet_transformer.py, but without the downsample/upsample path. |
| """ |
|
|
| def __init__(self, vocab_size, dim=512, num_heads=8, mlp_ratio=4.0, dropout=0.1, num_layers=12): |
| super().__init__() |
|
|
| self.token_emb = nn.Embedding(vocab_size, dim) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.layers = nn.ModuleList( |
| [TransformerBlock(dim, num_heads, mlp_ratio, dropout) for _ in range(num_layers)] |
| ) |
|
|
| self.norm = RMSNorm(dim) |
| self.head = nn.Linear(dim, vocab_size, bias=False) |
|
|
| |
| self.head.weight = self.token_emb.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, x, mask=None): |
| |
| x = self.token_emb(x) |
| x = self.dropout(x) |
|
|
| for block in self.layers: |
| x = block(x, mask) |
|
|
| x = self.norm(x) |
| logits = self.head(x) |
| return logits |
|
|
|
|
| if __name__ == "__main__": |
| |
| model = BaselineTransformer( |
| vocab_size=256, |
| dim=512, |
| num_heads=8, |
| mlp_ratio=4.0, |
| dropout=0.1, |
| num_layers=12, |
| ) |
| dummy = torch.randint(0, 256, (2, 1024)) |
| mask = torch.tril(torch.ones(1, 1, 1024, 1024, dtype=torch.uint8)) |
| out = model(dummy, mask=mask) |
| print(out.shape) |
|
|