Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support get pattern from Spring requestMappingInfo #534

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.arex.agent.bootstrap.util;

import java.util.Objects;
import java.util.function.Function;

public class ArrayUtils {
Expand Down Expand Up @@ -51,4 +52,22 @@ public static String toString(Object[] array, Function<Object, String> parser) {
builder.append("\", \"");
}
}

public static boolean equals(String[] array1, String[] array2) {
if (array1 == array2) {
return true;
}
if (array1 == null || array2 == null) {
return false;
}
if (array1.length != array2.length) {
return false;
}
for (int i = 0; i < array1.length; i++) {
if (!Objects.equals(array1[i], array2[i])) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,49 @@ public static <T> Constructor<T> getConstructor(Class<T> clazz, Class<?>... argT
return null;
}
}

/**
* The caller not need to consider the issue of dependency class and method version differences.
*/
public static Method getMethodWithoutClassType(String className, String methodName, String... argTypes) {
try {
Class<?> clazz = Class.forName(className, false, Thread.currentThread().getContextClassLoader());
Method[] methods = clazz.getDeclaredMethods();
for (Method method : methods) {
if (methodName.equals(method.getName()) && sameParamType(argTypes, method.getParameterTypes())) {
return method;
}
}
} catch (Exception e) {
// ignore exception
}
return null;
}

private static boolean sameParamType(String[] argTypeArray, Class<?>[] parameterTypes) {
if (ArrayUtils.isEmpty(argTypeArray) && ArrayUtils.isEmpty(parameterTypes)) {
return true;
}
List<String> parameterTypeList = new ArrayList<>();
if (ArrayUtils.isNotEmpty(parameterTypes)) {
for (Class<?> parameterType : parameterTypes) {
parameterTypeList.add(parameterType.getName());
}
}
return ArrayUtils.equals(argTypeArray, parameterTypeList.toArray(StringUtil.EMPTY_STRING_ARRAY));
}

public static Object invoke(Method method, Object instance, Object... args) {
if (method == null) {
return null;
}
try {
method.setAccessible(true);
return method.invoke(instance, args);
} catch (Exception e) {
return null;
} finally {
method.setAccessible(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,21 @@ static Stream<Arguments> ArrayUtilsToStringCase() {
);
}

@ParameterizedTest
@MethodSource("equalsCase")
void equals(String[] array1, String[] array2, Predicate<Boolean> asserts) {
asserts.test(ArrayUtils.equals(array1, array2));
}

static Stream<Arguments> equalsCase() {
Predicate<Boolean> assertTrue = bool -> bool;
Predicate<Boolean> assertFalse = bool -> !bool;
return Stream.of(
arguments(null, null, assertTrue),
arguments(new String[]{"mock"}, null, assertFalse),
arguments(new String[]{"mock"}, new String[]{}, assertFalse),
arguments(new String[]{"mock"}, new String[]{"mock1"}, assertFalse),
arguments(new String[]{"mock"}, new String[]{"mock"}, assertTrue)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.LinkedList;
import java.util.TreeSet;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -59,4 +61,29 @@ void getMethod() {
void getConstructor() {
assertNotNull(ReflectUtil.getConstructor(String.class));
}

@ParameterizedTest
@CsvSource(value ={
"java.lang.String, indexOf, int",
"java.lang.String, hashCode, null",
"java.lang.String, noExistMethod, null"
}, nullValues={"null"})
void getMethodWithoutClassType(String className, String methodName, String argTypes) {
if (argTypes == null) {
assertDoesNotThrow(() -> ReflectUtil.getMethodWithoutClassType(className, methodName));
} else {
assertDoesNotThrow(() -> ReflectUtil.getMethodWithoutClassType(className, methodName, argTypes));
}
}

@ParameterizedTest
@CsvSource(value ={
"null, instance, param",
"hashCode, instance, param",
"hashCode, instance, param"
}, nullValues={"null"})
void invoke(String methodName, String instance, String param) throws Exception {
assertDoesNotThrow(() ->
ReflectUtil.invoke(methodName == null ? null : String.class.getDeclaredMethod(methodName), instance, param));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,15 @@ private static <TRequest> boolean shouldSkip(ServletAdapter<TRequest, ?> adapter
if (StringUtil.isEmpty(requestURI)) {
return false;
}

String pattern = adapter.getPattern(httpServletRequest);
// As long as one parameter is hit in includeServiceOperations, the operation will not be skipped
if (CollectionUtil.isNotEmpty(Config.get().getIncludeServiceOperations()) &&
!(IgnoreUtils.includeOperation(adapter.getPattern(httpServletRequest)) ||
!(IgnoreUtils.includeOperation(pattern) ||
IgnoreUtils.includeOperation(requestURI))) {
return true;
}
// As long as one parameter is hit in excludeServiceOperations, the operation will be skipped
if (IgnoreUtils.excludeOperation(adapter.getPattern(httpServletRequest)) ||
if (IgnoreUtils.excludeOperation(pattern) ||
IgnoreUtils.excludeOperation(requestURI)) {
return true;
}
Expand All @@ -261,7 +261,7 @@ private static <TRequest> boolean shouldSkip(ServletAdapter<TRequest, ?> adapter
return true;
}

return Config.get().invalidRecord(requestURI);
return Config.get().invalidRecord(pattern);
}

private static <TRequest, TResponse> String getRedirectRecordId(ServletAdapter<TRequest, TResponse> adapter,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package io.arex.inst.httpservlet;

import io.arex.agent.bootstrap.util.CollectionUtil;
import io.arex.agent.bootstrap.util.ReflectUtil;
import org.springframework.context.ApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;
import org.springframework.web.servlet.mvc.method.RequestMappingInfoHandlerMapping;
import org.springframework.web.util.ServletRequestPathUtils;
import org.springframework.web.util.pattern.PathPattern;

import java.lang.reflect.Method;
import java.util.Set;

/**
* The caller should consider the issue of Spring version differences that may throw exceptions.
* (NoSuchMethodException, NoSuchMethodError, etc.)
*/
public class SpringUtil {

private static Method getMatchMethod = null;

private static Method parseAndCacheMethod = null;

private static ApplicationContext springApplicationContext = null;

private static Set<RequestMappingInfo> springRequestMappings = null;

public static ApplicationContext getApplicationContext(Object servletContext) {
if (springApplicationContext != null) {
return springApplicationContext;
}
Method findWebApplicationContextMethod = ReflectUtil.getMethodWithoutClassType(WebApplicationContextUtils.class.getName(),
"findWebApplicationContext", servletContext.getClass().getInterfaces()[0].getName());
springApplicationContext = (ApplicationContext) ReflectUtil.invoke(findWebApplicationContextMethod, null, servletContext);
return springApplicationContext;
}

public static Set<RequestMappingInfo> getAllRequestMappingInfo(Object servletContext) {
if (springRequestMappings == null) {
ApplicationContext applicationContext = getApplicationContext(servletContext);
if (applicationContext == null) {
return null;
}
RequestMappingInfoHandlerMapping handlerMapping = applicationContext.getBean(RequestMappingInfoHandlerMapping.class);
springRequestMappings = handlerMapping.getHandlerMethods().keySet();
}
return springRequestMappings;
}

public static String getPatternFromRequestMapping(Object httpServletRequest, Object servletContext) {
Set<RequestMappingInfo> requestMappingInfos = getAllRequestMappingInfo(servletContext);
if (CollectionUtil.isEmpty(requestMappingInfos)) {
return null;
}

setPathAttribute(httpServletRequest);

if (getMatchMethod == null) {
getMatchMethod = ReflectUtil.getMethodWithoutClassType(RequestMappingInfo.class.getName(),
"getMatchingCondition", httpServletRequest.getClass().getInterfaces()[0].getName());
}

Set<PathPattern> patterns = null;
for (RequestMappingInfo mappingInfo : requestMappingInfos) {
/*
* execute RequestMappingInfo.getMatchingCondition to get matched pattern
* eg: /user/lucas/22 -> /user/{name}/{age}
*/
RequestMappingInfo matchMapping = (RequestMappingInfo) ReflectUtil.invoke(getMatchMethod, mappingInfo, httpServletRequest);
if (matchMapping != null && matchMapping.getPathPatternsCondition() != null) {
patterns = matchMapping.getPathPatternsCondition().getPatterns();
break;
}
}

if (CollectionUtil.isEmpty(patterns)) {
return null;
}
return patterns.iterator().next().getPatternString();
}

private static void setPathAttribute(Object httpServletRequest) {
if (parseAndCacheMethod == null) {
parseAndCacheMethod = ReflectUtil.getMethodWithoutClassType(ServletRequestPathUtils.class.getName(),
"parseAndCache", httpServletRequest.getClass().getInterfaces()[0].getName());
}
// execute ServletRequestPathUtils.parseAndCache, because getMatchingCondition need path attribute from request
ReflectUtil.invoke(parseAndCacheMethod, null, httpServletRequest);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.arex.agent.bootstrap.util.IOUtils;
import io.arex.agent.bootstrap.util.StringUtil;
import io.arex.inst.httpservlet.SpringUtil;
import io.arex.inst.httpservlet.adapter.ServletAdapter;
import io.arex.inst.httpservlet.wrapper.CachedBodyRequestWrapperV3;
import io.arex.inst.httpservlet.wrapper.CachedBodyResponseWrapperV3;
Expand Down Expand Up @@ -138,18 +139,36 @@ public String getRequestURI(HttpServletRequest httpServletRequest) {

@Override
public String getPattern(HttpServletRequest httpServletRequest) {
// in org.springframework.web.servlet.DispatcherServlet#doService set pattern attribute
Object pattern = httpServletRequest.getAttribute("org.springframework.web.servlet.HandlerMapping.bestMatchingPattern");
if (pattern != null) {
return String.valueOf(pattern);
}
/*
* if can't get pattern attribute from request, try to get pattern from request mapping in spring applicationContext
* maybe called in filter before DispatcherServlet#doService (filter -> service)
*/
String patternStr = getPatternFromRequestMapping(httpServletRequest);
if (StringUtil.isNotEmpty(patternStr)) {
return patternStr;
}
final String requestURI = httpServletRequest.getRequestURI();
if (StringUtil.isNotEmpty(httpServletRequest.getContextPath()) && requestURI.contains(
httpServletRequest.getContextPath())) {
httpServletRequest.getContextPath())) {
return requestURI.replace(httpServletRequest.getContextPath(), "");
}
return requestURI;
}

public String getPatternFromRequestMapping(HttpServletRequest httpServletRequest) {
try {
return SpringUtil.getPatternFromRequestMapping(httpServletRequest, httpServletRequest.getServletContext());
} catch (Throwable ignore) {
// ignore exception
}
return null;
}

@Override
public String getResponseHeader(HttpServletResponse httpServletResponse, String name) {
return httpServletResponse.getHeader(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.arex.agent.bootstrap.util.IOUtils;
import io.arex.agent.bootstrap.util.StringUtil;
import io.arex.inst.httpservlet.SpringUtil;
import io.arex.inst.httpservlet.adapter.ServletAdapter;
import io.arex.inst.httpservlet.listener.ServletAsyncListenerV5;
import io.arex.inst.httpservlet.wrapper.CachedBodyRequestWrapperV5;
Expand Down Expand Up @@ -138,18 +139,36 @@ public String getRequestURI(HttpServletRequest httpServletRequest) {

@Override
public String getPattern(HttpServletRequest httpServletRequest) {
// in org.springframework.web.servlet.DispatcherServlet#doService set pattern attribute
Object pattern = httpServletRequest.getAttribute("org.springframework.web.servlet.HandlerMapping.bestMatchingPattern");
if (pattern != null) {
return String.valueOf(pattern);
}
/*
* if can't get pattern attribute from request, try to get pattern from request mapping in spring applicationContext
* maybe called in filter before DispatcherServlet#doService (filter -> service)
*/
String patternStr = getPatternFromRequestMapping(httpServletRequest);
if (StringUtil.isNotEmpty(patternStr)) {
return patternStr;
}
final String requestURI = httpServletRequest.getRequestURI();
if (StringUtil.isNotEmpty(httpServletRequest.getContextPath()) && requestURI.contains(
httpServletRequest.getContextPath())) {
httpServletRequest.getContextPath())) {
return requestURI.replace(httpServletRequest.getContextPath(), "");
}
return requestURI;
}

public String getPatternFromRequestMapping(HttpServletRequest httpServletRequest) {
try {
return SpringUtil.getPatternFromRequestMapping(httpServletRequest, httpServletRequest.getServletContext());
} catch (Throwable ignore) {
// ignore exception
}
return null;
}

@Override
public String getResponseHeader(HttpServletResponse httpServletResponse, String name) {
return httpServletResponse.getHeader(name);
Expand Down
Loading