efittschen commited on
Commit
1205cc7
·
verified ·
1 Parent(s): b83ee0f

Upload MuonGPTForCausalLM

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +18 -0
  3. generation_config.json +4 -0
  4. model.safetensors +3 -0
  5. modeling_nano_gpt.py +353 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MuonGPTForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_nano_gpt.MuonGPTConfig",
7
+ "AutoModelForCausalLM": "modeling_nano_gpt.MuonGPTForCausalLM"
8
+ },
9
+ "block_size": 128,
10
+ "eos_token_id": 50256,
11
+ "model_dim": 768,
12
+ "model_type": "muon-gpt",
13
+ "num_heads": 6,
14
+ "num_layers": 12,
15
+ "torch_dtype": "float32",
16
+ "transformers_version": "4.51.3",
17
+ "vocab_size": 16000
18
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.51.3"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe82bfa4d18b52e61775f74b670aee2913e815648fe02373087de180ec905cfa
3
+ size 576069056
modeling_nano_gpt.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn as nn, torch.nn.functional as F
2
+ from dataclasses import dataclass
3
+ from torch import Tensor, nn
4
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention
5
+
6
+ def lm_head_plain(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
7
+ return F.linear(x.to(torch.bfloat16), w.to(torch.bfloat16))
8
+
9
+ def norm(x):
10
+ return F.rms_norm(x, (x.size(-1),))
11
+
12
+
13
+ class CastedLinear(nn.Linear):
14
+ def __init__(self, in_features: int, out_features: int):
15
+ super().__init__(in_features, out_features, bias=False)
16
+
17
+ def reset_parameters(self) -> None:
18
+ std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3)
19
+ bound = (3 ** 0.5) * std
20
+ with torch.no_grad():
21
+ self.weight.uniform_(-bound, bound)
22
+
23
+ def forward(self, x):
24
+ return F.linear(x, self.weight.type_as(x))
25
+
26
+ class Rotary(nn.Module):
27
+ def __init__(self, dim: int, max_seq_len=65536):
28
+ super().__init__()
29
+ # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
30
+ angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
31
+ angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
32
+ t = torch.arange(max_seq_len, dtype=torch.float32)
33
+ theta = torch.einsum("i,j -> ij", t, angular_freq)
34
+ self.cos = nn.Buffer(theta.cos(), persistent=False)
35
+ self.sin = nn.Buffer(theta.sin(), persistent=False)
36
+
37
+ def forward(self, x_BTHD: Tensor):
38
+ assert self.cos.size(0) >= x_BTHD.size(-3)
39
+ cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
40
+ x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
41
+ y1 = x1 * cos + x2 * sin
42
+ y2 = x1 * (-sin) + x2 * cos
43
+ return torch.cat((y1, y2), 3).type_as(x_BTHD)
44
+
45
+ class CausalSelfAttention(nn.Module):
46
+ def __init__(self, dim: int, num_heads: int, layer_idx: int):
47
+ super().__init__()
48
+ assert dim % num_heads == 0
49
+ self.num_heads = num_heads
50
+ std = 0.5 * (dim ** -0.5)
51
+ bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
52
+ # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
53
+ # https://x.com/hi_tysam/status/1879699187107033311
54
+ self.qkv_w = nn.Parameter(torch.empty(3, dim, dim).uniform_(-bound, bound))
55
+ self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
56
+ self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim
57
+ self.c_proj = CastedLinear(dim, dim)
58
+ self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
59
+ # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
60
+ # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
61
+ self.attn_scale = 0.12
62
+
63
+ def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
64
+ B, T = x.size(0), x.size(1) # batch size, sequence length
65
+ assert B == 1, "Must use batch size = 1 for FlexAttention"
66
+ q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3*self.num_heads, -1).chunk(3, dim=-2)
67
+ if ve is not None:
68
+ v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
69
+ else: # skip mid-layers token value embeddings by @YouJiacheng
70
+ v = self.lambdas[0] * v
71
+ q, k = norm(q), norm(k) # QK norm @Grad62304977
72
+ q, k = self.rotary(q), self.rotary(k)
73
+ y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale)
74
+ y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
75
+ y = self.c_proj(y)
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+ def __init__(self, dim):
80
+ super().__init__()
81
+ self.c_fc = CastedLinear(dim, 4 * dim)
82
+ self.c_proj = CastedLinear(4 * dim, dim)
83
+ self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
84
+
85
+ def forward(self, x):
86
+ x = self.c_fc(x)
87
+ x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
88
+ x = self.c_proj(x)
89
+ return x
90
+
91
+ class Block(nn.Module):
92
+ def __init__(self, model_dim: int, num_heads: int, layer_idx: int):
93
+ super().__init__()
94
+ # skip attention of blocks.7 (the 8th layer) by @YouJiacheng
95
+ self.attn = CausalSelfAttention(model_dim, num_heads, layer_idx) if layer_idx != 7 else None
96
+ self.mlp = MLP(model_dim)
97
+ self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
98
+
99
+ def forward(self, x, ve, x0, block_mask):
100
+ x = self.lambdas[0] * x + self.lambdas[1] * x0
101
+ if self.attn is not None:
102
+ x = x + self.attn(norm(x), ve, block_mask)
103
+ x = x + self.mlp(norm(x))
104
+ return x
105
+
106
+ class ValueEmbedding(nn.Module):
107
+ def __init__(self, num_embeddings: int, embedding_dim: int, layer_count: int = 12):
108
+ super().__init__()
109
+ self.layer_count = layer_count
110
+ self.embed = nn.ModuleList([nn.Embedding(num_embeddings, embedding_dim) for _ in range(3)])
111
+
112
+ def forward(self, input_seq) -> list[Tensor | None]:
113
+ ve = [emb(input_seq) for emb in self.embed]
114
+ # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
115
+ new_ve = [None for _ in range(self.layer_count)]
116
+ new_ve[0] = ve[0]
117
+ new_ve[1] = ve[1]
118
+ new_ve[2] = ve[2]
119
+ new_ve[-1] = ve[2]
120
+ new_ve[-2] = ve[1]
121
+ new_ve[-3] = ve[0]
122
+ #ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, ve[0], ve[1], ve[2]]
123
+ return new_ve
124
+
125
+ # -----------------------------------------------------------------------------
126
+ # The main model
127
+
128
+ def next_multiple_of_n(v: float | int, *, n: int):
129
+ return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
130
+
131
+ class GPT(nn.Module):
132
+ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, eos_token_id: int = 50256, block_size: int = 128):
133
+ super().__init__()
134
+ self.eos_token_id = eos_token_id
135
+ self.block_size = block_size
136
+ self.embed = nn.Embedding(vocab_size, model_dim)
137
+ # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
138
+ self.value_embeds = ValueEmbedding(vocab_size, model_dim, layer_count=num_layers)
139
+ self.blocks = nn.ModuleList([Block(model_dim, num_heads, layer_idx) for layer_idx in range(num_layers)])
140
+ # U-net design by @brendanh0gan
141
+ self.num_encoder_layers = num_layers // 2 # Half of the layers for encoder
142
+ self.num_decoder_layers = num_layers - self.num_encoder_layers # Remaining for decoder
143
+ # Add learnable skip connection weights for decoder layers
144
+ self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
145
+ # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
146
+ # suggested to me by @Grad62304977. this originates from Karpathy's experiments.
147
+ self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128))
148
+ self.lm_head.weight.detach().zero_() # @Grad62304977
149
+
150
+ def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
151
+ BLOCK_SIZE = self.block_size
152
+ assert input_seq.ndim == 1
153
+ assert len(input_seq) % BLOCK_SIZE == 0
154
+ NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
155
+ docs = (input_seq == self.eos_token_id).cumsum(0)
156
+ docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
157
+ docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
158
+
159
+ def document_causal(b, h, q_idx, kv_idx):
160
+ causal_mask = q_idx >= kv_idx
161
+ document_mask = docs[q_idx] == docs[kv_idx]
162
+ return causal_mask & document_mask
163
+
164
+ def dense_to_ordered(dense_mask: Tensor):
165
+ num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32)
166
+ indices = dense_mask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
167
+ return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
168
+
169
+ # manual block mask creation by @YouJiacheng
170
+ def create_doc_swc_block_masks(sliding_window_num_blocks: Tensor):
171
+ kv_idx = block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
172
+ q_idx = block_idx[:, None]
173
+ causal_bm = q_idx >= kv_idx
174
+ causal_full_bm = q_idx > kv_idx
175
+ document_bm = (docs_low[:, None] <= docs_high) & (docs_low <= docs_high[:, None])
176
+ document_full_bm = (docs_low[:, None] == docs_high) & (docs_low == docs_high[:, None])
177
+ nonzero_bm = causal_bm & document_bm
178
+ full_bm = causal_full_bm & document_full_bm
179
+ kv_num_blocks, kv_indices = dense_to_ordered(nonzero_bm & ~full_bm)
180
+ full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm)
181
+ def build_bm(sw_num_blocks: Tensor) -> BlockMask:
182
+ return BlockMask.from_kv_blocks(
183
+ torch.clamp_max(kv_num_blocks, torch.clamp_min(sw_num_blocks - full_kv_num_blocks, 1)),
184
+ kv_indices,
185
+ torch.clamp_max(full_kv_num_blocks, sw_num_blocks - 1),
186
+ full_kv_indices,
187
+ BLOCK_SIZE=BLOCK_SIZE,
188
+ mask_mod=document_causal,
189
+ )
190
+ return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
191
+
192
+ # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
193
+ long_bm, short_bm = create_doc_swc_block_masks(sliding_window_num_blocks)
194
+
195
+ x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977
196
+ ve = self.value_embeds(input_seq)
197
+ assert len(ve) == len(self.blocks), f"expected {len(self.blocks)} value embeddings, got {len(ve)}"
198
+ ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
199
+ assert len(ve_enc) == self.num_encoder_layers and len(ve_dec) == self.num_decoder_layers
200
+
201
+ # Store outputs for U-Net skip connections
202
+ skip_connections = []
203
+ # Encoder pass - process only the first half of the blocks
204
+ block_masks = [long_bm if i % 2 == 0 else short_bm for i in range(self.num_encoder_layers)]
205
+ for i in range(self.num_encoder_layers):
206
+ x = self.blocks[i](x, ve_enc[i], x0, block_masks[i])
207
+ skip_connections.append(x)
208
+ # Decoder pass - process the remaining blocks with weighted skip connections
209
+ block_masks.reverse()
210
+ for i in range(self.num_decoder_layers):
211
+ x = x + self.skip_weights[i] * skip_connections.pop()
212
+ x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_masks[i])
213
+ x = norm(x)
214
+ logits = lm_head_plain(x, self.lm_head.weight) if self.training else self.lm_head(x)
215
+ # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
216
+ logits = 30 * torch.sigmoid(logits.float() / 7.5)
217
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq)
218
+ return loss, logits
219
+
220
+ def load_from_checkpoint(weights, **config):
221
+ model = GPT(**config)
222
+ model.load_state_dict(weights, strict=True)
223
+ return model
224
+
225
+
226
+ from transformers import PretrainedConfig
227
+
228
+ class MuonGPTConfig(PretrainedConfig):
229
+ model_type = "muon-gpt"
230
+ auto_map = {
231
+ "AutoConfig" : "modeling_nano_gpt.MuonGPTConfig",
232
+ "AutoModelForCausalLM": "modeling_nano_gpt.MuonGPTForCausalLM"
233
+ }
234
+
235
+
236
+ def __init__(self,
237
+ vocab_size=50257,
238
+ num_layers=12,
239
+ num_heads=6,
240
+ model_dim=768,
241
+ eos_token_id=50256,
242
+ block_size=128,
243
+ **kwargs):
244
+ super().__init__(**kwargs)
245
+ self.vocab_size = vocab_size
246
+ self.num_layers = num_layers
247
+ self.num_heads = num_heads
248
+ self.model_dim = model_dim
249
+ self.eos_token_id = eos_token_id
250
+ self.block_size = block_size
251
+
252
+
253
+ import torch, torch.nn.functional as F
254
+ from torch import nn
255
+ from transformers import PreTrainedModel, GenerationMixin
256
+ from transformers.modeling_outputs import CausalLMOutput
257
+
258
+
259
+ from typing import Optional, Tuple
260
+ BLOCK_SIZE = 128
261
+ PAD_TOKEN_ID = 50256 # GPT-2 <|endoftext|>
262
+
263
+ def _pad_to_multiple(x: torch.Tensor, multiple: int, value: int) -> Tuple[torch.Tensor, int]:
264
+ """Pad 1-D tensor on the right so that len(x) is a multiple of `multiple`."""
265
+ pad_len = (-x.size(0)) % multiple
266
+ if pad_len:
267
+ pad = x.new_full((pad_len,), value)
268
+ x = torch.cat([x, pad], dim=0)
269
+ return x, pad_len
270
+
271
+ class MuonGPTForCausalLM(PreTrainedModel, GenerationMixin):
272
+ config_class = MuonGPTConfig
273
+ supports_gradient_checkpointing = False
274
+
275
+
276
+ def __init__(self, config: MuonGPTConfig):
277
+ super().__init__(config)
278
+ self.gpt = GPT(
279
+ vocab_size = config.vocab_size,
280
+ num_layers = config.num_layers,
281
+ num_heads = config.num_heads,
282
+ model_dim = config.model_dim,
283
+ eos_token_id = config.eos_token_id,
284
+ block_size = config.block_size,
285
+ )
286
+ self.post_init() # HF helper
287
+
288
+ # ---------------------------------------------------------------------
289
+ # GenerationMixin helpers
290
+ # ---------------------------------------------------------------------
291
+ def get_input_embeddings(self):
292
+ return self.gpt.embed
293
+ def set_input_embeddings(self, new_emb):
294
+ self.gpt.embed = new_emb
295
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
296
+ return {"input_ids": input_ids}
297
+
298
+ # ---------------------------------------------------------------------
299
+ # Forward = pad → flatten → call GPT → reshape back
300
+ # ---------------------------------------------------------------------
301
+ def forward(
302
+ self,
303
+ input_ids: torch.Tensor, # (B, T)
304
+ attention_mask: Optional[torch.Tensor] = None,
305
+ labels: Optional[torch.Tensor] = None,
306
+ **kwargs
307
+ ) -> CausalLMOutput:
308
+
309
+ B, T = input_ids.shape
310
+ orig_tokens = B * T
311
+ device = input_ids.device
312
+
313
+ BLOCK_SIZE = self.gpt.block_size
314
+ PAD_TOKEN_ID = self.gpt.eos_token_id
315
+
316
+ # flatten & pad
317
+ flat_inp = input_ids.view(-1) # (B*T,)
318
+ flat_inp, pad_len = _pad_to_multiple(flat_inp, BLOCK_SIZE, PAD_TOKEN_ID)
319
+
320
+
321
+ if labels is None:
322
+ flat_lbl = flat_inp.clone()
323
+ else:
324
+ flat_lbl = labels.view(-1)
325
+ flat_lbl, _ = _pad_to_multiple(flat_lbl, BLOCK_SIZE, PAD_TOKEN_ID)
326
+
327
+ # dummy sliding-window argument (you can do better if you want)
328
+ sw_num_blocks = torch.tensor( flat_inp.size(0) // BLOCK_SIZE,
329
+ dtype=torch.int32, device=device )
330
+
331
+ # call the original training-time model
332
+ _, logits = self.gpt(flat_inp, flat_lbl, sw_num_blocks) # shape: (N, vocab)
333
+
334
+ logits = logits[:, :orig_tokens]
335
+
336
+ vocab = self.config.vocab_size
337
+ if logits.size(-1) != vocab:
338
+ logits = logits[:, :, :vocab]
339
+ logits = logits.view(B, T, -1)
340
+
341
+ loss = None
342
+ if labels is not None:
343
+ loss = F.cross_entropy(
344
+ logits.view(-1, logits.size(-1)),
345
+ labels.view(-1),
346
+ ignore_index=PAD_TOKEN_ID,
347
+ reduction="mean",
348
+ )
349
+
350
+ return CausalLMOutput(
351
+ loss = loss,
352
+ logits = logits,
353
+ )