一共有4个类,第一个WebSocketConfig 配置类
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Autowired
private WebSocketHandler webSocketHandler;
@Autowired
private WebSocketInterceptor webSocketInterceptor;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(webSocketHandler, "/ws")
.addInterceptors(webSocketInterceptor)
.setAllowedOrigins("*");
}
}
第二个,拦截器,这里我区分了pc和app,因为代码需求是同一个id登录的用户要在pc端和app端同时连接websocket,为做区分,在pc的userid后面加了pc两个字母.
@Component
public class WebSocketInterceptor implements HandshakeInterceptor {
private final Logger logger = LoggerFactory.getLogger(WebSocketInterceptor.class);
@Resource
private ISysUserService userService;
/**
* 握手前
* @param request 请求对象
* @param response 响应对象
* @param wsHandler 请求处理器
* @param attributes 属性域
* @return true放行,false拒绝
* @throws Exception 可能抛出的异常
*/
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, org.springframework.web.socket.WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
// 获得请求参数
Map<String, String> paramMap = HttpUtil.decodeParamMap(request.getURI().getQuery(), Charset.defaultCharset());
String userId = paramMap.get("userId");
if (CharSequenceUtil.isNotBlank(userId)) {
if (userId.endsWith("pc")){
// String substring = userId.substring(0, userId.length() - 2);
// // 校验连接人在系统是否存在
// SysUser user = userService.selectUserById(Long.valueOf(substring));
// if (user == null) {
// response.setStatusCode(HttpStatus.UNAUTHORIZED);
// return false;
// }
}else {
// 校验连接人在系统是否存在
SysUser user = userService.selectUserById(Long.valueOf(userId));
if (user == null) {
response.setStatusCode(HttpStatus.UNAUTHORIZED);
return false;
}
}
// 放入属性域
attributes.put("userId", userId);
logger.info("用户:{}握手成功!", userId);
return true;
} else {
logger.info("接受到一个websocket连接请求但是没有参数!");
}
return false;
}
/**
* 握手后
*
* @param request 请求独享
* @param response 响应对象
* @param wsHandler 处理器
* @param exception 抛出的异常
*/
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, org.springframework.web.socket.WebSocketHandler wsHandler, Exception exception) {
logger.info("握手结束!");
}
}
第三个是管理器,其中的add方法,本身是有一个判重机制,如果该连接已存在就把原来的踢下线,重新连接新的,防止出现多个同样的id的问题.但是这又导致了新的频繁关闭重连的问题,所以后来改成了如果已经存在就直接return
@Slf4j
public class WsSessionManager {
private WsSessionManager() {
}
private static final Logger logger = LoggerFactory.getLogger(WsSessionManager.class);
/**
* 记录当前在线连接数
*/
private static AtomicInteger onlineCount = new AtomicInteger(0);
/**
* 保存连接 session 的地方
*/
private static final ConcurrentHashMap<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>(99999);
/**
* 添加 session
*
* @param key 键
* @param session 值
*/
public static synchronized void add(String key, WebSocketSession session) {
WebSocketSession existingSession = SESSION_POOL.get(key);
if (existingSession != null) {
if (existingSession.equals(session)) {
logger.info("用户 {} 的 WebSocket 已存在,无需重复添加", key);
return;
}
// if (existingSession.isOpen()) {
// try {
// existingSession.close();
// logger.info("关闭旧的连接, userId: {}", key);
// } catch (IOException e) {
// logger.error("关闭旧的连接时出现异常, userId: {}, 异常: {}", key, e.getMessage());
// }
// }
if (existingSession.isOpen()) return;
}
SESSION_POOL.put(key, session);
onlineCount.incrementAndGet();
logger.info("新连接已添加, userId: {}, 当前在线人数: {}", key, getOnlineCount());
}
/**
* 删除 session, 会返回删除的 session
*
* @param key 键
* @return 值
*/
public static synchronized WebSocketSession remove(String key) {
WebSocketSession session = SESSION_POOL.remove(key);
if (session != null) {
onlineCount.decrementAndGet();
logger.info("连接已移除, userId: {}, 当前在线人数: {}", key, getOnlineCount());
}
return session;
}
/**
* 删除并同步关闭连接
*
* @param key 键
*/
public static synchronized void removeAndClose(String key) {
WebSocketSession session = remove(key);
if (session != null) {
try {
session.close();
logger.warn("关闭WebSocket会话, userId: {}", key);
} catch (IOException e) {
logger.error("关闭会话时出现异常, userId: {}, 异常: {}, {}", key, e.getMessage(), e);
}
}
}
/**
* 获得 session
*
* @param key 键
* @return 值
*/
public static WebSocketSession get(String key) {
return SESSION_POOL.get(key);
}
/**
* 获取当前在线连接数
*
* @return 在线连接数
*/
public static int getOnlineCount() {
return onlineCount.get();
}
/**
* 获得 Map
*
* @return 值
*/
public static ConcurrentMap<String, WebSocketSession> getMap() {
return SESSION_POOL;
}
}
第四个是真正发送消息的处理器
@Component
public class WebSocketHandler extends TextWebSocketHandler {
private final Logger logger = LoggerFactory.getLogger(WebSocketHandler.class);
private static final String KEY = "userId";
/**
* socket 建立成功事件
* @param session session对象
*/
@Override
public void afterConnectionEstablished(WebSocketSession session) {
Object userId = session.getAttributes().get(KEY);
if (userId != null) {
// 将用户的连接放入 WsSessionManager,会自动关闭之前的旧连接
WsSessionManager.add(userId.toString(), session);
logger.info("用户连接成功, userId: {}", userId);
} else {
logger.warn("未能在连接中找到 userId 属性");
}
logger.info("建立连接了, 当前在线人数: {}, session: {}, 当前map: {}", WsSessionManager.getOnlineCount(), session, WsSessionManager.getMap());
}
/**
* 接收消息事件
*
* @param session session对象
* @param message 接收到的消息
* @throws Exception 可能抛出的异常
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// 获得客户端传来的消息
// String payload = message.getPayload();
logger.info("收到ws消息: {}", message);
// 返回一条确认消息给发消息的用户
TextMessage responseMessage = new TextMessage("pong");
session.sendMessage(responseMessage);
}
/**
* socket 断开连接时
*
* @param session session对象
* @param status 断开状态
* @throws Exception 可能抛出的异常
*/
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
logger.info("断开连接了,session为{}", session == null ? "" : session);
Object token = session.getAttributes().get(KEY);
if (token != null) {
// 用户退出,移除缓存
WsSessionManager.removeAndClose(token.toString());
}
}
/**
* 发送消息给指定设备
*
* @param serialNumber 序列号
* @param message 消息内容
* @param type 1跳通知 2跳客户 3手机打电话 前端用 8 pc用未读消息数量 9 当前app是否在线,true或者false
* @param noticeId 通知id,已读用
*/
public void sendMessage(String serialNumber, String message, Integer type, Long noticeId) {
WebSocketSession webSocketSession = WsSessionManager.get(serialNumber);
try {
if (webSocketSession != null && webSocketSession.isOpen()) {
JSONObject jsonObject;
jsonObject = JSONObject.of("type", type, "value", message, "noticeId", noticeId);
webSocketSession.sendMessage(new TextMessage(jsonObject.toString()));
logger.info("发送消息给{},消息内容为{}", serialNumber, message);
}
} catch (Exception e) {
logger.error("消息发送失败,设备{},失败原因{}{}", webSocketSession.getAttributes().get(KEY), e.getMessage(), e);
}
}
/**
* 发送消息给指定设备
*
* @param serialNumber 序列号
* @param message 消息内容
* @param type 1跳通知 2跳客户 3手机打电话 前端用 8 pc用未读消息数量 9 当前app是否在线,true或者false
* @param notice 通知整个对象
*/
public void sendMessage(String serialNumber, String message, Integer type, ClientNoticeDO notice, Integer other) {
WebSocketSession webSocketSession = WsSessionManager.get(serialNumber);
try {
if (webSocketSession != null && webSocketSession.isOpen()) {
JSONObject jsonObject = JSONObject.of("type", type, "value", message, "notice", notice);
webSocketSession.sendMessage(new TextMessage(jsonObject.toString()));
logger.info("发送消息给{},消息内容为{}", serialNumber, message);
} else {
logger.warn("WebSocket 会话不可用, userId: {}", serialNumber);
}
} catch (IOException e) {
logger.error("WebSocket 消息发送失败, userId: {}, 原因: {}", serialNumber, e.getMessage(), e);
WsSessionManager.remove(serialNumber); // 自动移除无效会话
} catch (Exception e) {
logger.error("消息发送时发生未知错误, userId: {}, 原因: {}", serialNumber, e.getMessage(), e);
}
}
/**
* 广播消息
*
* @param message 消息
*/
public void sendMessageAll(String message) {
WsSessionManager.getMap().keySet().forEach(e -> sendMessage(e, message, 2, (Long) null));
}
}
其中的sendMessage方法根据自己的业务需求有一个重载方法,正常一个sendMessage就足够了.日志相关的酌情增减.
心跳:在handleTextMessage方法中,接收到前端任何消息都返回一个pong,前端如果一段时间未收到pong就会发起重连,以此保证连接不中断.如果业务有前端发来的其他消息则加个if判断即可.
最终使用的时候注入
@Autowired
private WebSocketHandler webSocketHandler;
//然后调用
webSocketHandler.sendMessage(XXX,XXX,XXX)
//即可.
连接的地址:ws://IP:端口/?userId=1
其中/ws是在WebSocketConfig配置的,
userId是在WebSocketHandler配置的KEY
最后附上在线连接websocket测试的网站:http://www.websocket-test.com/
以及相关可以直接测试的idea插件:CoolRequest