Spring环境下WebSocket的2种实现

0 运行环境及主要工具类简要说明

  • 基于Spring Boot环境

  • 由于@ServerEndPoint的实现方式中无法直接使用@Autowired之类的注解直接注入业务相关的Bean,所有实现了一个甚至WebApplicationContext获取Bean的一个工具类,具体代码如下

    package cn.xue.common;
    import org.springframework.beans.BeansException;
    import org.springframework.context.ApplicationContext;
    
    public class SpringContextUtil {
    
        private static ApplicationContext applicationContext = null;
    
        public static void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
            if(SpringContextUtil.applicationContext == null){
                SpringContextUtil.applicationContext  = applicationContext;
            }
        }
    
        //获取applicationContext
        public static ApplicationContext getApplicationContext() {
            return applicationContext;
        }
    
        //通过name获取 Bean.
        public static Object getBean(String name){
            return getApplicationContext().getBean(name);
    
        }
    
        //通过class获取Bean.
        public static <T> T getBean(Class<T> clazz){
            return getApplicationContext().getBean(clazz);
        }
    
        //通过name,以及Clazz返回指定的Bean
        public static <T> T getBean(String name,Class<T> clazz){
            return getApplicationContext().getBean(name, clazz);
        }
    
    }
    

    并在主函数中将ApplicationContext传给上述的工具类中,代码如下 :

    @SpringBootApplication
    public class MainApplication {
        public static void main(String[] args) {
            ApplicationContext app = SpringApplication.run(MainApplication.class, args);
            SpringContextUtil.setApplicationContext(app);
        }
    }
    
    

1 基于@ServerEndPoint注解的实现

  • 需要在@Configuration注解的类中注入ServerEndpointExporter

      @Bean
        public ServerEndpointExporter serverEndpointExporter() {
            return new ServerEndpointExporter();
        }
    
  • Configurator类实现

    package cn.xue.config;
    
    import cn.xue.common.Constants;
    import cn.xue.common.SpringContextUtil;
    import cn.xue.model.user.User;
    import cn.xue.service.user.UserService;
    
    import javax.servlet.annotation.WebListener;
    import javax.websocket.HandshakeResponse;
    import javax.websocket.server.HandshakeRequest;
    import javax.websocket.server.ServerEndpointConfig;
    
    @WebListener
    public class WebSocketSessionConfigurator extends ServerEndpointConfig.Configurator {
        @SuppressWarnings("unchecked")
        @Override
        public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
            Long userId = Long.parseLong(request.getParameterMap().get("userId").get(0).toString());
            String password = request.getParameterMap().get("password").get(0);
            UserService userService = (UserService) SpringContextUtil.getBean(UserService.class);
            User user = userService.getUserById(userId);
            if(null != user && user.getPassword().equals(password)){
                config.getUserProperties().put(Constants.WEBSOCKET_USER, user);
            }else{
                return;
            }
        }
    }
    
    

    简述:上述代码中主要作用是为了截取Socket连接请求中的请求参数,用以进行简单的身份校验,如果校验通过将用户信息存放于Session的UserProperties中,这样在Socket连接后可以从Session中获取到用户信息

  • 具体实现,代码如下:

    package cn.xue.controller.socket;
    
    import cn.xue.common.Constants;
    import cn.xue.common.SpringContextUtil;
    import cn.xue.model.socket.Message;
    import cn.xue.model.user.User;
    import cn.xue.service.user.UserService;
    import com.alibaba.fastjson.JSONObject;
    import org.apache.log4j.Logger;
    
    import javax.websocket.*;
    import javax.websocket.server.ServerEndpoint;
    import java.io.IOException;
    import java.util.Map;
    import java.util.concurrent.ConcurrentHashMap;
    @Component
    @ServerEndpoint(value = "/ChatEndPoint", configurator = WebSocketSessionConfigurator.class)
    public class WebSocketHandler {
        private final static Logger logger = Logger.getLogger(WebSocketHandler.class);
        @Autowired
        private RedisService redisService;
    
        private static final Map<Long, Session> users;
        static {
           users = new ConcurrentHashMap<Long, Session>();
        }
        @OnOpen
        public void onOpen(EndpointConfig conf, Session session) {
            User user = (User) conf.getUserProperties().get(Constants.WEBSOCKET_USER);
            if(user != null) {
                logger.info(user.getNickName() + "连接成功");
                users.put(user.getId(), session);
            } else {
                try {
                    session.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        @OnMessage
        public void handleMessage(Session session, String message) {
            sendMessage(JSONObject.parseObject(message, Message.class));
        }
    
        public static boolean sendMessage(Message message) {
            Session session = users.get(message.getDesUserId());
            if (session == null) {
                return false;
            }
            session.getAsyncRemote().sendText(JSONObject.toJSONString(message), new SendHandler() {
                @Override
                public void onResult(SendResult result) {
                    if (!result.isOK()) {
                        users.remove(message.getDesUserId());
                        try {
                            session.close();
                        } catch (IOException e) {
                        }
                    }
                }
            });
            return true;
        }
    
        @OnError
        public void error(Session session, java.lang.Throwable throwable) {
            if (throwable.getMessage() != null) {
                onClose(session);
            }
        }
        @OnClose
        public void onClose(Session session) {
            try {
                session.close();
                User user = (User) session.getUserProperties().get(Constants.WEBSOCKET_USER);
                users.remove(user.getId());
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    
    }
    
    

    代码内容简述:

    • 请关注@ServerEndPoint的configurator配置,其作用即在于在发起socket连接请求时,请求会首先经由WebSocketSessionConfigurator类中的modifyHandshake处理,并将用户身份校验的结果置于session的UserProperties中
    • 在本类中@OnOpen中,即可根据前一步处理结果查看是否能拿到用户信息,如果不能则表示连接非法,直接关闭连接
    • 其他像@OnMessage @OnClose等具体功能这里就略过
    • 另外代码中Message是一个自定义实体

2 基于Spring WebSocket的实现

  • 定义WebSocket握手拦截器,用以进行用户连接过程中的身份校验,代码如下:

    package cn.xue.interceptor.socket;
    
    import cn.xue.common.Constants;
    import cn.xue.model.user.User;
    import cn.xue.service.user.UserService;
    import org.apache.log4j.Logger;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.http.server.ServerHttpRequest;
    import org.springframework.http.server.ServerHttpResponse;
    import org.springframework.http.server.ServletServerHttpRequest;
    import org.springframework.stereotype.Component;
    import org.springframework.web.socket.WebSocketHandler;
    import org.springframework.web.socket.server.HandshakeInterceptor;
    
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpSession;
    import java.util.Map;
    
    @Component
    public class WebSocketInterceptor implements HandshakeInterceptor {
        private final Logger logger = Logger.getLogger(getClass());
        @Autowired
        private UserService userService;
    
        @Override
        public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> attributes) throws Exception {
            if (request instanceof ServletServerHttpRequest) {
                HttpServletRequest httpServletRequest = ((ServletServerHttpRequest) request).getServletRequest();
                HttpSession session = httpServletRequest.getSession();
                Long userId = Long.parseLong(httpServletRequest.getParameter("userId").toString());
                String password = httpServletRequest.getParameter("password");
                User user = userService.getUserById(userId);
                if(null != user && user.getPassword().equals(password)) {
                    attributes.put(Constants.WEBSOCKET_USER, user);
                    logger.info(user.getNickName() + "连接并验证成功");
                    return true;
                }
            }
            return false;
        }
        @Override
        public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {
    
        }
    }
    
    

    代码简述:在上述代码中beforeHandshake方法中截取soket连接请求参数并进行身份校验,如果校对失败返回false,也即意味着无法继续握手过程,从而连接失败

  • 定义Socket处理类

    package cn.xue.service.socket;
    
    import cn.xue.common.Constants;
    import cn.xue.model.socket.Message;
    import cn.xue.model.user.User;
    import cn.xue.service.redis.RedisService;
    import com.alibaba.fastjson.JSONObject;
    import com.gexin.fastjson.JSON;
    import org.apache.log4j.Logger;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.stereotype.Service;
    import org.springframework.web.socket.*;
    
    import java.io.IOException;
    import java.nio.ByteBuffer;
    import java.nio.CharBuffer;
    import java.nio.charset.Charset;
    import java.nio.charset.CharsetDecoder;
    import java.util.Map;
    import java.util.Set;
    import java.util.concurrent.ConcurrentHashMap;
    
    
    @Service
    public class SocketService implements WebSocketHandler {
        private static final Logger logger = Logger.getLogger(SocketService.class);
        //在线用户列表
        private static final Map<Long, WebSocketSession> users;
        static {
            users = new ConcurrentHashMap<>();
        }
        @Override
        public void afterConnectionEstablished(WebSocketSession session) throws Exception {
            User user = (User) session.getAttributes().get(Constants.WEBSOCKET_USER);
            if (user != null) {
                users.put(user.getId(), session);
                session.sendMessage(new TextMessage("成功建立socket连接"));
                logger.info(user + "成功连接!");
            }
            logger.info("当前在线人数:"+users.size());
        }
    
        //接收socket信息
        @Override
        public void handleMessage(WebSocketSession webSocketSession, WebSocketMessage<?> webSocketMessage) throws Exception {
            try{
                if(webSocketMessage instanceof TextMessage) {
                    Message object = JSONObject.parseObject(((TextMessage) webSocketMessage).getPayload(), Message.class);
                    sendMessageToUser(object);
                } else if(webSocketMessage instanceof BinaryMessage) {
                    Charset charset = null;
                    CharsetDecoder decoder = null;
                    CharBuffer charBuffer = null;
                    try {
                        charset = Charset.forName("UTF-8");
                        decoder = charset.newDecoder();                    
                        charBuffer = decoder.decode((ByteBuffer) webSocketMessage.getPayload());
                        System.out.println(charBuffer.toString());
                    } catch (Exception ex) {
                        ex.printStackTrace();
                    }
                }
            }catch(Exception e){
                e.printStackTrace();
            }
    
        }
    
        /**
         * 发送信息给指定用户n
         */
        public boolean sendMessageToUser(Message message) {
            Long userId = message.getSrcUserId();
            if (users.get(userId) == null) {
                //消息接收者不在线
    
                return false;
            }
            //消息接收者在线 直接发送
            WebSocketSession session = users.get(userId);
            if (!session.isOpen()) return false;
            try {
                session.sendMessage(new TextMessage(JSON.toJSONString(message)));
            } catch (IOException e) {
                e.printStackTrace();
                return false;
            }
            return true;
        }
    
        /**
         * 广播信息
         * @param message
         * @return
         */
        public boolean sendMessageToAllUsers(TextMessage message) {
            boolean allSendSuccess = true;
            Set<Long> clientIds = users.keySet();
            WebSocketSession session = null;
            for (Long clientId : clientIds) {
                try {
                    session = users.get(clientId);
                    if (session.isOpen()) {
                        session.sendMessage(message);
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                    allSendSuccess = false;
                }
            }
            return allSendSuccess;
        }
    
    
        @Override
        public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
            if (session.isOpen()) {
                session.close();
            }
            System.out.println("连接出错");
            users.remove(getUserId(session));
        }
    
        @Override
        public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
            System.out.println("连接已关闭:" + status);
            users.remove(getUserId(session));
        }
    
        @Override
        public boolean supportsPartialMessages() {
            return false;
        }
    
        /**
         * 获取用户标识
         * @param session
         * @return
         */
        private Long getUserId(WebSocketSession session) {
            try {
                User user = (User) session.getAttributes().get(Constants.WEBSOCKET_USER);
                return user != null ? user.getId() : null;
            } catch (Exception e) {
                return null;
            }
        }
    }
    

    代码简述:

    • 上述实现类是直接实现了spring-websocket的提供的接口
    • afterConnectionEstablished方法功能在于将前一步拦截器处理验证后将用户和用户所使用的连接存入缓存,其功能和@ServerEndPoint实现类中的onOpen()完成的内容基本一致
    • handleMessage方法作用在于接收用户发来的消息,其功能与@ServerEndPoint实现类中的onMessage()完成的内容基本一致
  • WebSocket配置,代码如下:

    package cn.xue.config;
    
    import cn.xue.interceptor.socket.WebSocketInterceptor;
    import cn.xue.service.socket.SocketService;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.web.socket.config.annotation.*;
    
    import javax.annotation.Resource;
    
    @Configuration
    @EnableWebSocket
    public class WebSocketConfig implements WebSocketConfigurer {
        @Resource
        private SocketService socketService;
        @Autowired
        private WebSocketInterceptor webSocketInterceptor;
    
        @Override
        public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
            registry.addHandler(socketService, "/ChatServer").setAllowedOrigins("*").addInterceptors(webSocketInterceptor);
        }
    }
    

    代码简述:上述代码实现spring-websocket中的WebSocketConfigurer方法,为指定的WebSocket连接路径添加处理器和拦截器

3 总结

综合以上内容,可以认为2种实现方式是完全可以相互对比,可以相互替代,而且在实现思想上也基本一致。如:

  • @ServerEndPoint的作用可以类比于WebSocketConfigurer的使用,都是指定了Socket的处理器和前置处理器
  • Spring-WebSocket中的拦截器的作用可以类比于@ServerEndPoint实现方式中的EndPointConfigurer
  • 2个种实现方式的处理器就更不用多说了,形不同,但质相同。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容