# netty-websocket-server **Repository Path**: ichiva_admin/netty-websocket-server ## Basic Information - **Project Name**: netty-websocket-server - **Description**: 封装netty用于快速创建websocket服务器 - **Primary Language**: Unknown - **License**: MulanPSL-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 4 - **Forks**: 2 - **Created**: 2022-01-11 - **Last Updated**: 2023-09-08 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # netty-websocket-server ## 介绍 封装netty用于快速创建websocket服务器 ## 快速入门 对抽象的FastNettyWebSocketServer进行实现并开启监听 ```aidl public static void main(String[] args) { new FastNettyWebSocketServer(){ @Override public void onMessage(WebSocketSession session, String message) { System.out.println("收到:" + message); send(session,"收到,over"); } }.start(8080); } ``` ## 实现细节 - 定义WebSocketServer接口 ```aidl public interface WebSocketServer { /** * 连接 * 本方法调用时还未创建websocket连接 */ default void onOpen(WebSocketSession session){ } /** * 收到消息 */ void onMessage(WebSocketSession session, String message); /** * 收到消息 */ default void onMessage(WebSocketSession session, byte[] message) { } /** * 发送消息 */ default void send(WebSocketSession session, String message){ session.getChannel().writeAndFlush( new TextWebSocketFrame(message) ); } /** * 发送消息 */ default void send(WebSocketSession session, byte[] message){ session.getChannel().writeAndFlush(new BinaryWebSocketFrame( Unpooled.buffer().writeBytes(message) )); } /** * 关闭连接 */ default void onClose(WebSocketSession session) { } /** * 发生错误 */ default void onError(WebSocketSession session, Throwable e){ } } ``` - 定义配置文件 ```aidl public interface NettyWebsocketServerConfig { /** * 默认工作线程池等于物理核心数2被 */ default NioEventLoopGroup getWorkerGroup(){ return new NioEventLoopGroup(); } /** * 默认一个调度线程 */ default NioEventLoopGroup getBoosGroup(){ return new NioEventLoopGroup(1); } ChannelHandler getChildHandler(); /** * 默认端口8080 */ default int getPort(){ return 8080; } } ``` - WebSocketServer (核心)实现,用于启动netty和关闭netty ```aidl public abstract class FastNettyWebSocketServer implements WebSocketServer { private Channel serverChannel; private NettyWebsocketServerConfig config; public void start(NettyWebsocketServerConfig config) { this.config = config; ServerBootstrap server = new ServerBootstrap(); server.group(config.getBoosGroup(), config.getWorkerGroup()); server.channel(NioServerSocketChannel.class); server.childHandler(config.getChildHandler()); ChannelFuture future = server.bind(config.getPort()); future.addListener(f -> { if (f.isDone() && f.isSuccess()) { this.serverChannel = future.channel(); log.info("Start ws server success"); log.info("boos group thread number {}", config.getBoosGroup().executorCount()); log.info("worker group thread number {}", config.getWorkerGroup().executorCount()); } if (f.isDone() && f.cause() != null) { log.error("Start ws server fail throw={}", f.cause().getMessage()); future.channel().close(); } }); } public void start(final int port) { start(new NettyWebsocketServerConfig() { @Override public ChannelHandler getChildHandler() { return new WebSocketChannelInitializer(FastNettyWebSocketServer.this); } @Override public int getPort() { return port; } }); } public void start() { start(8080); } public void stop() { if (serverChannel != null && serverChannel.isOpen()) { final int waitSec = 10; CountDownLatch latch = new CountDownLatch(1); serverChannel.close().addListener(f -> { config.getWorkerGroup().schedule(() -> { log.info("Shutdown dispatcher success..."); config.getWorkerGroup().shutdownGracefully(); latch.countDown(); }, waitSec - 2, TimeUnit.SECONDS); log.info("Close ws server socket success={}", f.isSuccess()); config.getBoosGroup().shutdownGracefully(); }); try { boolean flag = latch.await(waitSec, TimeUnit.SECONDS); if(!flag){ log.warn("Shutdown ws server timeout"); } } catch (InterruptedException e) { log.warn("Shutdown ws server interrupted exception={}", e.getMessage()); } } } } ``` - 默认的通道实现 ```aidl public class WebSocketChannelInitializer extends ChannelInitializer { private final WebSocketServer server; public WebSocketChannelInitializer(WebSocketServer server){ this.server = server; } @Override protected void initChannel(SocketChannel ch) { //二进制流在通道中被处理 ChannelPipeline pipeline = ch.pipeline(); // HttpRequestDecoder和HttpResponseEncoder的一个组合,针对http协议进行编解码 pipeline.addLast("httpServerCodec", new HttpServerCodec());//设置解码器 //分块向客户端写数据,防止发送大文件时导致内存溢出, channel.write(new ChunkedFile(new File("bigFile.mkv"))) pipeline.addLast(new ChunkedWriteHandler());//用于大数据的分区传输 // 将HttpMessage和HttpContents聚合到一个完成的 FullHttpRequest或FullHttpResponse中 // 具体是FullHttpRequest对象还是FullHttpResponse对象取决于是请求还是响应 // 需要放到HttpServerCodec这个处理器后面 pipeline.addLast(new HttpObjectAggregator(1024 * 2));//聚合器,使用websocket会用到 // webSocket 数据压缩扩展,当添加这个的时候WebSocketServerProtocolHandler的第三个参数需要设置成true pipeline.addLast(new WebSocketServerCompressionHandler()); // 服务器端向外暴露的 web socket 端点,当客户端传递比较大的对象时,maxFrameSize参数的值需要调大 pipeline.addLast(new WebSocketServerAuthProtocolHandler("/", null, true, 65536,server)); pipeline.addLast(new LengthFieldPrepender(4)); // 业务代码 pipeline.addLast(new WebSocketServerChannelInboundHandler(server)); } } ``` - 提供session支持 ```aidl public class WebSocketServerChannelInboundHandler extends SimpleChannelInboundHandler { private final WebSocketServer webSocketServer; public WebSocketServerChannelInboundHandler(WebSocketServer webSocketServer){ this.webSocketServer = webSocketServer; } @Override protected void channelRead0(ChannelHandlerContext ctx, Object msg) { WebSocketSession session = Sessions.getSession(ctx); if(msg instanceof TextWebSocketFrame){ String message = ((TextWebSocketFrame) msg).text(); try { webSocketServer.onMessage(session,message); }catch (Throwable e){ webSocketServer.onError(session,e); } }else if(msg instanceof BinaryWebSocketFrame){ byte[] bytes = ((BinaryWebSocketFrame) msg).content().array(); try { webSocketServer.onMessage(session,bytes); }catch (Throwable e){ webSocketServer.onError(session,e); } }else { System.out.println("未知消息类型:" + msg.getClass().getName()); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { super.exceptionCaught(ctx, cause); webSocketServer.onError(Sessions.getSession(ctx),cause); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { super.channelInactive(ctx); WebSocketSession destroy = Sessions.destroy(ctx); webSocketServer.onClose(destroy); } } ``` - 扩展uri支持 ```aidl public class WebSocketServerAuthProtocolHandler extends WebSocketServerProtocolHandler { private final WebSocketServer webSocketServer; public WebSocketServerAuthProtocolHandler(String websocketPath, WebSocketServer webSocketServer) { this(websocketPath, null, false,webSocketServer); } public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, WebSocketServer webSocketServer) { this(websocketPath, subprotocols, false,webSocketServer); } public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, WebSocketServer webSocketServer) { this(websocketPath, subprotocols, allowExtensions, 65536,webSocketServer); } public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, WebSocketServer webSocketServer) { this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false,webSocketServer); } public WebSocketServerAuthProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, WebSocketServer webSocketServer) { super(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch); this._webSocketPathPrefix = websocketPath; this._subprotocols =subprotocols; this._allowExtensions = allowExtensions; this._maxFrameSize = maxFrameSize; this._allowMaskMismatch = allowMaskMismatch; this.webSocketServer =webSocketServer; } String _webSocketPathPrefix; String _subprotocols; boolean _allowExtensions; int _maxFrameSize; boolean _allowMaskMismatch; @Override public void handlerAdded(ChannelHandlerContext ctx) { ChannelPipeline cp = ctx.pipeline(); if (cp.get(WebSocketServerAuthHandshakeHandler.class) == null) { // Add the WebSocketHandshakeHandler before this one. ctx.pipeline().addBefore(ctx.name(), WebSocketServerAuthHandshakeHandler.class.getName(), new WebSocketServerAuthHandshakeHandler(_webSocketPathPrefix, _subprotocols, _allowExtensions, _maxFrameSize, _allowMaskMismatch, webSocketServer)); } if (cp.get(Utf8FrameValidator.class) == null) { // Add the UFT8 checking before this one. ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(), new Utf8FrameValidator()); } } } public class WebSocketServerAuthHandshakeHandler extends ChannelInboundHandlerAdapter { private final String websocketPath; private final String subprotocols; private final boolean allowExtensions; private final int maxFramePayloadSize; private final boolean allowMaskMismatch; private final WebSocketServer webSocketServer; public WebSocketServerAuthHandshakeHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, WebSocketServer webSocketServer) { this.websocketPath = websocketPath; this.subprotocols = subprotocols; this.allowExtensions = allowExtensions; this.maxFramePayloadSize = maxFrameSize; this.allowMaskMismatch = allowMaskMismatch; this.webSocketServer = webSocketServer; } @Override public void channelRead(final ChannelHandlerContext ctx, Object msg) { FullHttpRequest req = (FullHttpRequest) msg; if (req.uri().indexOf(websocketPath) != 0) { ctx.fireChannelRead(msg); return; } try { if (req.method() != GET) { sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)); return; } WebSocketSession session = Sessions.createSession(ctx); session.setChannel(ctx.channel()); session.setId(ctx.channel().hashCode()); session.setUri(req.uri()); UrlEntity entity = UrlEntity.parse(req.uri()); session.setUriBase(entity.getBaseUrl()); session.setParams(entity.getParams()); webSocketServer.onOpen(session); final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory( getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols, allowExtensions, maxFramePayloadSize, allowMaskMismatch); final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); handshakeFuture.addListener((ChannelFutureListener) future -> { if (!future.isSuccess()) { ctx.fireExceptionCaught(future.cause()); } else { ctx.fireUserEventTriggered( WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE); } }); setHandshaker(ctx.channel(), handshaker); ctx.pipeline().replace(this, "WS403Responder", forbiddenHttpRequestResponder()); } } finally { req.release(); } } private static final AttributeKey HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER"); static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) { channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker); } static ChannelHandler forbiddenHttpRequestResponder() { return new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { if (msg instanceof FullHttpRequest) { ((FullHttpRequest) msg).release(); FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN); ctx.channel().writeAndFlush(response); } else { ctx.fireChannelRead(msg); } } }; } private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { ChannelFuture f = ctx.channel().writeAndFlush(res); if (!isKeepAlive(req) || res.status().code() != 200) { f.addListener(ChannelFutureListener.CLOSE); } } private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) { String protocol = "ws"; if (cp.get(SslHandler.class) != null) { // SSL in use so use Secure WebSockets protocol = "wss"; } return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path; } } ```