Skip to content

Commit

Permalink
add more descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxtyrdd committed Aug 9, 2022
1 parent 293d18c commit e519f70
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 120 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@ Arxiv Link: https://arxiv.org/abs/2112.10961

Project Page: https://semcomm.github.io/ntscc/

## Evaluation
## Usage

## Pretrained Models

Pretrained models (optimized for MSE) trained from scratch using randomly chose 300k images from the OpenImages dataset.

Other pretrained models will be released successively.

Note: We reorganize code and the performances are slightly different from the paper's.

### RD curves


## Citation
If you find the code helpful in your resarch or work, please cite:
Expand Down
13 changes: 9 additions & 4 deletions channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,25 @@ def gaussian_noise_layer(self, input_layer, std):
noise = noise_real + 1j * noise_imag
return input_layer + noise

def forward(self, input, avg_pwr, power=1):
channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2)
def forward(self, input, avg_pwr=None, power=1):
if avg_pwr is None:
avg_pwr = torch.mean(input ** 2)
channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2)
else:
channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2)
input_shape = channel_tx.shape
channel_in = channel_tx.reshape(-1)
channel_in = channel_in[::2] + channel_in[1::2] * 1j
channel_usage = channel_in.numel()
channel_output = self.channel_forward(channel_in)
channel_rx = torch.zeros_like(channel_tx.reshape(-1))
channel_rx[::2] = torch.real(channel_output)
channel_rx[1::2] = torch.imag(channel_output)
channel_rx = channel_rx.reshape(input_shape)
return channel_rx * torch.sqrt(avg_pwr * 2)
return channel_rx * torch.sqrt(avg_pwr * 2), channel_usage

def channel_forward(self, channel_in):
if self.chan_type == 0 or self.chan_type == 'none':
if self.chan_type == 0 or self.chan_type == 'noiseless':
return channel_in

elif self.chan_type == 1 or self.chan_type == 'awgn':
Expand Down
23 changes: 19 additions & 4 deletions layer/jscc_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def forward(self, x, indexes):
x_BLC_masked = (torch.matmul(x_BLC.unsqueeze(2), w).squeeze() + b) * mask
x_masked = x_BLC_masked.reshape(B, H, W, -1).permute(0, 3, 1, 2)
mask_BCHW = mask.reshape(B, H, W, -1).permute(0, 3, 1, 2)
return x_masked, mask_BCHW, indexes
return x_masked, mask_BCHW

def update_resolution(self, H, W, device):
self.H = H
Expand Down Expand Up @@ -70,8 +70,23 @@ def __init__(self, embed_dim=256, depths=[1, 1, 1], input_resolution=(16, 16),
self.refine = Mlp(embed_dim * 2, embed_dim * 8, embed_dim)
self.norm = norm_layer(embed_dim)

def forward(self, x, px, hx, eta):
def forward(self, x, px, eta):
"""
JSCCEncoder encodes latent representations to variable length channel-input vector.
Arguments:
x: Latent representation (patch embeddings), shape of BxCxHxW, also viewed as Bx(HxW)xC.
px: Estimated probability of x, shape of BxCxHxW, also viewed as Bx(HxW)xC.
eta: Scaling factor from entropy to channel bandwidth cost.
Returns:
s_masked: Channel-input vector.
indexes: The length of each patch embedding, shape of BxHxW.
mask: Binary mask, shape of BxCxHxW.
"""

B, C, H, W = x.size()
hx = torch.clamp_min(-torch.log(px) / math.log(2), 0)
symbol_num = torch.sum(hx, dim=1).flatten(0) * eta
x_BLC = x.flatten(2).permute(0, 2, 1)
px_BLC = px.flatten(2).permute(0, 2, 1)
Expand All @@ -84,8 +99,8 @@ def forward(self, x, px, hx, eta):
x_BLC = layer(x_BLC.contiguous())
x_BLC = self.norm(x_BLC)
x_BCHW = x_BLC.reshape(B, H, W, C).permute(0, 3, 1, 2)
x_masked, mask, indexes = self.rate_adaption(x_BCHW, indexes)
return x_masked, mask, indexes
s_masked, mask = self.rate_adaption(x_BCHW, indexes)
return s_masked, mask, indexes

def update_resolution(self, H, W):
self.input_resolution = (H, W)
Expand Down
50 changes: 14 additions & 36 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from net.NTSCC_Hyperior import NTSCC_Hyperprior
import torch.optim as optim
from utils import *
Expand Down Expand Up @@ -45,24 +45,21 @@ def train_one_epoch():
losses.update(loss.item())
bppys.update(bpp_y.item())
bppzs.update(bpp_z.item())
if mse_loss_ntc.item() > 0:
psnr_jscc = 10 * (torch.log(255. * 255. / mse_loss_ntscc) / np.log(10))
psnr_jsccs.update(psnr_jscc.item())
psnr = 10 * (torch.log(255. * 255. / mse_loss_ntc) / np.log(10))
psnrs.update(psnr.item())
else:
psnrs.update(100)
psnr_jsccs.update(100)

psnr_jscc = 10 * (torch.log(255. * 255. / mse_loss_ntscc) / np.log(10))
psnr_jsccs.update(psnr_jscc.item())
psnr = 10 * (torch.log(255. * 255. / mse_loss_ntc) / np.log(10))
psnrs.update(psnr.item())

if (global_step % config.print_step) == 0:
process = (global_step % train_loader.__len__()) / (train_loader.__len__()) * 100.0
log = (' | '.join([
f'Step [{global_step % train_loader.__len__()}/{train_loader.__len__()}={process:.2f}%]',
f'Loss {losses.val:.3f} ({losses.avg:.3f})',
f'Time {elapsed.avg:.2f}',
f'PSNR1 {psnr_jsccs.val:.2f} ({psnr_jsccs.avg:.2f})',
f'PSNR_JSCC {psnr_jsccs.val:.2f} ({psnr_jsccs.avg:.2f})',
f'CBR {cbrs.val:.4f} ({cbrs.avg:.4f})',
f'PSNR2 {psnrs.val:.2f} ({psnrs.avg:.2f})',
f'PSNR_NTC {psnrs.val:.2f} ({psnrs.avg:.2f})',
f'Bpp_y {bppys.val:.2f} ({bppys.avg:.2f})',
f'Bpp_z {bppzs.val:.4f} ({bppzs.avg:.4f})',
f'Epoch {epoch}',
Expand All @@ -83,7 +80,6 @@ def test():
start_time = time.time()
input_image = input_image.cuda()
mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc = net(input_image)
# mse_loss_ntc, bpp_y, bpp_z, x_hat_ntc = net.forward_NTC(input_image)
if config.use_side_info:
cbr_z = bpp_snr_to_kdivn(bpp_z, 10)
ntc_loss = mse_loss_ntc + config.train_lambda * (bpp_y + bpp_z)
Expand Down Expand Up @@ -114,31 +110,13 @@ def test():
]))
logger.info(log)
PSNR_list.append(psnr_jscc)
CBR_list.append(cbr_y.item())
# channel bandwidth cost of side info \bar{k}
# 4 / (16*16*3) / 2.667 = 0.001953125
# filename = "/media/Dataset/video_test/VSD_4K/city/city_45s_1/NTSCC_480p/{}_cbr={:.4f}_psnr={:.2f}.png".format((batch_idx+1).__str__().zfill(5),
# cbrs.val.item() + 0.001953125,
# psnrs.val)
# torchvision.utils.save_image(x_hat_ntscc[0], filename)

logger.info(f'Finish test! Average PSNR={psnr_jsccs.avg:.4f}dB, CBR={cbrs.avg + 0.001953125:.4f}')

CBR_list.append(cbr_y)

def test_image(net, input_image):
with torch.no_grad():
net.eval()
input_image = input_image.cuda()
# mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc = net(input_image)
y = net.ga(input_image)
mse_loss_ntscc, cbr_y, bpp_z, x_hat_ntscc = net.forward_NTSCC(input_image, y)
if config.use_side_info:
cbr_z = bpp_snr_to_kdivn(bpp_z, 10)
cbr = cbr_y + cbr_z
else:
cbr = cbr_y
psnr_jscc = CalcuPSNR_int(input_image, x_hat_ntscc).mean()
logger.info(f'Finish test! Average PSNR={psnr_jscc:.4f}dB, CBR={cbr + 0.001953125:.4f}')
# Here, the channel bandwidth cost of side info \bar{k} is transmitted by a capacity-achieving channel code. Note
# that, the side info should be transmitted through entropy coding and channel coding, which will be addressed in
# future releases.
cbr_sideinfo = np.log2(config.multiple_rate.__len__()) / (16*16*3) / np.log2(1 + 10 ** (net.channel.chan_param / 10))
logger.info(f'Finish test! Average PSNR={psnr_jsccs.avg:.4f}dB, CBR={cbrs.avg + cbr_sideinfo:.4f}')


if __name__ == '__main__':
Expand Down
110 changes: 35 additions & 75 deletions net/NTSCC_Hyperior.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def update_resolution(self, H, W):
self.H = H
self.W = W

def forward(self, input_image):
def forward(self, input_image, require_probs=False):
B, C, H, W = input_image.shape
self.update_resolution(H, W)
y = self.ga(input_image)
Expand All @@ -66,7 +66,10 @@ def forward(self, input_image):
mse_loss = self.distortion(input_image, x_hat)
bpp_y = torch.log(y_likelihoods).sum() / (-math.log(2) * H * W) / B
bpp_z = torch.log(z_likelihoods).sum() / (-math.log(2) * H * W) / B
return mse_loss, bpp_y, bpp_z, x_hat
if require_probs:
return mse_loss, bpp_y, bpp_z, x_hat, y, y_likelihoods, scales_hat, means_hat
else:
return mse_loss, bpp_y, bpp_z, x_hat

def aux_loss(self):
"""Return the aggregated loss over the auxiliary entropy bottleneck
Expand All @@ -86,7 +89,6 @@ def __init__(self, config):
self.fe = JSCCEncoder(**config.fe_kwargs)
self.fd = JSCCDecoder(**config.fd_kwargs)
if config.use_side_info:
# hyperprior-aided decoder refinement
embed_dim = config.fe_kwargs['embed_dim']
self.hyprior_refinement = Mlp(embed_dim * 3, embed_dim * 6, embed_dim)
self.eta = config.eta
Expand All @@ -95,13 +97,12 @@ def feature_probs_based_Gaussian(self, feature, mean, sigma):
sigma = sigma.clamp(1e-10, 1e10) if sigma.dtype == torch.float32 else sigma.clamp(1e-10, 1e4)
gaussian = torch.distributions.normal.Normal(mean, sigma)
prob = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5)
likelihoods = torch.clamp(prob, 1e-10, 1e10) # BCHW
# likelihoods = -1.0 * torch.log(probs) / math.log(2.0)
likelihoods = torch.clamp(prob, 1e-10, 1e10) # B C H W
entropy = torch.clamp_min(-torch.log(likelihoods) / math.log(2), 0) # B H W
return likelihoods, entropy

def forward(self, input_image):
B, C, H, W = input_image.shape
def update_resolution(self, H, W):
# Update attention mask for W-MSA and SW-MSA
if H != self.H or W != self.W:
self.ga.update_resolution(H, W)
self.fe.update_resolution(H // 16, W // 16)
Expand All @@ -110,84 +111,43 @@ def forward(self, input_image):
self.H = H
self.W = W

# forward NTC
y = self.ga(input_image)
z = self.ha(y)
_, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset

gaussian_params = self.hs(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
y_hat = ste_round(y - means_hat) + means_hat
y_likelihoods, hy = self.feature_probs_based_Gaussian(y, means_hat, scales_hat)
hy = torch.clamp_min(-torch.log(y_likelihoods) / math.log(2), 0)
x_hat_ntc = self.gs(y_hat)
mse_loss_ntc = self.distortion(input_image, x_hat_ntc)
bpp_y = torch.log(y_likelihoods).sum() / (-math.log(2) * H * W) / B
bpp_z = torch.log(z_likelihoods).sum() / (-math.log(2) * H * W) / B

# forward NTSCC
s_masked, mask_BCHW, indexes = self.fe(y, y_likelihoods.detach(), hy, eta=self.eta)
avg_pwr = torch.sum(s_masked ** 2) / mask_BCHW.sum()
s_hat = self.channel.forward(s_masked, avg_pwr) * mask_BCHW
# indexes
y_hat = self.fd(s_hat, indexes)

if self.config.use_side_info:
y_combine = torch.cat([BCHW2BLN(y_hat), BCHW2BLN(means_hat), BCHW2BLN(scales_hat)], dim=-1)
y_hat = BLN2BCHW(BCHW2BLN(y_hat) + self.hyprior_refinement(y_combine), H // 16, W // 16)
def forward(self, input_image, **kwargs):
B, C, H, W = input_image.shape
num_pixels = H * W * 3
self.update_resolution(H, W)

x_hat_ntscc = self.gs(y_hat).clip(0, 1)
mse_loss_ntscc = self.distortion(input_image, x_hat_ntscc)
cbr_y = mask_BCHW.sum() / (B * H * W * 3 * 2)
return mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc
# NTC forward
mse_loss_ntc, bpp_y, bpp_z, x_hat_ntc, y, y_likelihoods, scales_hat, means_hat = \
self.forward_NTC(input_image, require_probs=True)

def forward_NTC(self, input_image):
return super(NTSCC_Hyperprior, self).forward_NTC(input_image)
# DJSCC forward
s_masked, mask_BCHW, indexes = self.fe(y, y_likelihoods.detach(), eta=self.eta)

def forward_NTSCC(self, input_image):
B, C, H, W = input_image.shape
if H != self.H or W != self.W:
self.ga.update_resolution(H, W)
self.fe.update_resolution(H // 16, W // 16)
self.gs.update_resolution(H // 16, W // 16)
self.fd.update_resolution(H // 16, W // 16)
self.H = H
self.W = W
# Pass through the channel.
mask_BCHW = mask_BCHW.byte()
channel_input = torch.masked_select(s_masked, mask_BCHW)
channel_output, channel_usage = self.channel.forward(channel_input)
s_hat = torch.zeros_like(s_masked)
s_hat[mask_BCHW] = channel_output
cbr_y = channel_usage / num_pixels

# forward NTC
y = self.ga(input_image)
z = self.ha(y)
_, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
bpp_z = torch.log(z_likelihoods).sum() / (-math.log(2) * H * W) / B
# Another realization of channel.
# avg_pwr = torch.sum(s_masked ** 2) / mask_BCHW.sum()
# s_hat, _ = self.channel.forward(s_masked, avg_pwr)
# s_hat = s_hat * mask_BCHW
# cbr_y = mask_BCHW.sum() / (B * num_pixels * 2)

gaussian_params = self.hs(z_hat)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
y_likelihoods, hy = self.feature_probs_based_Gaussian(y, means_hat, scales_hat)
hy = torch.clamp_min(-torch.log(y_likelihoods) / math.log(2), 0)
s_masked, mask_BCHW, indexes = self.fe(y, y_likelihoods.detach(), hy, eta=self.eta)
avg_pwr = torch.sum(s_masked ** 2) / mask_BCHW.sum()
s_hat = self.channel.forward(s_masked, avg_pwr) * mask_BCHW
y_hat = self.fd(s_hat, indexes, eta=self.eta)

y_hat = self.fd(s_hat, indexes)
# hyperprior-aided decoder refinement (optional)
if self.config.use_side_info:
y_combine = torch.cat([BCHW2BLN(y_hat), BCHW2BLN(means_hat), BCHW2BLN(scales_hat)], dim=-1)
y_hat = BLN2BCHW(BCHW2BLN(y_hat) + self.hyprior_refinement(y_combine), H // 16, W // 16)

x_hat_ntscc = self.gs(y_hat).clip(0, 1)
mse_loss_ntscc = self.distortion(input_image, x_hat_ntscc)
cbr_y = mask_BCHW.sum() / (B * H * W * 3 * 2)
return mse_loss_ntscc, cbr_y, bpp_z, x_hat_ntscc

def update_resolution(self, H, W):
self.ga.update_resolution(H, W)
self.fe.update_resolution(H // 16, W // 16)
self.gs.update_resolution(H // 16, W // 16)
self.fd.update_resolution(H // 16, W // 16)
self.H = H
self.W = W
return mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc

def forward_NTC(self, input_image, **kwargs):
return super(NTSCC_Hyperprior, self).forward(input_image, **kwargs)

0 comments on commit e519f70

Please sign in to comment.