Spring MVC从源码分析到实践如何动态添加拦截器
2020-03-30
目前Java的后端开发基本都会用Spring MVC,在SpringBoot的封装里,拦截器如何添加网上教程千篇一律,即使寻找官方的文档,也只有一种比较死板的添加方式,如下:
1 2 3 4 5 6 7 8
| @EnableWebMvc @Configuration public class WebMvcConfig implements WebMvcConfigurer { @Override public void addInterceptors(InterceptorRegistry registry) { registry.addInterceptor(new LogInterceptor()).addPathPatterns("/**").order(1); } }
|
当我要实现在项目启动加载配置时,根据配置添加拦截器时,就无法实现了,官方提供的这种方式比较被动。
下面从源码分析,内部针对请求处理器加载拦截器的流程。
先来看看@EnableWebMvc
加载了什么配置:
1 2 3 4 5 6 7
| @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) @Documented // 导入了配置类 @Import(DelegatingWebMvcConfiguration.class) public @interface EnableWebMvc { }
|
跟进DelegatingWebMvcConfiguration.class
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| @Configuration(proxyBeanMethods = false) public class DelegatingWebMvcConfiguration extends WebMvcConfigurationSupport { // WebMvcConfigurer组合类,内部对所有WebMvcConfigurer对象遍历调用(意味着可以存在多个WebMvcConfigurer) private final WebMvcConfigurerComposite configurers = new WebMvcConfigurerComposite();
// 注入所有WebMvcConfigurer类型组件(不强制查找) @Autowired(required = false) public void setConfigurers(List<WebMvcConfigurer> configurers) { if (!CollectionUtils.isEmpty(configurers)) { this.configurers.addWebMvcConfigurers(configurers); } }
// 其它方法略。。。 @Override protected void addInterceptors(InterceptorRegistry registry) { this.configurers.addInterceptors(registry); } // 其它方法略。。。 }
|
通过以上可以发现,我们实现WebMvcConfigurer
的组件会委托给DelegatingWebMvcConfiguration
管理,那么DelegatingWebMvcConfiguration
里的回调又会被调用呢,我们走进它的父类WebMvcConfigurationSupport
,这是一个Spring MVC比较核心的配置类,你会看到它没有父类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| public class WebMvcConfigurationSupport implements ApplicationContextAware, ServletContextAware { // 忽略上面代码。。。
// 这里注册请求映射处理器Bean @Bean public RequestMappingHandlerMapping requestMappingHandlerMapping( @Qualifier("mvcContentNegotiationManager") ContentNegotiationManager contentNegotiationManager, @Qualifier("mvcConversionService") FormattingConversionService conversionService, @Qualifier("mvcResourceUrlProvider") ResourceUrlProvider resourceUrlProvider) { // 创建请求映射处理器(这个比较关键,后面会讲到) RequestMappingHandlerMapping mapping = createRequestMappingHandlerMapping(); mapping.setOrder(0); // 注册拦截器 mapping.setInterceptors(getInterceptors(conversionService, resourceUrlProvider));
// 省略不在本主题范围的设置,防止跑题。。 return mapping; }
protected RequestMappingHandlerMapping createRequestMappingHandlerMapping() { return new RequestMappingHandlerMapping(); }
protected final Object[] getInterceptors( FormattingConversionService mvcConversionService, ResourceUrlProvider mvcResourceUrlProvider) { if (this.interceptors == null) { // InterceptorRegistry是用于注册拦截器的封装类,用于添加拦截路径、拦截顺序等,这个类的用处,后面还会讲到 InterceptorRegistry registry = new InterceptorRegistry(); // 调用勾子方法,添加拦截开发者的拦截器 addInterceptors(registry); // 下面是内置的拦截器 registry.addInterceptor(new ConversionServiceExposingInterceptor(mvcConversionService)); registry.addInterceptor(new ResourceUrlProviderExposingInterceptor(mvcResourceUrlProvider)); this.interceptors = registry.getInterceptors(); } return this.interceptors.toArray(); }
// 添加拦截器勾子方法 protected void addInterceptors(InterceptorRegistry registry) { } // 忽略下面的代码 }
|
通过这个类,我们发现addInterceptors
在这里定义为作为一个勾子方法,所能也可以直接继承这个类并覆盖这个方法来添加拦截器。
既然添加拦截器的关键就是RequestMappingHandlerMapping
这个类,那么就尝试一下在我们项目配置类里注入这个RequestMappingHandlerMapping
的Bean,然后调用它的设置方法:
1 2 3 4 5 6 7 8 9
| @Configuration public class Config { @SuppressWarnings("all") @Autowired public void configRequestMappingHandlerMapping(RequestMappingHandlerMapping requestMappingHandlerMapping) { // 这个直接传HandlerInterceptor的么。。 requestMappingHandlerMapping.setInterceptors(Object...); } }
|
但这个传的居然不是HandlerInterceptor
,那要设置什么值呢,还是走走源码看:
1 2 3 4 5 6 7 8 9 10 11 12 13
| protected final Object[] getInterceptors( FormattingConversionService mvcConversionService, ResourceUrlProvider mvcResourceUrlProvider) { if (this.interceptors == null) { InterceptorRegistry registry = new InterceptorRegistry(); addInterceptors(registry); registry.addInterceptor(new ConversionServiceExposingInterceptor(mvcConversionService)); registry.addInterceptor(new ResourceUrlProviderExposingInterceptor(mvcResourceUrlProvider)); // 通过InterceptorRegistry获取拦截器 this.interceptors = registry.getInterceptors(); } return this.interceptors.toArray(); }
|
从源码来看,是通过这个方法返回的,所以就要看InterceptorRegistry
的getInterceptors
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| public class InterceptorRegistry {
private final List<InterceptorRegistration> registrations = new ArrayList<>();
// 我们传进来的HandlerInterceptor被封装成了InterceptorRegistration public InterceptorRegistration addInterceptor(HandlerInterceptor interceptor) { InterceptorRegistration registration = new InterceptorRegistration(interceptor); this.registrations.add(registration); return registration; }
public InterceptorRegistration addWebRequestInterceptor(WebRequestInterceptor interceptor) { WebRequestHandlerInterceptorAdapter adapted = new WebRequestHandlerInterceptorAdapter(interceptor); InterceptorRegistration registration = new InterceptorRegistration(adapted); this.registrations.add(registration); return registration; }
// 返回前内部会对拦截器排序,再调用InterceptorRegistration的getInterceptor方法返回 protected List<Object> getInterceptors() { return this.registrations.stream() .sorted(INTERCEPTOR_ORDER_COMPARATOR) .map(InterceptorRegistration::getInterceptor) .collect(Collectors.toList()); }
private static final Comparator<Object> INTERCEPTOR_ORDER_COMPARATOR = OrderComparator.INSTANCE.withSourceProvider(object -> { if (object instanceof InterceptorRegistration) { return (Ordered) ((InterceptorRegistration) object)::getOrder; } return null; });
}
|
这个类加上注释都不超过100行代码,阅读起来还是比较容易的。好了,顺着这个思路,看看InterceptorRegistration
类的getInterceptor
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| public class InterceptorRegistration { // 忽略前面的。。。
protected Object getInterceptor() { if (this.includePatterns.isEmpty() && this.excludePatterns.isEmpty()) { return this.interceptor; }
String[] include = StringUtils.toStringArray(this.includePatterns); String[] exclude = StringUtils.toStringArray(this.excludePatterns); MappedInterceptor mappedInterceptor = new MappedInterceptor(include, exclude, this.interceptor); if (this.pathMatcher != null) { mappedInterceptor.setPathMatcher(this.pathMatcher); } return mappedInterceptor; } }
|
可以看到内部返回的MappedInterceptor
类,也就是说这才是Spring MVC用得上拦截器类型,我们就可以把我们的配置改成这个类型的:
1 2 3 4 5 6 7 8 9
| @Configuration public class Config { @SuppressWarnings("all") @Autowired public void configRequestMappingHandlerMapping(RequestMappingHandlerMapping requestMappingHandlerMapping) { MappedInterceptor mappedInterceptor = new MappedInterceptor(StringUtils.toStringArray(Arrays.asList("/**"), null, interceptor); requestMappingHandlerMapping.setInterceptors(mappedInterceptor); } }
|
这看上去应该可以了吧,但运行后发现拦截器没有生效,看来只能走一下Spring MVC查找请求处理器的流程了。
走流程之前,我们先记一个flag,我们会再回来的,先查看setInterceptors
源码:
1 2 3 4 5 6 7 8
| public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport implements HandlerMapping, Ordered, BeanNameAware { // 忽略前面 public void setInterceptors(Object... interceptors) { this.interceptors.addAll(Arrays.asList(interceptors)); } // 忽略后面 }
|
AbstractHandlerMapping
是RequestMappingHandlerMapping
类的顶级抽象父类,拦截器就是在这里引用着。
回到Spring MVC请求映射处理器获取拦截器的流程,老地方先看org.springframework.web.servlet.DispatcherServlet#doDispatch
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
| protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception { // 忽略前面的。。。 try { ModelAndView mv = null; Exception dispatchException = null;
try { processedRequest = checkMultipart(request); multipartRequestParsed = (processedRequest != request);
// 请求处理器是通过getHandler来获得 // Determine handler for the current request. mappedHandler = getHandler(processedRequest); if (mappedHandler == null) { noHandlerFound(processedRequest, response); return; }
// Determine handler adapter for the current request. HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());
// Process last-modified header, if supported by the handler. String method = request.getMethod(); boolean isGet = "GET".equals(method); if (isGet || "HEAD".equals(method)) { long lastModified = ha.getLastModified(request, mappedHandler.getHandler()); if (new ServletWebRequest(request, response).checkNotModified(lastModified) && isGet) { return; } }
if (!mappedHandler.applyPreHandle(processedRequest, response)) { return; }
// Actually invoke the handler. mv = ha.handle(processedRequest, response, mappedHandler.getHandler());
if (asyncManager.isConcurrentHandlingStarted()) { return; }
applyDefaultViewName(processedRequest, mv); mappedHandler.applyPostHandle(processedRequest, response, mv); } catch (Exception ex) { dispatchException = ex; } catch (Throwable err) { // As of 4.3, we're processing Errors thrown from handler methods as well, // making them available for @ExceptionHandler methods and other scenarios. dispatchException = new NestedServletException("Handler dispatch failed", err); } processDispatchResult(processedRequest, response, mappedHandler, mv, dispatchException); } // 忽略后面的。。。 }
|
我们进入getHandler
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| @Nullable protected HandlerExecutionChain getHandler(HttpServletRequest request) throws Exception { if (this.handlerMappings != null) { // 注意这个HandlerMapping,当前实例可能是RequestMappingHandlerMapping for (HandlerMapping mapping : this.handlerMappings) { // 处理器执行链获取,这是我们找的关键 HandlerExecutionChain handler = mapping.getHandler(request); if (handler != null) { return handler; } } } return null; }
|
再进入HandlerExecutionChain
的getHandler
方法:
1 2 3 4 5 6 7 8
| public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport implements HandlerMapping, Ordered, BeanNameAware { // 忽略前面 public final HandlerExecutionChain getHandler(HttpServletRequest request) throws Exception { // ... } // 忽略后面 }
|
我们又回到AbstractHandlerMapping
类了,上面的flag实现了,下面列出对我们拦截器流程有用的详细信息:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport implements HandlerMapping, Ordered, BeanNameAware {
private final List<Object> interceptors = new ArrayList<>();
private final List<HandlerInterceptor> adaptedInterceptors = new ArrayList<>();
// 忽略前面
public void setInterceptors(Object... interceptors) { this.interceptors.addAll(Arrays.asList(interceptors)); }
// 初始化应用上下文时调用 @Override protected void initApplicationContext() throws BeansException { // 扩展拦截器 extendInterceptors(this.interceptors); // 检测MappedInterceptor类型的bean,并添加到映射的拦截器列表 detectMappedInterceptors(this.adaptedInterceptors); // 初始化拦截器 initInterceptors(); } // 将符合规范的拦截器interceptors添加到adaptedInterceptors protected void initInterceptors() { if (!this.interceptors.isEmpty()) { for (int i = 0; i < this.interceptors.size(); i++) { Object interceptor = this.interceptors.get(i); if (interceptor == null) { throw new IllegalArgumentException("Entry number " + i + " in interceptors array is null"); } this.adaptedInterceptors.add(adaptInterceptor(interceptor)); } } }
@Override @Nullable public final HandlerExecutionChain getHandler(HttpServletRequest request) throws Exception { Object handler = getHandlerInternal(request); if (handler == null) { handler = getDefaultHandler(); } if (handler == null) { return null; } // Bean name or resolved handler? if (handler instanceof String) { String handlerName = (String) handler; handler = obtainApplicationContext().getBean(handlerName); }
// 1. 走当前类的getHandlerExecutionChain方法 HandlerExecutionChain executionChain = getHandlerExecutionChain(handler, request);
if (logger.isTraceEnabled()) { logger.trace("Mapped to " + handler); } else if (logger.isDebugEnabled() && !request.getDispatcherType().equals(DispatcherType.ASYNC)) { logger.debug("Mapped to " + executionChain.getHandler()); }
if (hasCorsConfigurationSource(handler) || CorsUtils.isPreFlightRequest(request)) { CorsConfiguration config = (this.corsConfigurationSource != null ? this.corsConfigurationSource.getCorsConfiguration(request) : null); CorsConfiguration handlerConfig = getCorsConfiguration(handler, request); config = (config != null ? config.combine(handlerConfig) : handlerConfig); executionChain = getCorsHandlerExecutionChain(request, executionChain, config); }
return executionChain; }
protected HandlerExecutionChain getHandlerExecutionChain(Object handler, HttpServletRequest request) { HandlerExecutionChain chain = (handler instanceof HandlerExecutionChain ? (HandlerExecutionChain) handler : new HandlerExecutionChain(handler));
String lookupPath = this.urlPathHelper.getLookupPathForRequest(request, LOOKUP_PATH); // 2. 查找拦截器是通过adaptedInterceptors,而不是interceptors for (HandlerInterceptor interceptor : this.adaptedInterceptors) { if (interceptor instanceof MappedInterceptor) { MappedInterceptor mappedInterceptor = (MappedInterceptor) interceptor; if (mappedInterceptor.matches(lookupPath, this.pathMatcher)) { chain.addInterceptor(mappedInterceptor.getInterceptor()); } } else { chain.addInterceptor(interceptor); } } return chain; } // 忽略后面 }
|
能过以上源码的走向:1、2,我们发现拦截是从该类属性adaptedInterceptors
上获取,那我们设置到interceptors
属性上不是泡汤了,事实上确实如些,因为从interceptors
迁移到adaptedInterceptors
的调用时机不受我们掌控,在我们调用调用setInterceptors
设置前,就已经迁移完成,也就说上面的initApplicationContext
调用过了,我们设置的拦截器作废了。
现在的思路已经很明了,我们需要设置拦截器到adaptedInterceptors
属性上,框架才会采纳,既然又没提供set方法,那只能用反射了,虽然反射性能比直接调用方法差一些,还好的是调用在项目启用的时候,下面是动态添加拦截器的工具类方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| @Slf4j public class SpringMvcPolyfill { // 缓存反射字段 private static Field adaptedInterceptorsField;
/** * 动态添加拦截器 * @param interceptor 拦截器 * @param includeURLs 需要拦截的URL * @param excludeURLs 排除拦截的URL * @param handlerMapping AbstractHandlerMapping实现类 */ @SuppressWarnings("all") public static void addDynamicInterceptor(HandlerInterceptor interceptor, List<String> includeURLs, List<String> excludeURLs, AbstractHandlerMapping handlerMapping) { String[] include = StringUtils.toStringArray(includeURLs); String[] exclude = StringUtils.toStringArray(excludeURLs); // Interceptor -> MappedInterceptor MappedInterceptor mappedInterceptor = new MappedInterceptor(include, exclude, interceptor); // 下面这行可以省略,但为了保持内部的处理流程,使表达式成立:interceptors.count() == adaptedInterceptors.count() handlerMapping.setInterceptors(mappedInterceptor); try { if (adaptedInterceptorsField == null) { // 虽然用了反射,但这些代码在只在启动时加载 // 查找继承链 // RequestMappingHandlerMapping -> RequestMappingInfoHandlerMapping -> AbstractHandlerMethodMapping -> AbstractHandlerMapping // WelcomePageHandlerMapping -> AbstractUrlHandlerMapping -> AbstractHandlerMapping Class<?> abstractHandlerMapping = handlerMapping.getClass(); while (abstractHandlerMapping != AbstractHandlerMapping.class) { abstractHandlerMapping = abstractHandlerMapping.getSuperclass(); } // TODO <mark> 由于使用底层API, 这个AbstractHandlerMapping.adaptedInterceptors很后的版本可能会改 adaptedInterceptorsField = abstractHandlerMapping.getDeclaredField("adaptedInterceptors"); adaptedInterceptorsField.setAccessible(true); } // 添加到可采纳的拦截器列表,让拦截器处理器Chain流程获取得到这个拦截器 List<HandlerInterceptor> handlerInterceptors = (List<HandlerInterceptor>) adaptedInterceptorsField.get(handlerMapping); handlerInterceptors.add(mappedInterceptor); adaptedInterceptorsField.set(handlerMapping, handlerInterceptors); } catch (Exception e) { log.error("SpringMvcPolyfill invoke AbstractHandlerMapping.adaptedInterceptors error with msg: {}", e.getMessage(), e); } } }
|
配置文件:
1 2 3 4 5 6
| @SuppressWarnings("all") @Autowired public void configRequestMappingHandlerMapping(RequestMappingHandlerMapping requestMappingHandlerMapping) { SpringMvcPolyfill.addDynamicInterceptor(new LogInterceptor(), Collections.singletonList("/**"), null, requestMappingHandlerMapping); }
|
该实现在开源项目中已提供:Milkomeda