Bootstrap

请求体只能获取一次的解决方案

问题前景

在日常开发中,对于请求体的读取,一般只能读取一次。就好比我们在使用@RequestBody注解时,同一个接口的参数里面只能使用一个@RequestBody。所以在日常开发中有时候不得不使用Map或者是自己封装DTO对象,或者是封装自定义的参数解析注解,但是如果想要在过滤器中对请求体中的参数进行判断。那么读取一次过后当请求在到达Controller层时就没办法在获取了。本案例演示自定义参数解析注解和封装Request请求。

自定义参数解析注解

自定义一个注解MultiRequestBody
package com.we.applet.wifi.utils.annotion;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @Author: 张定辉
 * @CreateDate: 2022/9/3
 * @Description: 自定义多次获取请求体内容注解
 */
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
public @interface MultiRequestBody {
    /**
     * 是否必须出现的参数
     */
    boolean required() default true;

    /**
     * 当value的值或者参数名不匹配时,是否允许解析最外层属性到该对象
     */
    boolean parseAllFields() default true;

    /**
     * 解析时用到的JSON的key,非必须,不填则使用方法参数的名称进行获取
     */
    String value() default "";
}

编写该注解的解析器
package com.we.applet.wifi.handler.annotion;


import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.we.applet.wifi.utils.annotion.MultiRequestBody;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.core.MethodParameter;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;

import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.HashSet;
import java.util.Set;

/**
 * @Author: 张定辉
 * @CreateDate: 2022/9/3
 * @Description: 自定义请求参数注解解析器
 */
public class MultiRequestBodyArgumentResolver implements HandlerMethodArgumentResolver {
    private static final String JSONBODY_ATTRIBUTE = "JSON_REQUEST_BODY";
    
    private static final String REQUIRED_PARAM_RESOLVER_ERROR="required param %s is not present";

    private static final Set<Class<?>> CLASS_SET;

    static{
        CLASS_SET = new HashSet<>();
        CLASS_SET.add(Integer.class);
        CLASS_SET.add(Long.class);
        CLASS_SET.add(Short.class);
        CLASS_SET.add(Float.class);
        CLASS_SET.add(Double.class);
        CLASS_SET.add(Boolean.class);
        CLASS_SET.add(Byte.class);
        CLASS_SET.add(Character.class);
    }

    /**
     * 设置支持的方法参数类型
     *
     * @param parameter 方法参数
     * @return 支持的类型
     */
    @Override
    public boolean supportsParameter(MethodParameter parameter) {
        // 支持带@MultiRequestBody注解的参数
        return parameter.hasParameterAnnotation(MultiRequestBody.class);
    }

    /**
     * 参数解析,利用fastjson
     * 注意:非基本类型返回null会报空指针异常,要通过反射或者JSON工具类创建一个空对象
     */
    @Override
    public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception {
        String jsonBody = getRequestBody(webRequest);
        JSONObject jsonObject = JSON.parseObject(jsonBody);
        // 根据@MultiRequestBody注解value作为json解析的key
        MultiRequestBody parameterAnnotation = parameter.getParameterAnnotation(MultiRequestBody.class);
        //注解的value是JSON的key
        assert parameterAnnotation!=null;
        String key = parameterAnnotation.value();
        Object value;
        // 如果@MultiRequestBody注解没有设置value,则取参数名FrameworkServlet作为json解析的key
        if (StringUtils.isNotEmpty(key)) {
            value = jsonObject.get(key);
            // 如果设置了value但是解析不到,报错
            if (value == null && parameterAnnotation.required()) {
                throw new IllegalArgumentException(String.format( REQUIRED_PARAM_RESOLVER_ERROR, key));
            }
        } else {
            // 注解为设置value则用参数名当做json的key
            key = parameter.getParameterName();
            value = jsonObject.get(key);
        }

        // 获取的注解后的类型 Long
        Class<?> parameterType = parameter.getParameterType();
        // 通过注解的value或者参数名解析,能拿到value进行解析
        if (value != null) {
            //基本类型
            if (parameterType.isPrimitive()) {
                return parsePrimitive(parameterType.getName(), value);
            }
            // 基本类型包装类
            if (isBasicDataTypes(parameterType)) {
                return parseBasicTypeWrapper(parameterType, value);
                // 字符串类型
            } else if (parameterType == String.class) {
                return value.toString();
            }
            // 其他复杂对象
            return JSON.parseObject(value.toString(), parameterType);

        }

        // 解析不到则将整个json串解析为当前参数类型
        if (isBasicDataTypes(parameterType)) {
            if (parameterAnnotation.required()) {
                throw new IllegalArgumentException(String.format( REQUIRED_PARAM_RESOLVER_ERROR, key));
            } else {
                return null;
            }
        }

        // 非基本类型,不允许解析所有字段,必备参数则报错,非必备参数则返回null
        if (!parameterAnnotation.parseAllFields()) {
            // 如果是必传参数抛异常
            if (parameterAnnotation.required()) {
                throw new IllegalArgumentException(String.format( REQUIRED_PARAM_RESOLVER_ERROR, key));
            }
            // 否则返回null
            return null;
        }
        // 非基本类型,允许解析,将外层属性解析
        Object result;
        try {
            result = JSON.parseObject(jsonObject.toString(), parameterType);
        } catch (JSONException jsonException) {
            result = null;
        }

        // 如果非必要参数直接返回,否则如果没有一个属性有值则报错
        if (!parameterAnnotation.required()) {
            return null;
        } else {
            boolean haveValue = false;
            Field[] declaredFields = parameterType.getDeclaredFields();
            for (Field field : declaredFields) {
                field.setAccessible(true);
                if (field.get(result) != null) {
                    haveValue = true;
                    break;
                }
            }
            if (!haveValue) {
                throw new IllegalArgumentException(String.format( REQUIRED_PARAM_RESOLVER_ERROR, key));
            }
            return result;
        }
    }

    /**
     * 基本类型解析
     */
    private Object parsePrimitive(String parameterTypeName, Object value) {
        final String booleanTypeName = "boolean";
        if (booleanTypeName.equals(parameterTypeName)) {
            return Boolean.valueOf(value.toString());
        }
        final String intTypeName = "int";
        if (intTypeName.equals(parameterTypeName)) {
            return Integer.valueOf(value.toString());
        }
        final String charTypeName = "char";
        if (charTypeName.equals(parameterTypeName)) {
            return value.toString().charAt(0);
        }
        final String shortTypeName = "short";
        if (shortTypeName.equals(parameterTypeName)) {
            return Short.valueOf(value.toString());
        }
        final String longTypeName = "long";
        if (longTypeName.equals(parameterTypeName)) {
            return Long.valueOf(value.toString());
        }
        final String floatTypeName = "float";
        if (floatTypeName.equals(parameterTypeName)) {
            return Float.valueOf(value.toString());
        }
        final String doubleTypeName = "double";
        if (doubleTypeName.equals(parameterTypeName)) {
            return Double.valueOf(value.toString());
        }
        final String byteTypeName = "byte";
        if (byteTypeName.equals(parameterTypeName)) {
            return Byte.valueOf(value.toString());
        }
        return null;
    }

    /**
     * 基本类型包装类解析
     */
    private Object parseBasicTypeWrapper(Class<?> parameterType, Object value) {
        if (Number.class.isAssignableFrom(parameterType)) {
            Number number = (Number) value;
            if (parameterType == Integer.class) {
                return number.intValue();
            } else if (parameterType == Short.class) {
                return number.shortValue();
            } else if (parameterType == Long.class) {
                return number.longValue();
            } else if (parameterType == Float.class) {
                return number.floatValue();
            } else if (parameterType == Double.class) {
                return number.doubleValue();
            } else if (parameterType == Byte.class) {
                return number.byteValue();
            }
        } else if (parameterType == Boolean.class) {
            return value.toString();
        } else if (parameterType == Character.class) {
            return value.toString().charAt(0);
        }
        return null;
    }

    /**
     * 判断是否为基本数据类型包装类
     */
    private boolean isBasicDataTypes(Class<?> clazz) {
        return CLASS_SET.contains(clazz);
    }

    /**
     * 获取请求体JSON字符串
     */
    private String getRequestBody(NativeWebRequest webRequest) {
        HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
        assert servletRequest!=null;
        // 有就直接获取
        String jsonBody = (String) webRequest.getAttribute(JSONBODY_ATTRIBUTE, NativeWebRequest.SCOPE_REQUEST);
        // 没有就从请求中读取
        if (jsonBody == null) {
            try {
                jsonBody = IOUtils.toString(servletRequest.getReader());
                webRequest.setAttribute(JSONBODY_ATTRIBUTE, jsonBody, NativeWebRequest.SCOPE_REQUEST);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return jsonBody;
    }
}
编写Web配置类
package com.we.applet.wifi.config;
import com.we.applet.wifi.filter.CSRFFilter;
import com.we.applet.wifi.handler.annotion.MultiRequestBodyArgumentResolver;
import com.we.applet.wifi.common.Constant;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.List;

/**
 * @Author: 张定辉
 * @CreateDate: 2022/9/3
 * @Description: Web请求配置类
 */
@Configuration
public class WebConfig implements WebMvcConfigurer {

    @Resource
    private Constant constant;

    /**
     * 添加自定义参数解析器,解析自定义MultiRequestBody注解
     * @param resolvers initially an empty list
     */
    @Override
    public void addArgumentResolvers(List<HandlerMethodArgumentResolver> resolvers) {
        resolvers.add(new MultiRequestBodyArgumentResolver());
    }

    @Bean
    public HttpMessageConverter<String> responseBodyConverter() {
        return new StringHttpMessageConverter(StandardCharsets.UTF_8);
    }

    @Override
    public void configureMessageConverters(List<HttpMessageConverter<?>> converters) {
        converters.add(responseBodyConverter());
    }
}
应用

请求json示例

//多实体请求
{
    "student":{
        "name":"张三",
        "age":18
    },
    "teacher":{
        "teacherId":"123456",
        "name":"李四"
    }
}

//单实体请求
{
    //不能直接使用属性名的方式,要封装起来
    "student":{
        "name":"张三",
        "age":18
    }
}
@RestController
public class TestCoontroller{
   	@PostMapping("/test")
    /*如果不指定key则以属性名称作为key获取json里面的数据*/
    public Result<Object> test(@MultiRequestBody Student student,@MultiRequestBody Teacher teacher){
        /*省略处理逻辑*/
    }
    
    class Student{
        private String name;
        private Integer age;
        /*省略get,set方法*/
    }
    
    class Teacher{
        private String teacherId;
        private String name;
    }
}

封装Request(推荐使用)

自定义BodyReaderHttpServletRequestWrapper继承自HttpServletRequestWrapper
package com.we.applet.wifi.config;

import org.springframework.util.StreamUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;

/**
 * @Author: 张定辉
 * @CreateDate: 2022/11/27
 * @Description: 重新封装http请求,解决请求体只能读取一次的问题
 */
public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper {
    /**
     *	当传入ServletRequest时,将获取到的流保存在这里,以后所有的读取都从这里读
     */
    private final byte[] requestBody;

    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        requestBody = StreamUtils.copyToByteArray(request.getInputStream());
    }

    @Override
    public ServletInputStream getInputStream() {

        final ByteArrayInputStream bais = new ByteArrayInputStream(requestBody);

        return new ServletInputStream() {

            @Override
            public int read() {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener listener) {
                /*
                不需要使用读取监听
                 */
            }
        };
    }

    @Override
    public BufferedReader getReader(){
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }
}
新建过滤器将默认的Request替换为自定义的
@WebFilter(urlPatterns = "/*",filterName = "logFilter")
@Slf4j
public class LogFilter implements Filter {
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
    	HttpServletRequest newRequest=(HttpServletRequest)request;
    	//如果是文件上传请求则不进行封装,这里吃了个大亏,详情请见:https://blog.csdn.net/weixin_52195362/article/details/135052678?spm=1001.2014.3001.5501
        if(!StringUtils.startsWithIgnoreCase(request.getContentType(), "multipart/")){
            newRequest=new BodyReaderHttpServletRequestWrapper((HttpServletRequest)request);
        }
        filterChain.doFilter(newRequest,response);
    }
}
;