Bootstrap

Java WebSocket 中获取httpSession详解

 

一:本文适用范围

本文使用J2EE规范原生的WebSocket开发,适用于项目中WebSocket的使用比较少的情况,而Spring WebSocket封装的比较繁重,反而不适用于实际项目中的情况。

自己在开发时就是由于项目的原因,不想用Spring-WebSocket,所有用了原生的,但是找了好多帖子试了好多方法都不行,甚至将原生的和Spring-WebSocket混用了也是不行,看源码也看了好久,最后终于解决了问题了,记录一下,也希望能帮到大家。

二:配置webSocket

我们先来看一下@ServerEndpoint注解的源码

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package javax.websocket.server;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import javax.websocket.Decoder;
import javax.websocket.Encoder;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface ServerEndpoint {

    /**
     * URI or URI-template that the annotated class should be mapped to.
     * @return The URI or URI-template that the annotated class should be mapped
     *         to.
     */
    String value();

    String[] subprotocols() default {};

    Class<? extends Decoder>[] decoders() default {};

    Class<? extends Encoder>[] encoders() default {};

    public Class<? extends ServerEndpointConfig.Configurator> configurator()
            default ServerEndpointConfig.Configurator.class;
}
package javax.websocket.server;

import java.lang.reflect.InvocationTargetException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.ServiceLoader;

import javax.websocket.Decoder;
import javax.websocket.Encoder;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;

/**
 * Provides configuration information for WebSocket endpoints published to a
 * server. Applications may provide their own implementation or use
 * {@link Builder}.
 */
public interface ServerEndpointConfig extends EndpointConfig {

    ......


    public class Configurator {

        ......

        public void modifyHandshake(ServerEndpointConfig sec,
                HandshakeRequest request, HandshakeResponse response) {
            fetchContainerDefaultConfigurator().modifyHandshake(sec, request, response);
        }

        ......
    }
}

我们看到最后的一个方法,可以看到,它要求一个ServerEndpointConfig.Configurator的子类,而Configurator类中有个modifyHandshake方法,看名字我们也应该知道这是干什么的了,我们写一个类去继承它, 并实现modifyHandshake方法从而来修改握手时的操作,将httpSession加进webSocket的配置中。

package com.demo.config.websocket;

import com.demo.exception.JMakerException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

import javax.servlet.http.HttpSession;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
import java.util.Objects;

/**
 * websocket核心配置
 *
 * @author Zany 2019/5/14
 */
@Slf4j
@Configuration
public class WebSocketConfig extends ServerEndpointConfig.Configurator {

    /**
     * 握手加入httpSession配置
     */
    @Override
    public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
        HttpSession httpSession = (HttpSession)request.getHttpSession();
        if (Objects.isNull(httpSession)){
            log.error("httpSession为空, header = [{}], 请登录!", request.getHeaders());
            throw new JMakerException("httpSession为空, 请登录!");
        }
        log.debug("webSocket握手, sessionId = [{}]", httpSession.getId());
        config.getUserProperties().put("httpSession", httpSession);
    }

    /**
     * 如果不是spring boot项目,那就不需要进行这样的配置,
     * 因为在tomcat中运行的话,tomcat会扫描带有@ServerEndpoint的注解成为websocket,
     * 而spring boot项目中需要由这个bean来提供注册管理。
     */
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}

注:本来还担心用了spring-session-redis,从HandshakeRequest中getHttpSession()时,会不会有问题,看了源码之后就放心了,下面源码可以看到在构造HandshakeRequest时用的HttpServletRequest.getSession(false),不会创建新的session,程序运行之后证明也是的。


package org.apache.tomcat.websocket.server;

import java.net.URI;
import java.net.URISyntaxException;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import javax.servlet.http.HttpServletRequest;
import javax.websocket.server.HandshakeRequest;

import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
import org.apache.tomcat.util.res.StringManager;

/**
 * Represents the request that this session was opened under.
 */
public class WsHandshakeRequest implements HandshakeRequest {

    private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class);

    private final URI requestUri;
    private final Map<String,List<String>> parameterMap;
    private final String queryString;
    private final Principal userPrincipal;
    private final Map<String,List<String>> headers;
    private final Object httpSession;

    private volatile HttpServletRequest request;


    public WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams) {

        this.request = request;

        queryString = request.getQueryString();
        userPrincipal = request.getUserPrincipal();
        httpSession = request.getSession(false);
        requestUri = buildRequestUri(request);

        // ParameterMap
        Map<String,String[]> originalParameters = request.getParameterMap();
        Map<String,List<String>> newParameters =
                new HashMap<>(originalParameters.size());
        for (Entry<String,String[]> entry : originalParameters.entrySet()) {
            newParameters.put(entry.getKey(),
                    Collections.unmodifiableList(
                            Arrays.asList(entry.getValue())));
        }
        for (Entry<String,String> entry : pathParams.entrySet()) {
            newParameters.put(entry.getKey(),
                    Collections.unmodifiableList(
                            Arrays.asList(entry.getValue())));
        }
        parameterMap = Collections.unmodifiableMap(newParameters);

        // Headers
        Map<String,List<String>> newHeaders = new CaseInsensitiveKeyMap<>();

        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String headerName = headerNames.nextElement();

            newHeaders.put(headerName, Collections.unmodifiableList(
                    Collections.list(request.getHeaders(headerName))));
        }

        headers = Collections.unmodifiableMap(newHeaders);
    }

    @Override
    public URI getRequestURI() {
        return requestUri;
    }

    @Override
    public Map<String,List<String>> getParameterMap() {
        return parameterMap;
    }

    @Override
    public String getQueryString() {
        return queryString;
    }

    @Override
    public Principal getUserPrincipal() {
        return userPrincipal;
    }

    @Override
    public Map<String,List<String>> getHeaders() {
        return headers;
    }

    @Override
    public boolean isUserInRole(String role) {
        if (request == null) {
            throw new IllegalStateException();
        }

        return request.isUserInRole(role);
    }

    @Override
    public Object getHttpSession() {
        return httpSession;
    }

    /**
     * Called when the HandshakeRequest is no longer required. Since an instance
     * of this class retains a reference to the current HttpServletRequest that
     * reference needs to be cleared as the HttpServletRequest may be reused.
     *
     * There is no reason for instances of this class to be accessed once the
     * handshake has been completed.
     */
    void finished() {
        request = null;
    }


    /*
     * See RequestUtil.getRequestURL()
     */
    private static URI buildRequestUri(HttpServletRequest req) {

        StringBuffer uri = new StringBuffer();
        String scheme = req.getScheme();
        int port = req.getServerPort();
        if (port < 0) {
            // Work around java.net.URL bug
            port = 80;
        }

        if ("http".equals(scheme)) {
            uri.append("ws");
        } else if ("https".equals(scheme)) {
            uri.append("wss");
        } else {
            // Should never happen
            throw new IllegalArgumentException(
                    sm.getString("wsHandshakeRequest.unknownScheme", scheme));
        }

        uri.append("://");
        uri.append(req.getServerName());

        if ((scheme.equals("http") && (port != 80))
            || (scheme.equals("https") && (port != 443))) {
            uri.append(':');
            uri.append(port);
        }

        uri.append(req.getRequestURI());

        if (req.getQueryString() != null) {
            uri.append("?");
            uri.append(req.getQueryString());
        }

        try {
            return new URI(uri.toString());
        } catch (URISyntaxException e) {
            // Should never happen
            throw new IllegalArgumentException(
                    sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e);
        }
    }
}

三:WebSocket服务代码

package com.demo.config.websocket.server;

import com.demo.config.websocket.WebSocketConfig;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import javax.servlet.http.HttpSession;
import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
@Component
@ServerEndpoint(value = "/client/webSocket/scrollMessage", configurator = WebSocketConfig.class)
public class ScrollMessageWebSocketServer {

    /**
     * 客户端连接数量
     */
    private AtomicInteger onlineCount = new AtomicInteger(0);

    /**
     * 存放客户端对应的Session对象
     */
    private static ConcurrentHashMap<String, Session> webSocketServerPool = new ConcurrentHashMap<>();

    /**
     * 连接成功
     */
    @OnOpen
    public void onOpen(final Session session) {
        log.debug("wsSession中的userProperties = [{}]", session.getUserProperties());
        String userId = getHttpSessionUserId(session);
        final int users = onlineCount.incrementAndGet();
        webSocketServerPool.put(userId, session);
        log.info("用户userId=[{}]连接成功,当前在线人数=[{}]", userId, users);
    }

    /**
     * 连接关闭
     */
    @OnClose
    public void onClose(final Session session) {
        String userId = getHttpSessionUserId(session);
        webSocketServerPool.remove(userId);
        final int users = onlineCount.decrementAndGet();
//        final int users = onlineCount.get();
        log.info("用户userId=[{}]退出!当前在线人数为=[{}]", userId, users);
    }

    /**
     * 收到消息
     */
    @OnMessage
    public void onMessage(final String message, final Session session) {
        String userId = getHttpSessionUserId(session);
        log.debug("收到用户userId=[{}]的消息=[{}]", userId, message);
//        webSocketServerPool.forEach((a, b) -> send2User(message, a));
    }

    /**
     * 错误信息
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(final Session session, final Throwable error) {
        log.error("发生错误", error);
    }

    /**
     * 发送消息
     *
     * @param message
     * @throws IOException
     */
    private static void sendMessage(final Session session, final String message) {
        if (StringUtils.isBlank(message)) {
            return;
        }
        try {
            session.getBasicRemote().sendText(message);
        } catch (IOException e) {
            log.error("发送失败", e);
        }
    }


    /**
     * 发送信息给指定ID用户
     *
     * @param message 消息
     * @param userId  用户ID
     */
    public static void send2User(final String message, final String userId) {
        if (StringUtils.isBlank(userId) || StringUtils.isBlank(message)) {
            log.error("发送失败,消息message=[{}]为空或用户userId=[{}]为空", message, userId);
            return;
        }
        Session session;
        if (Objects.nonNull(session = webSocketServerPool.get(userId))) {
            sendMessage(session, message);
            log.debug("用户userId=[{}]发送发送成功", userId);
            return;
        }
        log.debug("用户userId=[{}]发送失败,不存在或已下线", userId);
    }

    /**
     * 发送信息给指定用户
     *
     * @param message    消息
     * @param userIdList 用户列表
     */
    public static void send2User(final String message, final List<String> userIdList) {
        if (StringUtils.isBlank(message) || CollectionUtils.isEmpty(userIdList)) {
            return;
        }
        userIdList.forEach((userId) -> send2User(message, userId));
    }

    /**
     * 发送信息给所有用户
     *
     * @param message    消息
     */
    public static void send2User(final String message) {
        if (StringUtils.isBlank(message)) {
            return;
        }
        webSocketServerPool.forEach((userId, session) -> sendMessage(session, message));
    }

    private String getHttpSessionUserId(Session session){
        HttpSession httpSession = (HttpSession) session.getUserProperties().get("httpSession");
        if (Objects.isNull(httpSession)){
            log.error("httpSession为空, 请登录");
            try {
                session.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return httpSession.getAttribute("sessionUserId").toString();
    }
}

最后解释下为什么我直接从WebSocket的session中getUserProperties(),本来在配置中是将属性加在EndpointConfig的实现类ServerEndpointConfig中的,所有我们看看Session的源码,可以看到在构造session时将endpointConfig中的userProperties全部加入了session中

public class WsSession implements Session {

    public WsSession(Endpoint localEndpoint,
            WsRemoteEndpointImplBase wsRemoteEndpoint,
            WsWebSocketContainer wsWebSocketContainer,
            URI requestUri, Map<String, List<String>> requestParameterMap,
            String queryString, Principal userPrincipal, String httpSessionId,
            List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters,
            boolean secure, EndpointConfig endpointConfig) throws DeploymentException {
        this.localEndpoint = localEndpoint;
        this.wsRemoteEndpoint = wsRemoteEndpoint;
        this.wsRemoteEndpoint.setSession(this);
        this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
        this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint);
        this.webSocketContainer = wsWebSocketContainer;
        applicationClassLoader = Thread.currentThread().getContextClassLoader();
        wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout());
        this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize();
        this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize();
        this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
        this.requestUri = requestUri;
        if (requestParameterMap == null) {
            this.requestParameterMap = Collections.emptyMap();
        } else {
            this.requestParameterMap = requestParameterMap;
        }
        this.queryString = queryString;
        this.userPrincipal = userPrincipal;
        this.httpSessionId = httpSessionId;
        this.negotiatedExtensions = negotiatedExtensions;
        if (subProtocol == null) {
            this.subProtocol = "";
        } else {
            this.subProtocol = subProtocol;
        }
        this.pathParameters = pathParameters;
        this.secure = secure;
        this.wsRemoteEndpoint.setEncoders(endpointConfig);
        this.endpointConfig = endpointConfig;

        //在构造session时将endpointConfig中的userProperties全部加入了session中
        this.userProperties.putAll(endpointConfig.getUserProperties());
        this.id = Long.toHexString(ids.getAndIncrement());

        InstanceManager instanceManager = webSocketContainer.getInstanceManager();
        if (instanceManager == null) {
            instanceManager = InstanceManagerBindings.get(applicationClassLoader);
        }
        if (instanceManager != null) {
            try {
                instanceManager.newInstance(localEndpoint);
            } catch (Exception e) {
                throw new DeploymentException(sm.getString("wsSession.instanceNew"), e);
            }
        }

        if (log.isDebugEnabled()) {
            log.debug(sm.getString("wsSession.created", id));
        }
    }
}

 

;