Skip to content

Commit

Permalink
Optimize stack (#828)
Browse files Browse the repository at this point in the history
* Increase encryption performance

* Release bufs when decrypting fails

* Use correct exceptions

* Improve compression performance

* Rename as requested by @Redned235

* Address review

* var name consistency

---------

Co-authored-by: Konicai <[email protected]>
  • Loading branch information
AlexProgrammerDE and Konicai committed Jul 27, 2024
1 parent ab9bcbe commit bb38c8a
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 115 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.geysermc.mcprotocollib.network.compression;

import io.netty.buffer.ByteBuf;

public interface PacketCompression {
void inflate(ByteBuf source, ByteBuf destination, int uncompressedSize) throws Exception;

void deflate(ByteBuf source, ByteBuf destination);

void close();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package org.geysermc.mcprotocollib.network.compression;

import io.netty.buffer.ByteBuf;

import java.nio.ByteBuffer;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

public class ZlibCompression implements PacketCompression {
private static final int ZLIB_BUFFER_SIZE = 8192;
private final Deflater deflater;
private final Inflater inflater;

public ZlibCompression() {
this(Deflater.DEFAULT_COMPRESSION);
}

public ZlibCompression(int level) {
this.deflater = new Deflater(level);
this.inflater = new Inflater();
}

@Override
public void inflate(ByteBuf source, ByteBuf destination, int uncompressedSize) throws DataFormatException {
final int originalIndex = source.readerIndex();
inflater.setInput(source.nioBuffer());

try {
while (!inflater.finished() && inflater.getBytesWritten() < uncompressedSize) {
if (!destination.isWritable()) {
destination.ensureWritable(ZLIB_BUFFER_SIZE);
}

ByteBuffer destNioBuf = destination.nioBuffer(destination.writerIndex(),
destination.writableBytes());
int produced = inflater.inflate(destNioBuf);
destination.writerIndex(destination.writerIndex() + produced);
}

if (!inflater.finished()) {
throw new DataFormatException("Received a deflate stream that was too large, wanted " + uncompressedSize);
}

source.readerIndex(originalIndex + inflater.getTotalIn());
} finally {
inflater.reset();
}
}

@Override
public void deflate(ByteBuf source, ByteBuf destination) {
final int originalIndex = source.readerIndex();
deflater.setInput(source.nioBuffer());
deflater.finish();

while (!deflater.finished()) {
if (!destination.isWritable()) {
destination.ensureWritable(ZLIB_BUFFER_SIZE);
}

ByteBuffer destNioBuf = destination.nioBuffer(destination.writerIndex(),
destination.writableBytes());
int produced = deflater.deflate(destNioBuf);
destination.writerIndex(destination.writerIndex() + produced);
}

source.readerIndex(originalIndex + deflater.getTotalIn());
deflater.reset();
}

@Override
public void close() {
deflater.end();
inflater.end();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,18 @@ public class AESEncryption implements PacketEncryption {
public AESEncryption(Key key) throws GeneralSecurityException {
this.inCipher = Cipher.getInstance("AES/CFB8/NoPadding");
this.inCipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(key.getEncoded()));

this.outCipher = Cipher.getInstance("AES/CFB8/NoPadding");
this.outCipher.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(key.getEncoded()));
}

@Override
public int getDecryptOutputSize(int length) {
return this.inCipher.getOutputSize(length);
}

@Override
public int getEncryptOutputSize(int length) {
return this.outCipher.getOutputSize(length);
}

@Override
public int decrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception {
return this.inCipher.update(input, inputOffset, inputLength, output, outputOffset);
public void decrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception {
this.inCipher.update(input, inputOffset, inputLength, output, outputOffset);
}

@Override
public int encrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception {
return this.outCipher.update(input, inputOffset, inputLength, output, outputOffset);
public void encrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception {
this.outCipher.update(input, inputOffset, inputLength, output, outputOffset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,33 @@

/**
* An interface for encrypting packets.
* The outputLength should always be the same as the inputLength.
* This is because that's what the Minecraft vanilla protocol does.
*/
public interface PacketEncryption {
/**
* Gets the output size from decrypting.
*
* @param length Length of the data being decrypted.
* @return The output size from decrypting.
*/
int getDecryptOutputSize(int length);

/**
* Gets the output size from encrypting.
*
* @param length Length of the data being encrypted.
* @return The output size from encrypting.
*/
int getEncryptOutputSize(int length);

/**
* Decrypts the given data.
* Input and output arrays can be the same.
*
* @param input Input data to decrypt.
* @param inputOffset Offset of the data to start decrypting at.
* @param inputLength Length of the data to be decrypted.
* @param output Array to output decrypted data to.
* @param outputOffset Offset of the output array to start at.
* @return The number of bytes stored in the output array.
* @throws Exception If an error occurs.
*/
int decrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception;
void decrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception;

/**
* Encrypts the given data.
* Input and output arrays can be the same.
*
* @param input Input data to encrypt.
* @param inputOffset Offset of the data to start encrypting at.
* @param inputLength Length of the data to be encrypted.
* @param output Array to output encrypted data to.
* @param outputOffset Offset of the output array to start at.
* @return The number of bytes stored in the output array.
* @throws Exception If an error occurs.
*/
int encrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception;
void encrypt(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -1,88 +1,77 @@
package org.geysermc.mcprotocollib.network.tcp;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.PacketCompression;

import java.util.List;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

public class TcpPacketCompression extends ByteToMessageCodec<ByteBuf> {
private static final int MAX_UNCOMPRESSED_SIZE = 8388608;
public class TcpPacketCompression extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private static final int MAX_UNCOMPRESSED_SIZE = 8 * 1024 * 1024; // 8MiB

private final Session session;
private final Deflater deflater = new Deflater();
private final Inflater inflater = new Inflater();
private final byte[] buf = new byte[8192];
private final PacketCompression compression;
private final boolean validateDecompression;

public TcpPacketCompression(Session session, boolean validateDecompression) {
public TcpPacketCompression(Session session, PacketCompression compression, boolean validateDecompression) {
this.session = session;
this.compression = compression;
this.validateDecompression = validateDecompression;
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
super.handlerRemoved(ctx);

this.deflater.end();
this.inflater.end();
public void handlerRemoved(ChannelHandlerContext ctx) {
this.compression.close();
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) {
int readable = in.readableBytes();
if (readable > MAX_UNCOMPRESSED_SIZE) {
throw new EncoderException("Packet too big: size of " + readable + " is larger than the protocol maximum of " + MAX_UNCOMPRESSED_SIZE + ".");
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
int uncompressed = msg.readableBytes();
if (uncompressed > MAX_UNCOMPRESSED_SIZE) {
throw new IllegalArgumentException("Packet too big (is " + uncompressed + ", should be less than " + MAX_UNCOMPRESSED_SIZE + ")");
}
if (readable < this.session.getCompressionThreshold()) {
this.session.getCodecHelper().writeVarInt(out, 0);
out.writeBytes(in);
} else {
byte[] bytes = new byte[readable];
in.readBytes(bytes);
this.session.getCodecHelper().writeVarInt(out, bytes.length);
this.deflater.setInput(bytes, 0, readable);
this.deflater.finish();
while (!this.deflater.finished()) {
int length = this.deflater.deflate(this.buf);
out.writeBytes(this.buf, 0, length);
}

this.deflater.reset();
ByteBuf outBuf = ctx.alloc().directBuffer(uncompressed);
if (uncompressed < this.session.getCompressionThreshold()) {
// Under the threshold, there is nothing to do.
this.session.getCodecHelper().writeVarInt(outBuf, 0);
outBuf.writeBytes(msg);
} else {
this.session.getCodecHelper().writeVarInt(outBuf, uncompressed);
compression.deflate(msg, outBuf);
}

out.add(outBuf);
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) throws Exception {
if (buf.readableBytes() != 0) {
int size = this.session.getCodecHelper().readVarInt(buf);
if (size == 0) {
out.add(buf.readBytes(buf.readableBytes()));
} else {
if (validateDecompression) { // This is sectioned off as of at least Java Edition 1.18
if (size < this.session.getCompressionThreshold()) {
throw new DecoderException("Badly compressed packet: size of " + size + " is below threshold of " + this.session.getCompressionThreshold() + ".");
}
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
int claimedUncompressedSize = this.session.getCodecHelper().readVarInt(in);
if (claimedUncompressedSize == 0) {
out.add(in.retain());
return;
}

if (size > MAX_UNCOMPRESSED_SIZE) {
throw new DecoderException("Badly compressed packet: size of " + size + " is larger than protocol maximum of " + MAX_UNCOMPRESSED_SIZE + ".");
}
}
if (validateDecompression) {
if (claimedUncompressedSize < this.session.getCompressionThreshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + this.session.getCompressionThreshold());
}

byte[] bytes = new byte[buf.readableBytes()];
buf.readBytes(bytes);
this.inflater.setInput(bytes);
byte[] inflated = new byte[size];
this.inflater.inflate(inflated);
out.add(Unpooled.wrappedBuffer(inflated));
this.inflater.reset();
if (claimedUncompressedSize > MAX_UNCOMPRESSED_SIZE) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is larger than protocol maximum of " + MAX_UNCOMPRESSED_SIZE);
}
}

ByteBuf uncompressed = ctx.alloc().directBuffer(claimedUncompressedSize);
try {
compression.inflate(in, uncompressed, claimedUncompressedSize);
out.add(uncompressed);
} catch (Exception e) {
uncompressed.release();
throw new DecoderException("Failed to decompress packet", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,49 +1,61 @@
package org.geysermc.mcprotocollib.network.tcp;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;

import java.util.List;

public class TcpPacketEncryptor extends ByteToMessageCodec<ByteBuf> {
public class TcpPacketEncryptor extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private final PacketEncryption encryption;
private byte[] decryptedArray = new byte[0];
private byte[] encryptedArray = new byte[0];

public TcpPacketEncryptor(PacketEncryption encryption) {
this.encryption = encryption;
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception {
int length = in.readableBytes();
byte[] bytes = this.getBytes(in);
int outLength = this.encryption.getEncryptOutputSize(length);
if (this.encryptedArray.length < outLength) {
this.encryptedArray = new byte[outLength];
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), msg);

int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();

try {
encryption.encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
throw new EncoderException("Error encrypting packet", e);
}

out.writeBytes(this.encryptedArray, 0, this.encryption.encrypt(bytes, 0, length, this.encryptedArray, 0));
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) throws Exception {
int length = buf.readableBytes();
byte[] bytes = this.getBytes(buf);
ByteBuf result = ctx.alloc().heapBuffer(this.encryption.getDecryptOutputSize(length));
result.writerIndex(this.encryption.decrypt(bytes, 0, length, result.array(), result.arrayOffset()));
out.add(result);
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), in).slice();

int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();

try {
encryption.decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
throw new DecoderException("Error decrypting packet", e);
}
}

private byte[] getBytes(ByteBuf buf) {
int length = buf.readableBytes();
if (this.decryptedArray.length < length) {
this.decryptedArray = new byte[length];
private ByteBuf ensureHeapBuffer(ByteBufAllocator alloc, ByteBuf buf) {
if (buf.hasArray()) {
return buf.retain();
} else {
ByteBuf heapBuf = alloc.heapBuffer(buf.readableBytes());
heapBuf.writeBytes(buf);
return heapBuf;
}

buf.readBytes(this.decryptedArray, 0, length);
return this.decryptedArray;
}
}
Loading

0 comments on commit bb38c8a

Please sign in to comment.