diff --git a/README.md b/README.md index 2fb0f1c..577ee39 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,18 @@ Arxiv Link: https://arxiv.org/abs/2112.10961 Project Page: https://semcomm.github.io/ntscc/ -## Evaluation +## Usage + +## Pretrained Models + +Pretrained models (optimized for MSE) trained from scratch using randomly chose 300k images from the OpenImages dataset. + +Other pretrained models will be released successively. + +Note: We reorganize code and the performances are slightly different from the paper's. + +### RD curves + ## Citation If you find the code helpful in your resarch or work, please cite: diff --git a/channel/channel.py b/channel/channel.py index e622714..b2e67d3 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -22,20 +22,25 @@ def gaussian_noise_layer(self, input_layer, std): noise = noise_real + 1j * noise_imag return input_layer + noise - def forward(self, input, avg_pwr, power=1): - channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2) + def forward(self, input, avg_pwr=None, power=1): + if avg_pwr is None: + avg_pwr = torch.mean(input ** 2) + channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2) + else: + channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2) input_shape = channel_tx.shape channel_in = channel_tx.reshape(-1) channel_in = channel_in[::2] + channel_in[1::2] * 1j + channel_usage = channel_in.numel() channel_output = self.channel_forward(channel_in) channel_rx = torch.zeros_like(channel_tx.reshape(-1)) channel_rx[::2] = torch.real(channel_output) channel_rx[1::2] = torch.imag(channel_output) channel_rx = channel_rx.reshape(input_shape) - return channel_rx * torch.sqrt(avg_pwr * 2) + return channel_rx * torch.sqrt(avg_pwr * 2), channel_usage def channel_forward(self, channel_in): - if self.chan_type == 0 or self.chan_type == 'none': + if self.chan_type == 0 or self.chan_type == 'noiseless': return channel_in elif self.chan_type == 1 or self.chan_type == 'awgn': diff --git a/layer/jscc_encoder.py b/layer/jscc_encoder.py index ec686ed..4594702 100644 --- a/layer/jscc_encoder.py +++ b/layer/jscc_encoder.py @@ -37,7 +37,7 @@ def forward(self, x, indexes): x_BLC_masked = (torch.matmul(x_BLC.unsqueeze(2), w).squeeze() + b) * mask x_masked = x_BLC_masked.reshape(B, H, W, -1).permute(0, 3, 1, 2) mask_BCHW = mask.reshape(B, H, W, -1).permute(0, 3, 1, 2) - return x_masked, mask_BCHW, indexes + return x_masked, mask_BCHW def update_resolution(self, H, W, device): self.H = H @@ -70,8 +70,23 @@ def __init__(self, embed_dim=256, depths=[1, 1, 1], input_resolution=(16, 16), self.refine = Mlp(embed_dim * 2, embed_dim * 8, embed_dim) self.norm = norm_layer(embed_dim) - def forward(self, x, px, hx, eta): + def forward(self, x, px, eta): + """ + JSCCEncoder encodes latent representations to variable length channel-input vector. + + Arguments: + x: Latent representation (patch embeddings), shape of BxCxHxW, also viewed as Bx(HxW)xC. + px: Estimated probability of x, shape of BxCxHxW, also viewed as Bx(HxW)xC. + eta: Scaling factor from entropy to channel bandwidth cost. + + Returns: + s_masked: Channel-input vector. + indexes: The length of each patch embedding, shape of BxHxW. + mask: Binary mask, shape of BxCxHxW. + """ + B, C, H, W = x.size() + hx = torch.clamp_min(-torch.log(px) / math.log(2), 0) symbol_num = torch.sum(hx, dim=1).flatten(0) * eta x_BLC = x.flatten(2).permute(0, 2, 1) px_BLC = px.flatten(2).permute(0, 2, 1) @@ -84,8 +99,8 @@ def forward(self, x, px, hx, eta): x_BLC = layer(x_BLC.contiguous()) x_BLC = self.norm(x_BLC) x_BCHW = x_BLC.reshape(B, H, W, C).permute(0, 3, 1, 2) - x_masked, mask, indexes = self.rate_adaption(x_BCHW, indexes) - return x_masked, mask, indexes + s_masked, mask = self.rate_adaption(x_BCHW, indexes) + return s_masked, mask, indexes def update_resolution(self, H, W): self.input_resolution = (H, W) diff --git a/main.py b/main.py index 3817243..b4fde0e 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import torch import os -os.environ["CUDA_VISIBLE_DEVICES"] = "5" +os.environ["CUDA_VISIBLE_DEVICES"] = "7" from net.NTSCC_Hyperior import NTSCC_Hyperprior import torch.optim as optim from utils import * @@ -45,14 +45,11 @@ def train_one_epoch(): losses.update(loss.item()) bppys.update(bpp_y.item()) bppzs.update(bpp_z.item()) - if mse_loss_ntc.item() > 0: - psnr_jscc = 10 * (torch.log(255. * 255. / mse_loss_ntscc) / np.log(10)) - psnr_jsccs.update(psnr_jscc.item()) - psnr = 10 * (torch.log(255. * 255. / mse_loss_ntc) / np.log(10)) - psnrs.update(psnr.item()) - else: - psnrs.update(100) - psnr_jsccs.update(100) + + psnr_jscc = 10 * (torch.log(255. * 255. / mse_loss_ntscc) / np.log(10)) + psnr_jsccs.update(psnr_jscc.item()) + psnr = 10 * (torch.log(255. * 255. / mse_loss_ntc) / np.log(10)) + psnrs.update(psnr.item()) if (global_step % config.print_step) == 0: process = (global_step % train_loader.__len__()) / (train_loader.__len__()) * 100.0 @@ -60,9 +57,9 @@ def train_one_epoch(): f'Step [{global_step % train_loader.__len__()}/{train_loader.__len__()}={process:.2f}%]', f'Loss {losses.val:.3f} ({losses.avg:.3f})', f'Time {elapsed.avg:.2f}', - f'PSNR1 {psnr_jsccs.val:.2f} ({psnr_jsccs.avg:.2f})', + f'PSNR_JSCC {psnr_jsccs.val:.2f} ({psnr_jsccs.avg:.2f})', f'CBR {cbrs.val:.4f} ({cbrs.avg:.4f})', - f'PSNR2 {psnrs.val:.2f} ({psnrs.avg:.2f})', + f'PSNR_NTC {psnrs.val:.2f} ({psnrs.avg:.2f})', f'Bpp_y {bppys.val:.2f} ({bppys.avg:.2f})', f'Bpp_z {bppzs.val:.4f} ({bppzs.avg:.4f})', f'Epoch {epoch}', @@ -83,7 +80,6 @@ def test(): start_time = time.time() input_image = input_image.cuda() mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc = net(input_image) - # mse_loss_ntc, bpp_y, bpp_z, x_hat_ntc = net.forward_NTC(input_image) if config.use_side_info: cbr_z = bpp_snr_to_kdivn(bpp_z, 10) ntc_loss = mse_loss_ntc + config.train_lambda * (bpp_y + bpp_z) @@ -114,31 +110,13 @@ def test(): ])) logger.info(log) PSNR_list.append(psnr_jscc) - CBR_list.append(cbr_y.item()) - # channel bandwidth cost of side info \bar{k} - # 4 / (16*16*3) / 2.667 = 0.001953125 - # filename = "/media/Dataset/video_test/VSD_4K/city/city_45s_1/NTSCC_480p/{}_cbr={:.4f}_psnr={:.2f}.png".format((batch_idx+1).__str__().zfill(5), - # cbrs.val.item() + 0.001953125, - # psnrs.val) - # torchvision.utils.save_image(x_hat_ntscc[0], filename) - - logger.info(f'Finish test! Average PSNR={psnr_jsccs.avg:.4f}dB, CBR={cbrs.avg + 0.001953125:.4f}') - + CBR_list.append(cbr_y) -def test_image(net, input_image): - with torch.no_grad(): - net.eval() - input_image = input_image.cuda() - # mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc = net(input_image) - y = net.ga(input_image) - mse_loss_ntscc, cbr_y, bpp_z, x_hat_ntscc = net.forward_NTSCC(input_image, y) - if config.use_side_info: - cbr_z = bpp_snr_to_kdivn(bpp_z, 10) - cbr = cbr_y + cbr_z - else: - cbr = cbr_y - psnr_jscc = CalcuPSNR_int(input_image, x_hat_ntscc).mean() - logger.info(f'Finish test! Average PSNR={psnr_jscc:.4f}dB, CBR={cbr + 0.001953125:.4f}') + # Here, the channel bandwidth cost of side info \bar{k} is transmitted by a capacity-achieving channel code. Note + # that, the side info should be transmitted through entropy coding and channel coding, which will be addressed in + # future releases. + cbr_sideinfo = np.log2(config.multiple_rate.__len__()) / (16*16*3) / np.log2(1 + 10 ** (net.channel.chan_param / 10)) + logger.info(f'Finish test! Average PSNR={psnr_jsccs.avg:.4f}dB, CBR={cbrs.avg + cbr_sideinfo:.4f}') if __name__ == '__main__': diff --git a/net/NTSCC_Hyperior.py b/net/NTSCC_Hyperior.py index a08d52f..2a4771b 100644 --- a/net/NTSCC_Hyperior.py +++ b/net/NTSCC_Hyperior.py @@ -48,7 +48,7 @@ def update_resolution(self, H, W): self.H = H self.W = W - def forward(self, input_image): + def forward(self, input_image, require_probs=False): B, C, H, W = input_image.shape self.update_resolution(H, W) y = self.ga(input_image) @@ -66,7 +66,10 @@ def forward(self, input_image): mse_loss = self.distortion(input_image, x_hat) bpp_y = torch.log(y_likelihoods).sum() / (-math.log(2) * H * W) / B bpp_z = torch.log(z_likelihoods).sum() / (-math.log(2) * H * W) / B - return mse_loss, bpp_y, bpp_z, x_hat + if require_probs: + return mse_loss, bpp_y, bpp_z, x_hat, y, y_likelihoods, scales_hat, means_hat + else: + return mse_loss, bpp_y, bpp_z, x_hat def aux_loss(self): """Return the aggregated loss over the auxiliary entropy bottleneck @@ -86,7 +89,6 @@ def __init__(self, config): self.fe = JSCCEncoder(**config.fe_kwargs) self.fd = JSCCDecoder(**config.fd_kwargs) if config.use_side_info: - # hyperprior-aided decoder refinement embed_dim = config.fe_kwargs['embed_dim'] self.hyprior_refinement = Mlp(embed_dim * 3, embed_dim * 6, embed_dim) self.eta = config.eta @@ -95,13 +97,12 @@ def feature_probs_based_Gaussian(self, feature, mean, sigma): sigma = sigma.clamp(1e-10, 1e10) if sigma.dtype == torch.float32 else sigma.clamp(1e-10, 1e4) gaussian = torch.distributions.normal.Normal(mean, sigma) prob = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5) - likelihoods = torch.clamp(prob, 1e-10, 1e10) # BCHW - # likelihoods = -1.0 * torch.log(probs) / math.log(2.0) + likelihoods = torch.clamp(prob, 1e-10, 1e10) # B C H W entropy = torch.clamp_min(-torch.log(likelihoods) / math.log(2), 0) # B H W return likelihoods, entropy - def forward(self, input_image): - B, C, H, W = input_image.shape + def update_resolution(self, H, W): + # Update attention mask for W-MSA and SW-MSA if H != self.H or W != self.W: self.ga.update_resolution(H, W) self.fe.update_resolution(H // 16, W // 16) @@ -110,84 +111,43 @@ def forward(self, input_image): self.H = H self.W = W - # forward NTC - y = self.ga(input_image) - z = self.ha(y) - _, z_likelihoods = self.entropy_bottleneck(z) - z_offset = self.entropy_bottleneck._get_medians() - z_tmp = z - z_offset - z_hat = ste_round(z_tmp) + z_offset - - gaussian_params = self.hs(z_hat) - scales_hat, means_hat = gaussian_params.chunk(2, 1) - y_hat = ste_round(y - means_hat) + means_hat - y_likelihoods, hy = self.feature_probs_based_Gaussian(y, means_hat, scales_hat) - hy = torch.clamp_min(-torch.log(y_likelihoods) / math.log(2), 0) - x_hat_ntc = self.gs(y_hat) - mse_loss_ntc = self.distortion(input_image, x_hat_ntc) - bpp_y = torch.log(y_likelihoods).sum() / (-math.log(2) * H * W) / B - bpp_z = torch.log(z_likelihoods).sum() / (-math.log(2) * H * W) / B - - # forward NTSCC - s_masked, mask_BCHW, indexes = self.fe(y, y_likelihoods.detach(), hy, eta=self.eta) - avg_pwr = torch.sum(s_masked ** 2) / mask_BCHW.sum() - s_hat = self.channel.forward(s_masked, avg_pwr) * mask_BCHW - # indexes - y_hat = self.fd(s_hat, indexes) - - if self.config.use_side_info: - y_combine = torch.cat([BCHW2BLN(y_hat), BCHW2BLN(means_hat), BCHW2BLN(scales_hat)], dim=-1) - y_hat = BLN2BCHW(BCHW2BLN(y_hat) + self.hyprior_refinement(y_combine), H // 16, W // 16) + def forward(self, input_image, **kwargs): + B, C, H, W = input_image.shape + num_pixels = H * W * 3 + self.update_resolution(H, W) - x_hat_ntscc = self.gs(y_hat).clip(0, 1) - mse_loss_ntscc = self.distortion(input_image, x_hat_ntscc) - cbr_y = mask_BCHW.sum() / (B * H * W * 3 * 2) - return mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc + # NTC forward + mse_loss_ntc, bpp_y, bpp_z, x_hat_ntc, y, y_likelihoods, scales_hat, means_hat = \ + self.forward_NTC(input_image, require_probs=True) - def forward_NTC(self, input_image): - return super(NTSCC_Hyperprior, self).forward_NTC(input_image) + # DJSCC forward + s_masked, mask_BCHW, indexes = self.fe(y, y_likelihoods.detach(), eta=self.eta) - def forward_NTSCC(self, input_image): - B, C, H, W = input_image.shape - if H != self.H or W != self.W: - self.ga.update_resolution(H, W) - self.fe.update_resolution(H // 16, W // 16) - self.gs.update_resolution(H // 16, W // 16) - self.fd.update_resolution(H // 16, W // 16) - self.H = H - self.W = W + # Pass through the channel. + mask_BCHW = mask_BCHW.byte() + channel_input = torch.masked_select(s_masked, mask_BCHW) + channel_output, channel_usage = self.channel.forward(channel_input) + s_hat = torch.zeros_like(s_masked) + s_hat[mask_BCHW] = channel_output + cbr_y = channel_usage / num_pixels - # forward NTC - y = self.ga(input_image) - z = self.ha(y) - _, z_likelihoods = self.entropy_bottleneck(z) - z_offset = self.entropy_bottleneck._get_medians() - z_tmp = z - z_offset - z_hat = ste_round(z_tmp) + z_offset - bpp_z = torch.log(z_likelihoods).sum() / (-math.log(2) * H * W) / B + # Another realization of channel. + # avg_pwr = torch.sum(s_masked ** 2) / mask_BCHW.sum() + # s_hat, _ = self.channel.forward(s_masked, avg_pwr) + # s_hat = s_hat * mask_BCHW + # cbr_y = mask_BCHW.sum() / (B * num_pixels * 2) - gaussian_params = self.hs(z_hat) - scales_hat, means_hat = gaussian_params.chunk(2, 1) - y_likelihoods, hy = self.feature_probs_based_Gaussian(y, means_hat, scales_hat) - hy = torch.clamp_min(-torch.log(y_likelihoods) / math.log(2), 0) - s_masked, mask_BCHW, indexes = self.fe(y, y_likelihoods.detach(), hy, eta=self.eta) - avg_pwr = torch.sum(s_masked ** 2) / mask_BCHW.sum() - s_hat = self.channel.forward(s_masked, avg_pwr) * mask_BCHW - y_hat = self.fd(s_hat, indexes, eta=self.eta) + y_hat = self.fd(s_hat, indexes) + # hyperprior-aided decoder refinement (optional) if self.config.use_side_info: y_combine = torch.cat([BCHW2BLN(y_hat), BCHW2BLN(means_hat), BCHW2BLN(scales_hat)], dim=-1) y_hat = BLN2BCHW(BCHW2BLN(y_hat) + self.hyprior_refinement(y_combine), H // 16, W // 16) x_hat_ntscc = self.gs(y_hat).clip(0, 1) mse_loss_ntscc = self.distortion(input_image, x_hat_ntscc) - cbr_y = mask_BCHW.sum() / (B * H * W * 3 * 2) - return mse_loss_ntscc, cbr_y, bpp_z, x_hat_ntscc - def update_resolution(self, H, W): - self.ga.update_resolution(H, W) - self.fe.update_resolution(H // 16, W // 16) - self.gs.update_resolution(H // 16, W // 16) - self.fd.update_resolution(H // 16, W // 16) - self.H = H - self.W = W + return mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc + + def forward_NTC(self, input_image, **kwargs): + return super(NTSCC_Hyperprior, self).forward(input_image, **kwargs)