Bootstrap

websocket前后端长连接之java部分

一共有4个类,第一个WebSocketConfig 配置类

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Autowired
    private WebSocketHandler webSocketHandler;

    @Autowired
    private WebSocketInterceptor webSocketInterceptor;

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(webSocketHandler, "/ws")
                .addInterceptors(webSocketInterceptor)
                .setAllowedOrigins("*");
    }
}

第二个,拦截器,这里我区分了pc和app,因为代码需求是同一个id登录的用户要在pc端和app端同时连接websocket,为做区分,在pc的userid后面加了pc两个字母.

@Component
public class WebSocketInterceptor implements HandshakeInterceptor {
    private final Logger logger = LoggerFactory.getLogger(WebSocketInterceptor.class);

    @Resource
    private ISysUserService userService;
    /**
     * 握手前
     * @param request    请求对象
     * @param response   响应对象
     * @param wsHandler  请求处理器
     * @param attributes 属性域
     * @return true放行,false拒绝
     * @throws Exception 可能抛出的异常
     */
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, org.springframework.web.socket.WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        // 获得请求参数
        Map<String, String> paramMap = HttpUtil.decodeParamMap(request.getURI().getQuery(), Charset.defaultCharset());
        String userId = paramMap.get("userId");
        if (CharSequenceUtil.isNotBlank(userId)) {
            if (userId.endsWith("pc")){
//                String substring = userId.substring(0, userId.length() - 2);
//                // 校验连接人在系统是否存在
//                SysUser user = userService.selectUserById(Long.valueOf(substring));
//                if (user == null) {
//                    response.setStatusCode(HttpStatus.UNAUTHORIZED);
//                    return false;
//                }
            }else {
                // 校验连接人在系统是否存在
                SysUser user = userService.selectUserById(Long.valueOf(userId));
                if (user == null) {
                    response.setStatusCode(HttpStatus.UNAUTHORIZED);
                    return false;
                }
            }
            // 放入属性域
            attributes.put("userId", userId);
            logger.info("用户:{}握手成功!", userId);
            return true;
        } else {
            logger.info("接受到一个websocket连接请求但是没有参数!");
        }
        return false;
    }

    /**
     * 握手后
     *
     * @param request   请求独享
     * @param response  响应对象
     * @param wsHandler 处理器
     * @param exception 抛出的异常
     */
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, org.springframework.web.socket.WebSocketHandler wsHandler, Exception exception) {
        logger.info("握手结束!");
    }
}

第三个是管理器,其中的add方法,本身是有一个判重机制,如果该连接已存在就把原来的踢下线,重新连接新的,防止出现多个同样的id的问题.但是这又导致了新的频繁关闭重连的问题,所以后来改成了如果已经存在就直接return

@Slf4j
public class WsSessionManager {
    private WsSessionManager() {
    }

    private static final Logger logger = LoggerFactory.getLogger(WsSessionManager.class);

    /**
     * 记录当前在线连接数
     */
    private static AtomicInteger onlineCount = new AtomicInteger(0);

    /**
     * 保存连接 session 的地方
     */
    private static final ConcurrentHashMap<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>(99999);

    /**
     * 添加 session
     *
     * @param key     键
     * @param session 值
     */
    public static synchronized void add(String key, WebSocketSession session) {
        WebSocketSession existingSession = SESSION_POOL.get(key);
        if (existingSession != null) {
            if (existingSession.equals(session)) {
                logger.info("用户 {} 的 WebSocket 已存在,无需重复添加", key);
                return;
            }
//            if (existingSession.isOpen()) {
//                try {
//                    existingSession.close();
//                    logger.info("关闭旧的连接, userId: {}", key);
//                } catch (IOException e) {
//                    logger.error("关闭旧的连接时出现异常, userId: {}, 异常: {}", key, e.getMessage());
//                }
//            }
            if (existingSession.isOpen()) return;

        }
        SESSION_POOL.put(key, session);
        onlineCount.incrementAndGet();
        logger.info("新连接已添加, userId: {}, 当前在线人数: {}", key, getOnlineCount());
    }


    /**
     * 删除 session, 会返回删除的 session
     *
     * @param key 键
     * @return 值
     */
    public static synchronized WebSocketSession remove(String key) {
        WebSocketSession session = SESSION_POOL.remove(key);
        if (session != null) {
            onlineCount.decrementAndGet();
            logger.info("连接已移除, userId: {}, 当前在线人数: {}", key, getOnlineCount());
        }
        return session;
    }

    /**
     * 删除并同步关闭连接
     *
     * @param key 键
     */
    public static synchronized void removeAndClose(String key) {
        WebSocketSession session = remove(key);
        if (session != null) {
            try {
                session.close();
                logger.warn("关闭WebSocket会话, userId: {}", key);
            } catch (IOException e) {
                logger.error("关闭会话时出现异常, userId: {}, 异常: {}, {}", key, e.getMessage(), e);
            }
        }
    }

    /**
     * 获得 session
     *
     * @param key 键
     * @return 值
     */
    public static WebSocketSession get(String key) {
        return SESSION_POOL.get(key);
    }

    /**
     * 获取当前在线连接数
     *
     * @return 在线连接数
     */
    public static int getOnlineCount() {
        return onlineCount.get();
    }

    /**
     * 获得 Map
     *
     * @return 值
     */
    public static ConcurrentMap<String, WebSocketSession> getMap() {
        return SESSION_POOL;
    }
}

第四个是真正发送消息的处理器

@Component
public class WebSocketHandler extends TextWebSocketHandler {
    private final Logger logger = LoggerFactory.getLogger(WebSocketHandler.class);

    private static final String KEY = "userId";


    /**
     * socket 建立成功事件
     * @param session session对象
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        Object userId = session.getAttributes().get(KEY);
        if (userId != null) {
            // 将用户的连接放入 WsSessionManager,会自动关闭之前的旧连接
            WsSessionManager.add(userId.toString(), session);
            logger.info("用户连接成功, userId: {}", userId);
        } else {
            logger.warn("未能在连接中找到 userId 属性");
        }
        logger.info("建立连接了, 当前在线人数: {}, session: {}, 当前map: {}", WsSessionManager.getOnlineCount(), session, WsSessionManager.getMap());
    }


    /**
     * 接收消息事件
     *
     * @param session session对象
     * @param message 接收到的消息
     * @throws Exception 可能抛出的异常
     */
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        // 获得客户端传来的消息
//        String payload = message.getPayload();
        logger.info("收到ws消息: {}", message);
        // 返回一条确认消息给发消息的用户
        TextMessage responseMessage = new TextMessage("pong");
        session.sendMessage(responseMessage);
    }

    /**
     * socket 断开连接时
     *
     * @param session session对象
     * @param status  断开状态
     * @throws Exception 可能抛出的异常
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        logger.info("断开连接了,session为{}", session == null ? "" : session);
        Object token = session.getAttributes().get(KEY);
        if (token != null) {
            // 用户退出,移除缓存
            WsSessionManager.removeAndClose(token.toString());
        }
    }

    /**
     * 发送消息给指定设备
     *
     * @param serialNumber 序列号
     * @param message      消息内容
     * @param type         1跳通知 2跳客户 3手机打电话 前端用  8 pc用未读消息数量   9 当前app是否在线,true或者false
     * @param noticeId     通知id,已读用
     */
    public void sendMessage(String serialNumber, String message, Integer type, Long noticeId) {
        WebSocketSession webSocketSession = WsSessionManager.get(serialNumber);
        try {
            if (webSocketSession != null && webSocketSession.isOpen()) {
                JSONObject jsonObject;
                jsonObject = JSONObject.of("type", type, "value", message, "noticeId", noticeId);
                webSocketSession.sendMessage(new TextMessage(jsonObject.toString()));
                logger.info("发送消息给{},消息内容为{}", serialNumber, message);
            }
        } catch (Exception e) {
            logger.error("消息发送失败,设备{},失败原因{}{}", webSocketSession.getAttributes().get(KEY), e.getMessage(), e);
        }
    }

    /**
     * 发送消息给指定设备
     *
     * @param serialNumber 序列号
     * @param message      消息内容
     * @param type         1跳通知 2跳客户 3手机打电话 前端用  8 pc用未读消息数量   9 当前app是否在线,true或者false
     * @param notice     通知整个对象
     */
    public void sendMessage(String serialNumber, String message, Integer type, ClientNoticeDO notice, Integer other) {
        WebSocketSession webSocketSession = WsSessionManager.get(serialNumber);
        try {
            if (webSocketSession != null && webSocketSession.isOpen()) {
                JSONObject jsonObject = JSONObject.of("type", type, "value", message, "notice", notice);
                webSocketSession.sendMessage(new TextMessage(jsonObject.toString()));
                logger.info("发送消息给{},消息内容为{}", serialNumber, message);
            } else {
                logger.warn("WebSocket 会话不可用, userId: {}", serialNumber);
            }
        } catch (IOException e) {
            logger.error("WebSocket 消息发送失败, userId: {}, 原因: {}", serialNumber, e.getMessage(), e);
            WsSessionManager.remove(serialNumber); // 自动移除无效会话
        } catch (Exception e) {
            logger.error("消息发送时发生未知错误, userId: {}, 原因: {}", serialNumber, e.getMessage(), e);
        }
    }


    /**
     * 广播消息
     *
     * @param message 消息
     */
    public void sendMessageAll(String message) {
        WsSessionManager.getMap().keySet().forEach(e -> sendMessage(e, message, 2, (Long) null));
    }
}

其中的sendMessage方法根据自己的业务需求有一个重载方法,正常一个sendMessage就足够了.日志相关的酌情增减.

心跳:在handleTextMessage方法中,接收到前端任何消息都返回一个pong,前端如果一段时间未收到pong就会发起重连,以此保证连接不中断.如果业务有前端发来的其他消息则加个if判断即可.

最终使用的时候注入

@Autowired
private WebSocketHandler webSocketHandler;
  //然后调用
  webSocketHandler.sendMessage(XXX,XXX,XXX)
  //即可.

连接的地址:ws://IP:端口/?userId=1
其中/ws是在WebSocketConfig配置的,
userId是在WebSocketHandler配置的KEY

最后附上在线连接websocket测试的网站:http://www.websocket-test.com/
以及相关可以直接测试的idea插件:CoolRequest
在这里插入图片描述

;