49

Netty源码分析 (十一)----- 拆包器之LengthFieldBasedFrameDecoder

 4 years ago
source link: https://www.tuicool.com/articles/viy2Efz
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

本篇文章主要是介绍使用LengthFieldBasedFrameDecoder解码器自定义协议。通常,协议的格式如下:

nARVR3V.png!web

LengthFieldBasedFrameDecoder是netty解决拆包粘包问题的一个重要的类,主要结构就是header+body结构。我们只需要传入正确的参数就可以发送和接收正确的数据,那么重点就在于这几个参数的意义。下面我们就具体了解一下这几个参数的意义。先来看一下LengthFieldBasedFrameDecoder主要的构造方法:

public LengthFieldBasedFrameDecoder(
            int maxFrameLength,
            int lengthFieldOffset, int lengthFieldLength,
            int lengthAdjustment, int initialBytesToStrip)

那么这几个重要的参数如下:

  • maxFrameLength:最大帧长度。也就是可以接收的数据的最大长度。如果超过,此次数据会被丢弃。
  • lengthFieldOffset:长度域偏移。就是说数据开始的几个字节可能不是表示数据长度,需要后移几个字节才是长度域。
  • lengthFieldLength:长度域字节数。用几个字节来表示数据长度。
  • lengthAdjustment:数据长度修正。因为长度域指定的长度可以使header+body的整个长度,也可以只是body的长度。如果表示header+body的整个长度,那么我们需要修正数据长度。
  • initialBytesToStrip:跳过的字节数。如果你需要接收header+body的所有数据,此值就是0,如果你只想接收body数据,那么需要跳过header所占用的字节数。

下面我们根据几个例子的使用来具体说明这几个参数的使用。

LengthFieldBasedFrameDecoder 的用法

需求1

长度域为2个字节,我们要求发送和接收的数据如下所示:

     发送的数据 (14 bytes)          接收到数据 (14 bytes)
+--------+----------------+      +--------+----------------+
| Length | Actual Content |----->| Length | Actual Content |
|  12    | "HELLO, WORLD" |      |   12   | "HELLO, WORLD" |
+--------+----------------+      +--------+----------------+

留心的你肯定发现了,长度域只是实际内容的长度,不包括长度域的长度。下面是参数的值:

  • lengthFieldOffset=0:开始的2个字节就是长度域,所以不需要长度域偏移。
  • lengthFieldLength=2:长度域2个字节。
  • lengthAdjustment=0:数据长度修正为0,因为长度域只包含数据的长度,所以不需要修正。
  • initialBytesToStrip=0:发送和接收的数据完全一致,所以不需要跳过任何字节。

需求2

长度域为2个字节,我们要求发送和接收的数据如下所示:

   发送的数据 (14 bytes)        接收到数据 (12 bytes)
+--------+----------------+      +----------------+
| Length | Actual Content |----->| Actual Content |
|  12    | "HELLO, WORLD" |      | "HELLO, WORLD" |
+--------+----------------+      +----------------+

参数值如下:

  • lengthFieldOffset=0:开始的2个字节就是长度域,所以不需要长度域偏移。
  • lengthFieldLength=2:长度域2个字节。
  • lengthAdjustment=0:数据长度修正为0,因为长度域只包含数据的长度,所以不需要修正。
  • initialBytesToStrip=2:我们发现接收的数据没有长度域的数据,所以要跳过长度域的2个字节。

需求3

长度域为2个字节,我们要求发送和接收的数据如下所示:

 BEFORE DECODE (14 bytes)         AFTER DECODE (14 bytes)
+--------+----------------+      +--------+----------------+
| Length | Actual Content |----->| Length | Actual Content |
| 14     | "HELLO, WORLD" |      |  14    | "HELLO, WORLD" |
+--------+----------------+      +--------+----------------+  

留心的你肯定又发现了,长度域表示的长度是总长度 也就是header+body的总长度。参数如下:

  • lengthFieldOffset=0:开始的2个字节就是长度域,所以不需要长度域偏移。
  • lengthFieldLength=2:长度域2个字节。
  • lengthAdjustment=-2:因为长度域为总长度,所以我们需要修正数据长度,也就是减去2。
  • initialBytesToStrip=0:我们发现接收的数据没有长度域的数据,所以要跳过长度域的2个字节。

需求4

长度域为2个字节,我们要求发送和接收的数据如下所示:

   BEFORE DECODE (17 bytes)                      AFTER DECODE (17 bytes)
+----------+----------+----------------+      +----------+----------+----------------+
| meta     |  Length  | Actual Content |----->| meta | Length | Actual Content |
|  0xCAFE  | 12       | "HELLO, WORLD" |      |  0xCAFE  | 12       | "HELLO, WORLD" |
+----------+----------+----------------+      +----------+----------+----------------+

我们发现,数据的结构有点变化,变成了 meta+header+body的结构。meta一般表示元数据,魔数等。我们定义这里meta有三个字节。参数如下:

  • lengthFieldOffset=3:开始的3个字节是meta,然后才是长度域,所以长度域偏移为3。
  • lengthFieldLength=2:长度域2个字节。
  • lengthAdjustment=0:长度域指定的长度位数据长度,所以数据长度不需要修正。
  • initialBytesToStrip=0:发送和接收数据相同,不需要跳过数据。

需求5

长度域为2个字节,我们要求发送和接收的数据如下所示:

    BEFORE DECODE (17 bytes)                      AFTER DECODE (17 bytes)
+----------+----------+----------------+      +----------+----------+----------------+
|  Length  | meta     | Actual Content |----->| Length | meta | Actual Content |
|   12     |  0xCAFE  | "HELLO, WORLD" |      |    12    |  0xCAFE  | "HELLO, WORLD" |
+----------+----------+----------------+      +----------+----------+----------------+

我们发现,数据的结构有点变化,变成了 header+meta+body的结构。meta一般表示元数据,魔数等。我们定义这里meta有三个字节。参数如下:

  • lengthFieldOffset=0:开始的2个字节就是长度域,所以不需要长度域偏移。
  • lengthFieldLength=2:长度域2个字节。
  • lengthAdjustment=3:我们需要把meta+body当做body处理,所以数据长度需要加3。
  • initialBytesToStrip=0:发送和接收数据相同,不需要跳过数据。

需求6

长度域为2个字节,我们要求发送和接收的数据如下所示:

    BEFORE DECODE (16 bytes)                    AFTER DECODE (13 bytes)
+------+--------+------+----------------+      +------+----------------+
| HDR1 | Length | HDR2 | Actual Content |----->| HDR2 | Actual Content |
| 0xCA | 0x000C | 0xFE | "HELLO, WORLD" |      | 0xFE | "HELLO, WORLD" |
+------+--------+------+----------------+      +------+----------------+

我们发现,数据的结构有点变化,变成了 hdr1+header+hdr2+body的结构。我们定义这里hdr1和hdr2都只有1个字节。参数如下:

  • lengthFieldOffset=1:开始的1个字节是长度域,所以需要设置长度域偏移为1。
  • lengthFieldLength=2:长度域2个字节。
  • lengthAdjustment=1:我们需要把hdr2+body当做body处理,所以数据长度需要加1。
  • initialBytesToStrip=3:接收数据不包括hdr1和长度域相同,所以需要跳过3个字节。

LengthFieldBasedFrameDecoder 源码剖析

实现拆包抽象

在前面的文章中我们知道,具体的拆包协议只需要实现

void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) 

其中 in 表示目前为止还未拆的数据,拆完之后的包添加到 out这个list中即可实现包向下传递,第一层实现比较简单

@Override
protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
    Object decoded = decode(ctx, in);
    if (decoded != null) {
        out.add(decoded);
    }
}

重载的protected函数decode做真正的拆包动作

protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
    if (this.discardingTooLongFrame) {
        long bytesToDiscard = this.bytesToDiscard;
        int localBytesToDiscard = (int)Math.min(bytesToDiscard, (long)in.readableBytes());
        in.skipBytes(localBytesToDiscard);
        bytesToDiscard -= (long)localBytesToDiscard;
        this.bytesToDiscard = bytesToDiscard;
        this.failIfNecessary(false);
    }

    // 如果当前可读字节还未达到长度长度域的偏移,那说明肯定是读不到长度域的,直接不读
    if (in.readableBytes() < this.lengthFieldEndOffset) {
        return null;
    } else {
        // 拿到长度域的实际字节偏移,就是长度域的开始下标
        // 这里就是需求4,开始的几个字节并不是长度域
        int actualLengthFieldOffset = in.readerIndex() + this.lengthFieldOffset;
        // 拿到实际的未调整过的包长度
        // 就是读取长度域的十进制值,最原始传过来的包的长度
        long frameLength = this.getUnadjustedFrameLength(in, actualLengthFieldOffset, this.lengthFieldLength, this.byteOrder);
        // 如果拿到的长度为负数,直接跳过长度域并抛出异常
        if (frameLength < 0L) {
            in.skipBytes(this.lengthFieldEndOffset);
            throw new CorruptedFrameException("negative pre-adjustment length field: " + frameLength);
        } else {
            // 调整包的长度
            frameLength += (long)(this.lengthAdjustment + this.lengthFieldEndOffset);
            // 整个数据包的长度还没有长度域长,直接抛出异常
            if (frameLength < (long)this.lengthFieldEndOffset) {
                in.skipBytes(this.lengthFieldEndOffset);
                throw new CorruptedFrameException("Adjusted frame length (" + frameLength + ") is less " + "than lengthFieldEndOffset: " + this.lengthFieldEndOffset);
            // 数据包长度超出最大包长度,进入丢弃模式
            } else if (frameLength > (long)this.maxFrameLength) {
                long discard = frameLength - (long)in.readableBytes();
                this.tooLongFrameLength = frameLength;
                if (discard < 0L) {
                    in.skipBytes((int)frameLength);
                } else {
                    this.discardingTooLongFrame = true;
                    this.bytesToDiscard = discard;
                    in.skipBytes(in.readableBytes());
                }

                this.failIfNecessary(true);
                return null;
            } else {
                int frameLengthInt = (int)frameLength;
                //当前可读的字节数小于包中的length,什么都不做,等待下一次解码
                if (in.readableBytes() < frameLengthInt) {
                    return null;
                //跳过的字节不能大于数据包的长度,否则就抛出 CorruptedFrameException 的异常
                } else if (this.initialBytesToStrip > frameLengthInt) {
                    in.skipBytes(frameLengthInt);
                    throw new CorruptedFrameException("Adjusted frame length (" + frameLength + ") is less " + "than initialBytesToStrip: " + this.initialBytesToStrip);
                } else {
                    //根据initialBytesToStrip的设置来跳过某些字节
                    in.skipBytes(this.initialBytesToStrip);
                    //拿到当前累积数据的读指针
                    int readerIndex = in.readerIndex();
                    //拿到待抽取数据包的实际长度
                    int actualFrameLength = frameLengthInt - this.initialBytesToStrip;
                    //进行抽取
                    ByteBuf frame = this.extractFrame(ctx, in, readerIndex, actualFrameLength);
                    //移动读指针
                    in.readerIndex(readerIndex + actualFrameLength);
                    return frame;
                }
            }
        }
    }
}

下面分几个部分来分析一下这个重量级函数

获取frame长度

获取需要待拆包的包大小

// 拿到长度域的实际字节偏移,就是长度域的开始下标
// 这里就是需求4,开始的几个字节并不是长度域
int actualLengthFieldOffset = in.readerIndex() + this.lengthFieldOffset;
// 拿到实际的未调整过的包长度
// 就是读取长度域的十进制值,最原始传过来的包的长度
long frameLength = this.getUnadjustedFrameLength(in, actualLengthFieldOffset, this.lengthFieldLength, this.byteOrder);
// 调整包的长度
frameLength += (long)(this.lengthAdjustment + this.lengthFieldEndOffset);

上面这一段内容有个扩展点 getUnadjustedFrameLength,如果你的长度域代表的值表达的含义不是正常的int,short等基本类型,你可以重写这个函数

protected long getUnadjustedFrameLength(ByteBuf buf, int offset, int length, ByteOrder order) {
    buf = buf.order(order);
    long frameLength;
    switch (length) {
    case 1:
        frameLength = buf.getUnsignedByte(offset);
        break;
    case 2:
        frameLength = buf.getUnsignedShort(offset);
        break;
    case 3:
        frameLength = buf.getUnsignedMedium(offset);
        break;
    case 4:
        frameLength = buf.getUnsignedInt(offset);
        break;
    case 8:
        frameLength = buf.getLong(offset);
        break;
    default:
        throw new DecoderException(
                "unsupported lengthFieldLength: " + lengthFieldLength + " (expected: 1, 2, 3, 4, or 8)");
    }
    return frameLength;
}

跳过指定字节长度

int frameLengthInt = (int)frameLength;
//当前可读的字节数小于包中的length,什么都不做,等待下一次解码
if (in.readableBytes() < frameLengthInt) {
    return null;
//跳过的字节不能大于数据包的长度,否则就抛出 CorruptedFrameException 的异常
} else if (this.initialBytesToStrip > frameLengthInt) {
    in.skipBytes(frameLengthInt);
    throw new CorruptedFrameException("Adjusted frame length (" + frameLength + ") is less " + "than initialBytesToStrip: " + this.initialBytesToStrip);
}
//根据initialBytesToStrip的设置来跳过某些字节
in.skipBytes(this.initialBytesToStrip);

先验证当前是否已经读到足够的字节,如果读到了,在下一步抽取一个完整的数据包之前,需要根据initialBytesToStrip的设置来跳过某些字节(见文章开篇),当然,跳过的字节不能大于数据包的长度,否则就抛出 CorruptedFrameException 的异常

抽取frame

//根据initialBytesToStrip的设置来跳过某些字节
in.skipBytes(this.initialBytesToStrip);
//拿到当前累积数据的读指针
int readerIndex = in.readerIndex();
//拿到待抽取数据包的实际长度
int actualFrameLength = frameLengthInt - this.initialBytesToStrip;
//进行抽取
ByteBuf frame = this.extractFrame(ctx, in, readerIndex, actualFrameLength);
//移动读指针
in.readerIndex(readerIndex + actualFrameLength);
return frame;

到了最后抽取数据包其实就很简单了,拿到当前累积数据的读指针,然后拿到待抽取数据包的实际长度进行抽取,抽取之后,移动读指针

protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) {
    return buffer.retainedSlice(index, length);
}

抽取的过程是简单的调用了一下 ByteBuf 的retainedSliceapi,该api无内存copy开销

自定义解码器

协议实体的定义

public class MyProtocolBean {
    //类型  系统编号 0xA 表示A系统,0xB 表示B系统
    private byte type;

    //信息标志  0xA 表示心跳包    0xB 表示超时包  0xC 业务信息包
    private byte flag;

    //内容长度
    private int length;

    //内容
    private String content;

    //省略get/set
}

服务器端

服务端的实现

public class Server {

    private static final int MAX_FRAME_LENGTH = 1024 * 1024;  //最大长度
    private static final int LENGTH_FIELD_LENGTH = 4;  //长度字段所占的字节数
    private static final int LENGTH_FIELD_OFFSET = 2;  //长度偏移
    private static final int LENGTH_ADJUSTMENT = 0;
    private static final int INITIAL_BYTES_TO_STRIP = 0;

    private int port;

    public Server(int port) {
        this.port = port;
    }

    public void start(){
        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap sbs = new ServerBootstrap().group(bossGroup,workerGroup).channel(NioServerSocketChannel.class).localAddress(new InetSocketAddress(port))
                    .childHandler(new ChannelInitializer<SocketChannel>() {

                        protected void initChannel(SocketChannel ch) throws Exception {
                            ch.pipeline().addLast(new MyProtocolDecoder(MAX_FRAME_LENGTH,LENGTH_FIELD_OFFSET,LENGTH_FIELD_LENGTH,LENGTH_ADJUSTMENT,INITIAL_BYTES_TO_STRIP,false));
                            ch.pipeline().addLast(new ServerHandler());
                        };

                    }).option(ChannelOption.SO_BACKLOG, 128)
                    .childOption(ChannelOption.SO_KEEPALIVE, true);
            // 绑定端口,开始接收进来的连接
            ChannelFuture future = sbs.bind(port).sync();

            System.out.println("Server start listen at " + port );
            future.channel().closeFuture().sync();
        } catch (Exception e) {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }

    public static void main(String[] args) throws Exception {
        int port;
        if (args.length > 0) {
            port = Integer.parseInt(args[0]);
        } else {
            port = 8080;
        }
        new Server(port).start();
    }
}

自定义解码器MyProtocolDecoder

public class MyProtocolDecoder extends LengthFieldBasedFrameDecoder {

    private static final int HEADER_SIZE = 6;

    /**
     * @param maxFrameLength  帧的最大长度
     * @param lengthFieldOffset length字段偏移的地址
     * @param lengthFieldLength length字段所占的字节长
     * @param lengthAdjustment 修改帧数据长度字段中定义的值,可以为负数 因为有时候我们习惯把头部记入长度,若为负数,则说明要推后多少个字段
     * @param initialBytesToStrip 解析时候跳过多少个长度
     * @param failFast 为true,当frame长度超过maxFrameLength时立即报TooLongFrameException异常,为false,读取完整个帧再报异
     */

    public MyProtocolDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip, boolean failFast) {

        super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip, failFast);

    }

    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        //在这里调用父类的方法,实现指得到想要的部分,我在这里全部都要,也可以只要body部分
        in = (ByteBuf) super.decode(ctx,in);  

        if(in == null){
            return null;
        }
        if(in.readableBytes()<HEADER_SIZE){
            throw new Exception("字节数不足");
        }
        //读取type字段
        byte type = in.readByte();
        //读取flag字段
        byte flag = in.readByte();
        //读取length字段
        int length = in.readInt();
        
        if(in.readableBytes()!=length){
            throw new Exception("标记的长度不符合实际长度");
        }
        //读取body
        byte []bytes = new byte[in.readableBytes()];
        in.readBytes(bytes);

        return new MyProtocolBean(type,flag,length,new String(bytes,"UTF-8"));

    }
}

服务端Hanlder

public class ServerHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        MyProtocolBean myProtocolBean = (MyProtocolBean)msg;  //直接转化成协议消息实体
        System.out.println(myProtocolBean.getContent());
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        super.channelActive(ctx);
    }
}

客户端和客户端Handler

public class Client {
    static final String HOST = System.getProperty("host", "127.0.0.1");
    static final int PORT = Integer.parseInt(System.getProperty("port", "8080"));
    static final int SIZE = Integer.parseInt(System.getProperty("size", "256"));

    public static void main(String[] args) throws Exception {

        // Configure the client.
        EventLoopGroup group = new NioEventLoopGroup();

        try {
            Bootstrap b = new Bootstrap();
            b.group(group)
                    .channel(NioSocketChannel.class)
                    .option(ChannelOption.TCP_NODELAY, true)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        public void initChannel(SocketChannel ch) throws Exception {
                            ch.pipeline().addLast(new MyProtocolEncoder());
                            ch.pipeline().addLast(new ClientHandler());
                        }
                    });

            ChannelFuture future = b.connect(HOST, PORT).sync();
            future.channel().closeFuture().sync();
        } finally {
            group.shutdownGracefully();
        }
    }

}

客户端编码器

public class MyProtocolEncoder extends MessageToByteEncoder<MyProtocolBean> {

    @Override
    protected void encode(ChannelHandlerContext ctx, MyProtocolBean msg, ByteBuf out) throws Exception {
        if(msg == null){
            throw new Exception("msg is null");
        }
        out.writeByte(msg.getType());
        out.writeByte(msg.getFlag());
        out.writeInt(msg.getLength());
        out.writeBytes(msg.getContent().getBytes(Charset.forName("UTF-8")));
    }
}
  • 编码的时候,只需要按照定义的顺序依次写入到ByteBuf中.

客户端Handler

public class ClientHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        super.channelRead(ctx, msg);
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {

        MyProtocolBean myProtocolBean = new MyProtocolBean((byte)0xA, (byte)0xC, "Hello,Netty".length(), "Hello,Netty");
        ctx.writeAndFlush(myProtocolBean);

    }
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK