| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from modules import Encoder, Decoder |
| | from modules import Codebook |
| |
|
| |
|
| | class VQBASE(nn.Module): |
| | def __init__(self, ddconfig, n_embed, embed_dim, init_steps, reservoir_size): |
| | super(VQBASE, self).__init__() |
| | self.encoder = Encoder(**ddconfig) |
| | self.decoder = Decoder(**ddconfig) |
| | self.quantize = Codebook(n_embed, embed_dim, beta=0.25, init_steps=init_steps, reservoir_size=reservoir_size) |
| | self.quant_conv = nn.Sequential( |
| | nn.Conv2d(ddconfig["z_channels"], embed_dim, 1), |
| | nn.SyncBatchNorm(embed_dim) |
| | ) |
| | self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
| |
|
| | def encode(self, x): |
| | h = self.encoder(x) |
| | h = self.quant_conv(h) |
| | quant, emb_loss, info = self.quantize(h) |
| | return quant, emb_loss, info |
| |
|
| | def decode(self, quant): |
| | quant = self.post_quant_conv(quant) |
| | dec = self.decoder(quant) |
| | return dec |
| |
|
| | def decode_code(self, code_b): |
| | quant_b = self.quantize.embed_code(code_b) |
| | dec = self.decode(quant_b) |
| | return dec |
| |
|
| | def forward(self, input): |
| | quant, diff = self.encode(input) |
| | dec = self.decode(quant) |
| | return dec, diff |
| |
|
| |
|