编辑
2024-07-28
后端
00
请注意,本文编写于 286 天前,最后修改于 286 天前,其中某些信息可能已经过时。

netty前两天看完了,一本叫做netty实战的书,差不多三百页看了两天,也算是入门了,也写了一些demo,现在至少对netty有一点基本的理解,后续打算深入源码,并且打算用netty写一个web框架,类似Gin那样的,其实我很不喜欢注解,所以想写一个不基于注解的框架,不知道有没有思路,尝试过几种方式觉得不太优雅

今天不写demo了,来看看一个就开源项目的源码,代码量不多,但是项目有四百多的star,这是个仿照SpringBoot-websocket写的基于netty的webSocket的包,正好可以复习一下怎么写Starter

先来看看怎么自动装配的吧,长不多都一样得有一个配置类,配置类里面注入Bean,注入Bean的方式有多中,看自己喜好,这里因为只需要注入两个Bean所以作者用的是@Bean的方式

image.png

# Auto Configure org.springframework.boot.autoconfigure.EnableAutoConfiguration=\ com.niezhiliang.netty.websocket.starter.NettyWebsocketAutoConfiguration

Springboot是2.x版本,2.7之后换成了Spring文件夹下的AutoConfiguration.imports文件了

下面是配置类

java
@Configuration public class NettyWebsocketAutoConfiguration { @Bean public WebSocketAnnotationPostProcessor webSocketAnnotationPostProcessor() { return new WebSocketAnnotationPostProcessor(); } @Bean public WebsocketProperties websocketProperties() { return new WebsocketProperties(); } }

总共注入了两个Bean,一个是WebSocketAnnotationPostProcessor,这是一个注解处理器,和一个配置文件相关得类

然后定义了一系列Spring-websocket一样的注解

一步步来,先看看配置类

java
@ConfigurationProperties(prefix = WebsocketProperties.WEBSOCKET_PREFIX) @Data public class WebsocketProperties { public static final String WEBSOCKET_PREFIX = "netty.websocket"; private Integer port; private Integer bossThreadNums = 1; private Integer workerThreadNums = 2; /** * 连接超时时间 */ private Integer connectTimeout = 15000; /** * TCP 连接的请求队列的最大长度,默认128 */ private Integer backLog = 128; /** * 消息是否立即发送 */ private boolean tcpNoDelay = true; /** * 心跳读超时时间 */ private Integer readerIdleTimeSeconds = 60; /** * 心跳写超时时间 */ private Integer writerIdleTimeSeconds = 60; private Integer allIdleTimeSeconds = 60; /** * */ private Integer maxContentLength = 65536; }

方便后面阅读源码 接下来看一下WebSocketAnnotationPostProcessor是怎么运作的,看名字是个处理注解的后置处理器

看源码

java
public class WebSocketAnnotationPostProcessor implements SmartInitializingSingleton { @Autowired private DefaultListableBeanFactory beanFactory; @Autowired private WebsocketProperties websocketProperties; @Override public void afterSingletonsInstantiated() { String[] beanNamesForAnnotation = beanFactory.getBeanNamesForAnnotation(SpringBootApplication.class); String applicationStartBean = beanNamesForAnnotation[0]; Object bean = beanFactory.getBean(applicationStartBean); String basePackage = ClassUtils.getPackageName(bean.getClass()); scanWebsocketServiceBeans(basePackage,beanFactory); registerServerEndpoints(); } @SneakyThrows private void registerServerEndpoints() { String[] beanNamesForAnnotation = beanFactory.getBeanNamesForAnnotation(WsServerEndpoint.class); WebsocketActionDispatch actionDispatch = new WebsocketActionDispatch(); for (String beanName : beanNamesForAnnotation) { Class<?> beanType = beanFactory.getType(beanName); Class<?> targetClass = getTargetClass(beanType); WsServerEndpoint wsServerEndpoint = targetClass.getAnnotation(WsServerEndpoint.class); WebsocketServerEndpoint websocketServerEndpoint = new WebsocketServerEndpoint(targetClass ,beanFactory.getBean(targetClass),wsServerEndpoint.value()); actionDispatch.addWebsocketServerEndpoint(websocketServerEndpoint); } NettyWebsocketServer websocketServer = new NettyWebsocketServer(actionDispatch,websocketProperties); // 启动websocket websocketServer.start(); } /** * 扫描WsServerEndpoint的Bean * @param packagesToScan 扫描包路径 * @param registry */ private void scanWebsocketServiceBeans(String packagesToScan, BeanDefinitionRegistry registry) { ClassPathBeanDefinitionScanner scanner = new ClassPathBeanDefinitionScanner(registry); // 扫描 @WsServerEndpoint标注的类 scanner.addIncludeFilter(new AnnotationTypeFilter(WsServerEndpoint.class)); scanner.scan(packagesToScan); } /** * 获取类型的目标类型 * @param clazz * @return */ public Class<?> getTargetClass(Class<?> clazz) { if (AopUtils.isCglibProxy(clazz)) { return clazz.getSuperclass(); } return clazz; } }

作者写的代码很简单易懂

首先就是实现了

java
public interface SmartInitializingSingleton { void afterSingletonsInstantiated(); }

这个接口,这个接口是对初始化好的单例Bean做一些操作,这里直接就是获取启动类所在包名,然后扫描,这个扫描操作会扫描所有带有

java
@Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface WsServerEndpoint { String value() default "/ws/{arg}"; }

这个注解的类,然后把类注册到容器

扫描

java
private void scanWebsocketServiceBeans(String packagesToScan, BeanDefinitionRegistry registry) { ClassPathBeanDefinitionScanner scanner = new ClassPathBeanDefinitionScanner(registry); // 扫描 @WsServerEndpoint标注的类 scanner.addIncludeFilter(new AnnotationTypeFilter(WsServerEndpoint.class)); scanner.scan(packagesToScan); }

注册逻辑

java
@SneakyThrows private void registerServerEndpoints() { String[] beanNamesForAnnotation = beanFactory.getBeanNamesForAnnotation(WsServerEndpoint.class); WebsocketActionDispatch actionDispatch = new WebsocketActionDispatch(); for (String beanName : beanNamesForAnnotation) { Class<?> beanType = beanFactory.getType(beanName); Class<?> targetClass = getTargetClass(beanType); WsServerEndpoint wsServerEndpoint = targetClass.getAnnotation(WsServerEndpoint.class); WebsocketServerEndpoint websocketServerEndpoint = new WebsocketServerEndpoint(targetClass ,beanFactory.getBean(targetClass),wsServerEndpoint.value()); actionDispatch.addWebsocketServerEndpoint(websocketServerEndpoint); } NettyWebsocketServer websocketServer = new NettyWebsocketServer(actionDispatch,websocketProperties); // 启动websocket websocketServer.start(); }

这里是关键,首先是拿到所有带有WsServerEndpoint注解的BeanName,然后创建对应的WebsocketServerEndpoint对象加入到WebsocketActionDispatch,这个应该是管理WebsocketServerEndpoint的一个管理或者调度器,最后启动nettyserver

java
public class NettyWebsocketServer { private final WebsocketActionDispatch websocketActionDispatch; private WebsocketProperties websocketProperties; public NettyWebsocketServer(WebsocketActionDispatch websocketActionDispatch,WebsocketProperties websocketProperties) { this.websocketActionDispatch = websocketActionDispatch; this.websocketProperties = websocketProperties; } public void start() throws InterruptedException { NioEventLoopGroup boss = new NioEventLoopGroup(websocketProperties.getBossThreadNums()); NioEventLoopGroup worker = new NioEventLoopGroup(websocketProperties.getWorkerThreadNums()); ServerBootstrap bootstrap = new ServerBootstrap(); bootstrap.group(boss,worker) .channel(NioServerSocketChannel.class) .childHandler(new ChannelInitializer<NioSocketChannel>() { @Override protected void initChannel(NioSocketChannel channel) { ChannelPipeline pipeline = channel.pipeline(); pipeline.addLast(new HttpServerCodec()) .addLast(new HttpObjectAggregator(websocketProperties.getMaxContentLength())) .addLast(new IdleStateHandler(websocketProperties.getReaderIdleTimeSeconds() ,websocketProperties.getWriterIdleTimeSeconds() ,websocketProperties.getAllIdleTimeSeconds())) .addLast(new HttpRequestHandler(websocketActionDispatch)) .addLast(new WebSocketFrameAggregator(Integer.MAX_VALUE)) .addLast(new GenericHandler(websocketActionDispatch)) .addLast(new WebSocketServerHandler(websocketActionDispatch)); } }) // 连接超时时间 .option(ChannelOption.CONNECT_TIMEOUT_MILLIS,websocketProperties.getConnectTimeout()) // TCP 连接的请求队列的最大长度 .option(ChannelOption.SO_BACKLOG,websocketProperties.getBackLog()) // 消息是否立即发送 .option(ChannelOption.TCP_NODELAY,websocketProperties.isTcpNoDelay()) // TCP 建立连接后,每隔一段时间就会对连接做一次探测 .childOption(ChannelOption.SO_KEEPALIVE,Boolean.TRUE); ChannelFuture channelFuture = bootstrap.bind(websocketProperties.getPort()).sync(); Channel serverChannle = channelFuture.channel(); serverChannle.closeFuture().addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture channelFuture) throws Exception { boss.shutdownGracefully(); worker.shutdownGracefully(); } }); } }

上面是源码,属性有配置相关类还有一个 WebsocketActionDispatch,也就是刚刚看到的

java
NettyWebsocketServer(actionDispatch,websocketProperties);

接下来肯定得就看这个调度器是怎么运作的,因为三个自定义Handler都用到了

java
public class WebsocketActionDispatch { private AntPathMatcher antPathMatcher = new AntPathMatcher(); private final static Map<String, WebsocketServerEndpoint> endpointMap = new ConcurrentHashMap<>(16); /** * 验证请求路径是否合法 * @param uri * @return */ protected boolean verifyUri(String uri) { return endpointMap.keySet().stream().anyMatch(e -> antPathMatcher.match(e, uri)); } /** * 添加websocket服务 * @param endpoint */ public void addWebsocketServerEndpoint(WebsocketServerEndpoint endpoint) { endpointMap.putIfAbsent(endpoint.getPath(),endpoint); } /** * uri匹配对应的websocket服务 * @param uri * @return */ protected WebsocketServerEndpoint matchServerEndpoint(String uri) { for (Map.Entry<String, WebsocketServerEndpoint> entry : endpointMap.entrySet()) { if (antPathMatcher.match(entry.getKey(),uri)) { return entry.getValue(); } } return null; } /** * 事件分发到具体的方法 * @param uri * @param action */ protected void dispatch(String uri, Action action, Channel channel) { WebsocketServerEndpoint websocketServerEndpoint = matchServerEndpoint(uri); if (Objects.nonNull(websocketServerEndpoint)) { Method method = null; Object obj = websocketServerEndpoint.getObject(); switch (action) { case HAND_SHAKE: method = websocketServerEndpoint.getOnHandShake(); break; case OPEN: method = websocketServerEndpoint.getOnOpen(); break; case CLOSE: method = websocketServerEndpoint.getOnClose(); break; case MESSAGE: method = websocketServerEndpoint.getOnMessage(); break; case EVENT: method = websocketServerEndpoint.getOnEvent(); break; case ERROR: method = websocketServerEndpoint.getOnError(); break; default: break; } if (Objects.nonNull(method)) { Object[] args = new MethodParamsBuild().getMethodArgumentValues(method,channel); ReflectionUtils.invokeMethod(method,obj,args); } } } public Map<String,String> getUriTemplateVariables(String lookupPath) { WebsocketServerEndpoint websocketServerEndpoint = matchServerEndpoint(lookupPath); return antPathMatcher.extractUriTemplateVariables(websocketServerEndpoint.getPath(), lookupPath); } enum Action { HAND_SHAKE, OPEN, CLOSE, MESSAGE, EVENT, ERROR } }

我们再看一下具体到WebsocketServerEndpoint是怎么处理的

java
@Getter public class WebsocketServerEndpoint { /** * @WsServerEndpoint配置的路径 */ private String path; /** * 握手前调用的目标方法 */ private Method onHandShake; /** * 连接关闭事件调用的目标方法 */ private Method onClose; /** * 触发心跳事件调用的目标方法 */ private Method onEvent; /** * 连接成功调用的目标方法 */ private Method onOpen; /** * 收到消息调用的目标方法 */ private Method onMessage; /** * 错误事件调用的目标方法 */ private Method onError; /** * path对应@WsServerEndpoint修饰的类 */ private Object object; public WebsocketServerEndpoint(Class<?> pojoClazz,Object o,String path) { this.object = o; this.path = path; AtomicReference<Method> handShake = new AtomicReference<>(); AtomicReference<Method> close = new AtomicReference<>(); AtomicReference<Method> event = new AtomicReference<>(); AtomicReference<Method> open = new AtomicReference<>(); AtomicReference<Method> message = new AtomicReference<>(); AtomicReference<Method> error = new AtomicReference<>(); Method[] pojoClazzMethods = null; Class<?> currentClazz = pojoClazz; while (!currentClazz.equals(Object.class)) { Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); if (currentClazz == pojoClazz) { pojoClazzMethods = currentClazzMethods; } for (Method method : currentClazzMethods) { if (Objects.nonNull(method.getAnnotation(HandshakeBefore.class))) { methodFill(currentClazz,method,pojoClazz,handShake, HandshakeBefore.class); } else if (Objects.nonNull(method.getAnnotation(OnClose.class))) { methodFill(currentClazz,method,pojoClazz,close,OnClose.class); } else if (Objects.nonNull(method.getAnnotation(OnEvent.class))) { methodFill(currentClazz,method,pojoClazz,event,OnEvent.class); } else if (Objects.nonNull(method.getAnnotation(OnOpen.class))) { methodFill(currentClazz,method,pojoClazz,open,OnOpen.class); } else if (Objects.nonNull(method.getAnnotation(OnMessage.class))) { methodFill(currentClazz,method,pojoClazz,message,OnMessage.class); } else if (Objects.nonNull(method.getAnnotation(OnError.class))) { methodFill(currentClazz,method,pojoClazz,error,OnError.class); } } currentClazz = currentClazz.getSuperclass(); this.onHandShake = handShake.get(); this.onClose = close.get(); this.onEvent = event.get(); this.onOpen = open.get(); this.onMessage = message.get(); this.onError = error.get(); } } private void methodFill(Class<?> currentClazz, Method method, Class<?> pojoClazz, AtomicReference<Method> point, Class annotation) { checkPublic(method); if (Objects.isNull(point.get())) { point.set(method); } else { if (currentClazz == pojoClazz || !isMethodOverride(point.get(), method)) { throw new WebsocketDeploymentException( "wsServerEndpoint.duplicateAnnotation " + annotation.getSimpleName()); } } } /** * 判断方法是否public * @param m * @throws WebsocketDeploymentException */ private void checkPublic(Method m) throws WebsocketDeploymentException { if (!Modifier.isPublic(m.getModifiers())) { throw new WebsocketDeploymentException( "pojoMethodMapping.methodNotPublic " + m.getName()); } } /** * 判断方法是否重写方法 * @param method1 * @param method2 * @throws WebsocketDeploymentException */ private boolean isMethodOverride(Method method1, Method method2) { return (method1.getName().equals(method2.getName()) && method1.getReturnType().equals(method2.getReturnType()) && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes())); } }

也很容易理解,重要的就是事件分发的方法,接下来看第一个Handler

java
@ChannelHandler.Sharable public class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> { private final WebsocketActionDispatch websocketActionDispatch; public HttpRequestHandler(WebsocketActionDispatch websocketActionDispatch) { this.websocketActionDispatch = websocketActionDispatch; } @Override protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { // 校验请求路径 boolean pass = verifyRequest(request); if (!pass) { ctx.close(); } // 参数传递到WebsocketHandler ctx.channel().attr(AttributeKeyConstant.fullHttpRequest).set(request); ctx.channel().attr(AttributeKeyConstant.PATH_KEY).set(request.uri()); Map<String, String> uriTemplateVariables = websocketActionDispatch.getUriTemplateVariables(request.uri()); ctx.channel().attr(AttributeKeyConstant.uriTemplateVariables).set(uriTemplateVariables); websocketActionDispatch.dispatch(request.uri(), WebsocketActionDispatch.Action.HAND_SHAKE, ctx.channel()); WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(request), null, true, 65536); WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(request); if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { handshaker.handshake(ctx.channel(), request).addListener(future -> { if (future.isSuccess()) { websocketActionDispatch.dispatch(request.uri(), WebsocketActionDispatch.Action.OPEN,ctx.channel()); } else { handshaker.close(ctx.channel(), new CloseWebSocketFrame()); } }); } } private static String getWebSocketLocation(FullHttpRequest req) { String location = req.headers().get(HttpHeaderNames.HOST) + req.uri(); return "ws://" + location; } /** * 验证请求是否是Http升级Websocket * 并且验证uri是否合法 * @param request * @return */ private boolean verifyRequest(FullHttpRequest request) { HttpHeaders headers = request.headers(); String connection = headers.get("Connection"); String upgrade = headers.get("Upgrade"); String host = headers.get("Host"); if (Objects.isNull(connection) || Objects.isNull(upgrade) || Objects.isNull(host)) { return false; } else if (!"Upgrade".equalsIgnoreCase(connection) || !"websocket".equalsIgnoreCase(upgrade)) { return false; } else if (!"GET".equalsIgnoreCase(request.method().name())) { return false; } return websocketActionDispatch.verifyUri(request.uri()); } }

然后往下面主要看Handler,有三个自定义Handler其他的解码相关,还有一个IdleStateHandler用来剔除空闲链接。

接下来就是看一下Handler怎么处理的,按顺序看。

最先进入的是HttpRequestHandler,这个是处理协议升级的

java
@ChannelHandler.Sharable public class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequest> { private final WebsocketActionDispatch websocketActionDispatch; public HttpRequestHandler(WebsocketActionDispatch websocketActionDispatch) { this.websocketActionDispatch = websocketActionDispatch; } @Override protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { // 校验请求路径 boolean pass = verifyRequest(request); if (!pass) { ctx.close(); } // 参数传递到WebsocketHandler ctx.channel().attr(AttributeKeyConstant.fullHttpRequest).set(request); ctx.channel().attr(AttributeKeyConstant.PATH_KEY).set(request.uri()); Map<String, String> uriTemplateVariables = websocketActionDispatch.getUriTemplateVariables(request.uri()); ctx.channel().attr(AttributeKeyConstant.uriTemplateVariables).set(uriTemplateVariables); websocketActionDispatch.dispatch(request.uri(), WebsocketActionDispatch.Action.HAND_SHAKE, ctx.channel()); WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(request), null, true, 65536); WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(request); if (handshaker == null) { WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); } else { handshaker.handshake(ctx.channel(), request).addListener(future -> { if (future.isSuccess()) { websocketActionDispatch.dispatch(request.uri(), WebsocketActionDispatch.Action.OPEN,ctx.channel()); } else { handshaker.close(ctx.channel(), new CloseWebSocketFrame()); } }); } } private static String getWebSocketLocation(FullHttpRequest req) { String location = req.headers().get(HttpHeaderNames.HOST) + req.uri(); return "ws://" + location; } /** * 验证请求是否是Http升级Websocket * 并且验证uri是否合法 * @param request * @return */ private boolean verifyRequest(FullHttpRequest request) { HttpHeaders headers = request.headers(); String connection = headers.get("Connection"); String upgrade = headers.get("Upgrade"); String host = headers.get("Host"); if (Objects.isNull(connection) || Objects.isNull(upgrade) || Objects.isNull(host)) { return false; } else if (!"Upgrade".equalsIgnoreCase(connection) || !"websocket".equalsIgnoreCase(upgrade)) { return false; } else if (!"GET".equalsIgnoreCase(request.method().name())) { return false; } return websocketActionDispatch.verifyUri(request.uri()); } }

看看主要的处理方法,做了以下事

  • 验证请求是否合法,包括检查请求头和请求路径。

  • 将请求参数存储在 Channel 的属性中,以便后续处理器使用。

  • 分发握手事件到 WebsocketActionDispatch。

  • 创建 WebSocketServerHandshaker 并执行握手操作。

  • 如果握手成功,分发打开事件;如果失败,关闭连接。

verifyRequest 方法:

  • 验证请求头是否包含必要的字段(如 Connection、Upgrade、Host)。

  • 验证请求方法是否为 GET。

  • 验证请求路径是否合法。

我们只要知道他是一个处理协议升级的ChannelHandler就可以了

接下来是GenericHandler

java
public class GenericHandler extends ChannelInboundHandlerAdapter { private final WebsocketActionDispatch websocketActionDispatch; public GenericHandler(WebsocketActionDispatch websocketActionDispatch) { this.websocketActionDispatch = websocketActionDispatch; } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { super.channelInactive(ctx); } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { String uri = ctx.channel().attr(AttributeKeyConstant.PATH_KEY).get(); ctx.channel().attr(AttributeKeyConstant.idleStateEvent).set(evt); websocketActionDispatch.dispatch(uri, WebsocketActionDispatch.Action.EVENT,ctx.channel()); super.userEventTriggered(ctx, evt); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { String uri = ctx.channel().attr(AttributeKeyConstant.PATH_KEY).get(); ctx.channel().attr(AttributeKeyConstant.throwable).set(cause); websocketActionDispatch.dispatch(uri, WebsocketActionDispatch.Action.ERROR,ctx.channel()); super.exceptionCaught(ctx, cause); } }

这是个用来处理Channel事件的Handler,我们直接看下一个

java
@ChannelHandler.Sharable public class WebSocketServerHandler extends SimpleChannelInboundHandler<WebSocketFrame> { private final WebsocketActionDispatch websocketActionDispatch; public WebSocketServerHandler(WebsocketActionDispatch websocketActionDispatch) { this.websocketActionDispatch = websocketActionDispatch; } @Override protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception { handleWebSocketFrame(ctx, msg); } private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) { // 获取上下文传递过来的uri,给分发器分发任务使用 String uri = ctx.channel().attr(AttributeKeyConstant.PATH_KEY).get(); if (frame instanceof TextWebSocketFrame) { ctx.channel().attr(AttributeKeyConstant.textWebSocketFrame).set((TextWebSocketFrame) frame); websocketActionDispatch.dispatch(uri, WebsocketActionDispatch.Action.MESSAGE,ctx.channel()); return; } if (frame instanceof PingWebSocketFrame) { ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain())); return; } if (frame instanceof CloseWebSocketFrame) { websocketActionDispatch.dispatch(uri, WebsocketActionDispatch.Action.CLOSE,ctx.channel()); ctx.writeAndFlush(frame.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE); return; } if (frame instanceof BinaryWebSocketFrame) { return; } if (frame instanceof PongWebSocketFrame) { return; } } }

这里的代码也很容易懂。假设你有Netty和Java基础

本文作者:yowayimono

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!