`

protobuf在netty里面的应用举例

 
阅读更多

netty为protobuf提供了两个编码器(ProtobufEncoder,ProtobufVarint32LengthFieldPrepender),两个解码器(ProtobufVarint32FrameDecoder,ProtobufDecoder)
[注]所谓的编码就是把应用程序使用的数据类型编码成在网络上传输的二进制字节流,反之同理。
看一个netty官网上提供的一个使用protobuf的例子:
LocalTimeProtocol.proto文件:

[java] view plaincopy
  1. package org.jboss.netty.example.localtime;  
  2. option optimize_for = SPEED;  
  3. enum Continent {  
  4.   AFRICA = 0;  
  5.   AMERICA = 1;  
  6.   ANTARCTICA = 2;  
  7.   ARCTIC = 3;  
  8.   ASIA = 4;  
  9.   ATLANTIC = 5;  
  10.   AUSTRALIA = 6;  
  11.   EUROPE = 7;  
  12.   INDIAN = 8;  
  13.   MIDEAST = 9;  
  14.   PACIFIC = 10;  
  15. }  
  16. message Location {  
  17.   required Continent continent = 1;  
  18.   required string city = 2;  
  19. }  
  20. message Locations {  
  21.   repeated Location location = 1;  
  22. }  
  23. enum DayOfWeek {  
  24.   SUNDAY = 1;  
  25.   MONDAY = 2;  
  26.   TUESDAY = 3;  
  27.   WEDNESDAY = 4;  
  28.   THURSDAY = 5;  
  29.   FRIDAY = 6;  
  30.   SATURDAY = 7;  
  31. }  
  32. message LocalTime {  
  33.   required uint32 year = 1;  
  34.   required uint32 month = 2;  
  35.   required uint32 dayOfMonth = 4;  
  36.   required DayOfWeek dayOfWeek = 5;  
  37.   required uint32 hour = 6;  
  38.   required uint32 minute = 7;  
  39.   required uint32 second = 8;  
  40. }  
  41. message LocalTimes {  
  42.   repeated LocalTime localTime = 1;  
  43. }  

 

客户端:

[java] view plaincopy
  1. public class LocalTimeClient {  
  2.   
  3.     public static void main(String[] args) throws Exception {  
  4.         // Parse options.  
  5.         String host = "localhost";  
  6.         int port = 8080;  
  7.         Collection<String> cities = new ArrayList<String>(){  
  8.   private static final long serialVersionUID = 1L;  
  9.   {  
  10.           add("America/New_York");  
  11.           add("Asia/Seoul");  
  12.          }  
  13.         };  
  14.         // Set up.  
  15.         ClientBootstrap bootstrap = new ClientBootstrap(  
  16.                 new NioClientSocketChannelFactory(  
  17.                         Executors.newCachedThreadPool(),  
  18.                         Executors.newCachedThreadPool()));  
  19.   
  20.         // Configure the event pipeline factory.  
  21.         bootstrap.setPipelineFactory(new LocalTimeClientPipelineFactory());  
  22.   
  23.         // Make a new connection.  
  24.         ChannelFuture connectFuture =  
  25.             bootstrap.connect(new InetSocketAddress(host, port));  
  26.   
  27.         // Wait until the connection is made successfully.  
  28.         Channel channel = connectFuture.awaitUninterruptibly().getChannel();  
  29.   
  30.         // Get the handler instance to initiate the request.  
  31.         LocalTimeClientHandler handler =  
  32.             channel.getPipeline().get(LocalTimeClientHandler.class);  
  33.   
  34.         // Request and get the response.  
  35.         List<String> response = handler.getLocalTimes(cities);  
  36.         // Close the connection.  
  37.         channel.close().awaitUninterruptibly();  
  38.   
  39.         // Shut down all thread pools to exit.  
  40.         bootstrap.releaseExternalResources();  
  41.   
  42.         // Print the response at last but not least.  
  43.         Iterator<String> i1 = cities.iterator();  
  44.         Iterator<String> i2 = response.iterator();  
  45.         while (i1.hasNext()) {  
  46.             System.out.format("%28s: %s%n", i1.next(), i2.next());  
  47.         }  
  48.     }  
  49. }  
[java] view plaincopy
  1. public class LocalTimeClientPipelineFactory implements ChannelPipelineFactory {  
  2.   
  3.     public ChannelPipeline getPipeline() throws Exception {  
  4.         ChannelPipeline p = pipeline();  
  5.  //解码用  
  6.         p.addLast("frameDecoder"new ProtobufVarint32FrameDecoder());  
  7.         //构造函数传递要解码成的类型  
  8.         p.addLast("protobufDecoder"new ProtobufDecoder(LocalTimeProtocol.LocalTimes.getDefaultInstance()));  
  9.  //编码用  
  10.         p.addLast("frameEncoder"new ProtobufVarint32LengthFieldPrepender());  
  11.         p.addLast("protobufEncoder"new ProtobufEncoder());  
  12.  //业务逻辑用  
  13.         p.addLast("handler"new LocalTimeClientHandler());  
  14.         return p;  
  15.     }  
  16. }  


 

[java] view plaincopy
  1. public class LocalTimeClientHandler extends SimpleChannelUpstreamHandler {  
  2.   
  3.     private static final Logger logger = Logger.getLogger(  
  4.             LocalTimeClientHandler.class.getName());  
  5.   
  6.     // Stateful properties  
  7.     private volatile Channel channel;  
  8.     //用来存储服务端返回的结果  
  9.     private final BlockingQueue<LocalTimes> answer = new LinkedBlockingQueue<LocalTimes>();  
  10.   
  11.     public List<String> getLocalTimes(Collection<String> cities) {  
  12.         Locations.Builder builder = Locations.newBuilder();  
  13.  //构造传输给服务端的Locations对象  
  14.         for (String c: cities) {  
  15.             String[] components = c.split("/");  
  16.             builder.addLocation(Location.newBuilder().  
  17.                 setContinent(Continent.valueOf(components[0].toUpperCase())).  
  18.                 setCity(components[1]).build());  
  19.         }  
  20.   
  21.         channel.write(builder.build());  
  22.   
  23.         LocalTimes localTimes;  
  24.         boolean interrupted = false;  
  25.         for (;;) {  
  26.             try {  
  27.   //从queue里面得到的,也就是服务端传过来的LocalTimes。  
  28.                 localTimes = answer.take();  
  29.                 break;  
  30.             } catch (InterruptedException e) {  
  31.                 interrupted = true;  
  32.             }  
  33.         }  
  34.   
  35.         if (interrupted) {  
  36.             Thread.currentThread().interrupt();  
  37.         }  
  38.   
  39.         List<String> result = new ArrayList<String>();  
  40.         for (LocalTime lt: localTimes.getLocalTimeList()) {  
  41.             result.add(  
  42.                     new Formatter().format(  
  43.                             "%4d-%02d-%02d %02d:%02d:%02d %s",  
  44.                             lt.getYear(),  
  45.                             lt.getMonth(),  
  46.                             lt.getDayOfMonth(),  
  47.                             lt.getHour(),  
  48.                             lt.getMinute(),  
  49.                             lt.getSecond(),  
  50.                             lt.getDayOfWeek().name()).toString());  
  51.         }  
  52.   
  53.         return result;  
  54.     }  
  55.   
  56.     @Override  
  57.     public void handleUpstream(  
  58.             ChannelHandlerContext ctx, ChannelEvent e) throws Exception {  
  59.         if (e instanceof ChannelStateEvent) {  
  60.             logger.info(e.toString());  
  61.         }  
  62.         super.handleUpstream(ctx, e);  
  63.     }  
  64.   
  65.     @Override  
  66.     public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e)  
  67.             throws Exception {  
  68.         channel = e.getChannel();  
  69.         super.channelOpen(ctx, e);  
  70.     }  
  71.   
  72.     @Override  
  73.     public void messageReceived(  
  74.             ChannelHandlerContext ctx, final MessageEvent e) {  
  75.  //收到服务端返回的消息,已经解码成了LocalTimes类型  
  76.         boolean offered = answer.offer((LocalTimes) e.getMessage());  
  77.         assert offered;  
  78.     }  
  79.   
  80.     @Override  
  81.     public void exceptionCaught(  
  82.             ChannelHandlerContext ctx, ExceptionEvent e) {  
  83.         logger.log(  
  84.                 Level.WARNING,  
  85.                 "Unexpected exception from downstream.",  
  86.                 e.getCause());  
  87.         e.getChannel().close();  
  88.     }  
  89. }  


服务端的处理:

[java] view plaincopy
  1. public class LocalTimeServer {  
  2.   
  3.     public static void main(String[] args) throws Exception {  
  4.         // Configure the server.  
  5.         ServerBootstrap bootstrap = new ServerBootstrap(  
  6.                 new NioServerSocketChannelFactory(  
  7.                         Executors.newCachedThreadPool(),  
  8.                         Executors.newCachedThreadPool()));  
  9.   
  10.         // Set up the event pipeline factory.  
  11.         bootstrap.setPipelineFactory(new LocalTimeServerPipelineFactory());  
  12.   
  13.         // Bind and start to accept incoming connections.  
  14.         bootstrap.bind(new InetSocketAddress(8080));  
  15.     }  
  16. }  
  17. public class LocalTimeServerPipelineFactory implements ChannelPipelineFactory {  
  18.   
  19.     public ChannelPipeline getPipeline() throws Exception {  
  20.         ChannelPipeline p = pipeline();  
  21.  //解码  
  22.         p.addLast("frameDecoder"new ProtobufVarint32FrameDecoder());  
  23.  //构造函数传递要解码成的类型  
  24.         p.addLast("protobufDecoder"new ProtobufDecoder(LocalTimeProtocol.Locations.getDefaultInstance()));  
  25.  //编码  
  26.         p.addLast("frameEncoder"new ProtobufVarint32LengthFieldPrepender());  
  27.         p.addLast("protobufEncoder"new ProtobufEncoder());  
  28.  //业务逻辑处理  
  29.         p.addLast("handler"new LocalTimeServerHandler());  
  30.         return p;  
  31.     }  
  32. }  


 

[java] view plaincopy
  1. public class LocalTimeServerHandler extends SimpleChannelUpstreamHandler {  
  2.   
  3.     private static final Logger logger = Logger.getLogger(  
  4.             LocalTimeServerHandler.class.getName());  
  5.   
  6.     @Override  
  7.     public void handleUpstream(  
  8.             ChannelHandlerContext ctx, ChannelEvent e) throws Exception {  
  9.         if (e instanceof ChannelStateEvent) {  
  10.             logger.info(e.toString());  
  11.         }  
  12.         super.handleUpstream(ctx, e);  
  13.     }  
  14.   
  15.     @Override  
  16.     public void messageReceived(  
  17.             ChannelHandlerContext ctx, MessageEvent e) {  
  18.  //收到的消息是Locations  
  19.         Locations locations = (Locations) e.getMessage();  
  20.         long currentTime = System.currentTimeMillis();  
  21.   
  22.         LocalTimes.Builder builder = LocalTimes.newBuilder();  
  23.         for (Location l: locations.getLocationList()) {  
  24.             TimeZone tz = TimeZone.getTimeZone(  
  25.                     toString(l.getContinent()) + '/' + l.getCity());  
  26.             Calendar calendar = Calendar.getInstance(tz);  
  27.             calendar.setTimeInMillis(currentTime);  
  28.   
  29.             builder.addLocalTime(LocalTime.newBuilder().  
  30.                     setYear(calendar.get(YEAR)).  
  31.                     setMonth(calendar.get(MONTH) + 1).  
  32.                     setDayOfMonth(calendar.get(DAY_OF_MONTH)).  
  33.                     setDayOfWeek(DayOfWeek.valueOf(calendar.get(DAY_OF_WEEK))).  
  34.                     setHour(calendar.get(HOUR_OF_DAY)).  
  35.                     setMinute(calendar.get(MINUTE)).  
  36.                     setSecond(calendar.get(SECOND)).build());  
  37.         }  
  38.  //返回LocalTimes  
  39.         e.getChannel().write(builder.build());  
  40.     }  
  41.   
  42.     @Override  
  43.     public void exceptionCaught(  
  44.             ChannelHandlerContext ctx, ExceptionEvent e) {  
  45.         logger.log(  
  46.                 Level.WARNING,  
  47.                 "Unexpected exception from downstream.",  
  48.                 e.getCause());  
  49.         e.getChannel().close();  
  50.     }  
  51.   
  52.     private static String toString(Continent c) {  
  53.         return "" + c.name().charAt(0) + c.name().toLowerCase().substring(1);  
  54.     }  
  55. }  

 

从这个例子中也可以看出来,netty已经把所有的protobuf的细节给封装过了,我们现在就看一下netty是如何发送和接受protobuf数据的。
先看一下ProtobufEncoder,

[java] view plaincopy
  1. @Override  
  2. protected Object encode(  
  3.             ChannelHandlerContext ctx, Channel channel, Object msg) throws Exception {  
  4.         if (!(msg instanceof MessageLite)) {  
  5.             return msg;  
  6.         }  
  7.         return wrappedBuffer(((MessageLite) msg).toByteArray());  
  8. }  

encode方法很简单,实际上它会调用protobuf的api,把消息编码成protobuf格式的字节数组。
然后看一下ProtobufVarint32LengthFieldPrepender:
它会在原来的数据的前面,追加一个使用Base 128 Varints编码过的length:
 BEFORE DECODE (300 bytes)       AFTER DECODE (302 bytes)
 +---------------+               +--------+---------------+
 | Protobuf Data |-------------->| Length | Protobuf Data |
 |  (300 bytes)  |               | 0xAC02 |  (300 bytes)  |
 +---------------+               +--------+---------------+
因此,netty实际上只做了这么一点工作,其余的全部都是protobuf自己完成的。

[java] view plaincopy
  1. @Override  
  2. protected Object encode(ChannelHandlerContext ctx, Channel channel,  
  3.             Object msg) throws Exception {  
  4.         if (!(msg instanceof ChannelBuffer)) {  
  5.             return msg;  
  6.         }  
  7.         ChannelBuffer body = (ChannelBuffer) msg;  
  8.         int length = body.readableBytes();  
  9.  //header使用跟body同样的字节序,容量是length这个整数所占的字节数  
  10.         ChannelBuffer header =  
  11.             channel.getConfig().getBufferFactory().getBuffer(  
  12.                     body.order(),  
  13.                     CodedOutputStream.computeRawVarint32Size(length));  
  14.         CodedOutputStream codedOutputStream = CodedOutputStream  
  15.                 .newInstance(new ChannelBufferOutputStream(header));  
  16.  //把length按照Base 128 Varints的方式写入header里面  
  17.         codedOutputStream.writeRawVarint32(length);  
  18.         codedOutputStream.flush();  
  19.  //把header和body组合到一块  
  20.         return wrappedBuffer(header, body);  
  21. }  

唯一值得一看的是:

[java] view plaincopy
  1. //计算一个整数在varint编码下所占的字节数,  
  2. public static int computeRawVarint32Size(final int value) {  
  3.     if ((value & (0xffffffff <<  7)) == 0return 1;  
  4.     if ((value & (0xffffffff << 14)) == 0return 2;  
  5.     if ((value & (0xffffffff << 21)) == 0return 3;  
  6.     if ((value & (0xffffffff << 28)) == 0return 4;  
  7.     return 5;  
  8. }  
  9. public void writeRawVarint32(int value) throws IOException {  
  10.     while (true) {  
  11.       if ((value & ~0x7F) == 0) {//如果最高位是0  
  12.         writeRawByte(value);//直接写  
  13.         return;  
  14.       } else {  
  15.         writeRawByte((value & 0x7F) | 0x80);//先写入低7位,最高位置1  
  16.         value >>>= 7;//再写高7位  
  17.       }  
  18.     }  
  19. }  

 

解码的过程无非就是先读出length来,根据length读取出所有的数据来,交给protobuf就能还原消息出来。
看一下具体的解码过程:
ProtobufVarint32FrameDecoder:

[java] view plaincopy
  1. protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception {  
  2.         buffer.markReaderIndex();  
  3.         final byte[] buf = new byte[5];//存放长度,最多也就5个字节  
  4.         for (int i = 0; i < buf.length; i ++) {  
  5.             if (!buffer.readable()) {  
  6.                 buffer.resetReaderIndex();  
  7.                 return null;  
  8.             }  
  9.   
  10.             buf[i] = buffer.readByte();  
  11.             if (buf[i] >= 0) {  
  12.   //读取长度  
  13.                 int length = CodedInputStream.newInstance(buf, 0, i + 1).readRawVarint32();  
  14.                 if (length < 0) {  
  15.                     throw new CorruptedFrameException("negative length: " + length);  
  16.                 }  
  17.   
  18.                 if (buffer.readableBytes() < length) {  
  19.                     buffer.resetReaderIndex();  
  20.                     return null;  
  21.                 } else {  
  22.       //读取数据  
  23.                     return buffer.readBytes(length);  
  24.                 }  
  25.             }  
  26.         }  
  27.   
  28.         // Couldn't find the byte whose MSB is off.  
  29.         throw new CorruptedFrameException("length wider than 32-bit");  
  30. }  

 

看一下readRawVarint32:

[java] view plaincopy
  1. public int readRawVarint32() throws IOException {  
  2.     byte tmp = readRawByte();//先读一个字节  
  3.     if (tmp >= 0) {  
  4.       return tmp;  
  5.     }  
  6.     int result = tmp & 0x7f;//长度大于一个字节  
  7.     if ((tmp = readRawByte()) >= 0) {  
  8.       result |= tmp << 7;  
  9.     } else {  
  10.       result |= (tmp & 0x7f) << 7;  
  11.       if ((tmp = readRawByte()) >= 0) {  
  12.         result |= tmp << 14;  
  13.       } else {  
  14.         result |= (tmp & 0x7f) << 14;  
  15.         if ((tmp = readRawByte()) >= 0) {  
  16.           result |= tmp << 21;  
  17.         } else {  
  18.           result |= (tmp & 0x7f) << 21;  
  19.           result |= (tmp = readRawByte()) << 28;  
  20.           if (tmp < 0) {  
  21.             // Discard upper 32 bits.  
  22.             for (int i = 0; i < 5; i++) {  
  23.               if (readRawByte() >= 0) {  
  24.                 return result;  
  25.               }  
  26.             }  
  27.             throw InvalidProtocolBufferException.malformedVarint();  
  28.           }  
  29.         }  
  30.       }  
  31.     }  
  32.     return result;  
  33.   }  
  34. ProtobufDecoder:  
  35. @Override  
  36.     protected Object decode(  
  37.             ChannelHandlerContext ctx, Channel channel, Object msg) throws Exception {  
  38.         if (!(msg instanceof ChannelBuffer)) {  
  39.             return msg;  
  40.         }  
  41.   
  42.         ChannelBuffer buf = (ChannelBuffer) msg;  
  43.         if (buf.hasArray()) {  
  44.             final int offset = buf.readerIndex();  
  45.             if(extensionRegistry == null) {  
  46.   //从字节数组里面还原出消息,上层收到的就是构造函数里面传递进来的类型  
  47.                 return prototype.newBuilderForType().mergeFrom(  
  48.                         buf.array(), buf.arrayOffset() + offset, buf.readableBytes()).build();  
  49.             } else {  
  50.                 return prototype.newBuilderForType().mergeFrom(  
  51.                         buf.array(), buf.arrayOffset() + offset, buf.readableBytes(), extensionRegistry).build();  
  52.             }  
  53.         } else {  
  54.             if (extensionRegistry == null) {  
  55.                 return prototype.newBuilderForType().mergeFrom(  
  56.                         new ChannelBufferInputStream((ChannelBuffer) msg)).build();  
  57.             } else {  
  58.                 return prototype.newBuilderForType().mergeFrom(  
  59.                         new ChannelBufferInputStream((ChannelBuffer) msg), extensionRegistry).build();  
  60.             }  
  61.         }  
  62.     }  

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics