如何利用JDK代理以及注解简单实现SpringSecurity的方法权限控制

SpringSecuriry 方法权限控制

在java-web中,经常会涉及权限问题:restful接口的权限访问、数据库接口权限访问、方法权限访问等等,这些权限的问题,如果要手动去实现的话,对于restful接口的权限访问还比较好解决,写一个HttpFilter拦截器,拦截需要验证的url进行相应的权限验证即可;但是方法呢?方法可没有像HttpFiler一样的东西呀,怎么去实现这个方法的权限拦截验证?这时,SpringSecurity解决了这个问题,利用注解以及依赖注入、动态代理,成功的解决了方法的权限拦截这个问题。

利用代理将方法执行方反转

参考以前的一篇文章——如何简单实现JDK的动态代理功能

实现注解扫描

实现了方法的拦截后,还有一个问题,怎么才能知道,那些方法是需要权限验证的?简单的来说,就是怎么去查找到那些被注解的类以及方法。这个时候,就需要使用JDK中线程运行的上下文类环境,通过代码

1
2
String filePath = packageName.replace('.', '/');
Enumeration<URL> dirs = Thread.currentThread().getContextClassLoader().getResources(filePath)

就可以扫描所指定的包,扫描的结果都存放在了Enumeration<URL> dirs中,随后,我们可以通过对URL的协议进行分别处理,在这里,我只简单的处理了——只对file协议进行扫描

1
2
3
4
5
6
while (dirs.hasMoreElements()) {
URL url = dirs.nextElement();
if ("file".equals(url.getProtocol())) {
fileScan(packageName, URLDecoder.decode(url.getFile(), "UTF-8"));
}
}

fileScan方法则是进一布对URL的子目录、文件进行扫描操作,如果遇到的是目录类型,就递归调用fileScan方法,如果发现是以.class结尾的文件,表明找到了类文件,这个时候,就把类的路径(packageName + className)存储至Set集合中,等待下一步的操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
protected void fileScan(String packageName, String filePath) {
File dir = new File(filePath);
if (dir.exists()) {
File[] files = dir.listFiles(pathname -> (true && pathname.isDirectory()) || (pathname.getName().endsWith(".class")));
for (File file : files) {
if (file.isDirectory()) {
fileScan(packageName + "." + file.getName(), file.getPath());
} else {
String className = file.getName().substring(0, file.getName().length() - 6);
try {
classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + "." + className));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}

经过这一步的处理之后,我们就得到了一个装了被注解@EnableSecure以及@Secure所修饰的类的类名的Set集合了,有了类名,就下来就好办了。我们只需要利用java自带的API——Thread.currentThread().getContextClassLoader().loadClass(),根据类名去装载这个类,然后,利用java的反射机制获取这个类下所共用的方法(public域的方法)数组Method[],最后只需要去便利这个方法数组,利用Method所带有的方法getAnnotation()获取被注解@Secure所注解的方法,将这些方法装进List或者Set容器,又或者数组,存储这些方法起来,剩下的就只有JDK动态代理去实现方法的权限拦截验证了。

后期优化

目前实现的只是简单的方法拦截,需要继承代理接口才可以实现方法的拦截,后期打算参考参考依赖注入,实现cglib动态代理

最终代码实现

注解扫描代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@Data
public class ScanPackage {

private String packageName;
private Set<Class> classes;

/**
* 构造函数,传入要扫描的包名
* @param packageName
*/
public ScanPackage(String packageName) {
this.packageName = packageName;
this.classes = new HashSet<>();
sacn();
}

/**
* 扫描指定包下的所有class文件
*/
protected void sacn() {
String filePath = packageName.replace('.', '/');
try {
Enumeration<URL> dirs = Thread.currentThread().getContextClassLoader()
.getResources(filePath);
while (dirs.hasMoreElements()) {
URL url = dirs.nextElement();
if ("file".equals(url.getProtocol())) {
fileScan(packageName, URLDecoder.decode(url.getFile(), "UTF-8"));
}
}
} catch (IOException e) {
e.printStackTrace();
}
}

/**
* 扫描所有class文件,并将扫描到的class文件装载进Set集合
* @param packageName
* @param filePath
*/
protected void fileScan(String packageName, String filePath) {
File dir = new File(filePath);
if (dir.exists()) {
File[] files = dir.listFiles(pathname -> (true && pathname.isDirectory())
|| (pathname.getName().endsWith(".class")));
for (File file : files) {
if (file.isDirectory()) {
fileScan(packageName + "." + file.getName(), file.getPath());
} else {
String className = file.getName().substring(0, file.getName().length() - 6);
try {
classes.add(Thread.currentThread().getContextClassLoader()
.loadClass(packageName + "." + className));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}

}

过滤出被注解的类以及方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@Data
public class MethodPermissionFilter {

private HashMap<String, Method[]> proxyMap;

public MethodPermissionFilter() {
String packageName = Main.class.getPackage().getName();
ScanPackage scanPackage = new ScanPackage(packageName);
proxyMap = getEnableSecureClass(scanPackage.getClasses());
}

/**
* 利用Java原生的lambda表达式,构建{@link HashMap<String, Method[]>} String => 方法类名,Method[] => 该类的所有方法
* @param classes
* @return
*/
private HashMap<String, Method[]> getEnableSecureClass(Set<Class> classes) {
return classes.stream().filter(cls -> cls.getAnnotation(EnableSecure.class) != null)
.collect(HashMap::new, (m, v) -> m.put(v.getName(), v.getMethods()), HashMap::putAll);
}

/**
* 方法权限验证,如果所拥有的权限超出方法所允许的权限则抛出异常
* @param clsName
* @param meth
* @param r
* @return
*/
public boolean isAuthority(String clsName, Method meth, String r) {
Method[] methods = proxyMap.get(clsName);
for (Method method : methods) {
if (meth.getName().equals(method.getName())) {
Secure secure = method.getAnnotation(Secure.class);
String role = secure.role();
if (r.equals(role)) {
return true;
}
}
}
return false;
}

}

方法拦截

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public class MyInvokationHandler<E> implements InvocationHandler {

private E target;
private String role;

private MethodPermissionFilter methodPermissionFilter;

public MyInvokationHandler(E target, String role) {
this.target = target;
this.role = role;
this.methodPermissionFilter = new MethodPermissionFilter();
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (methodPermissionFilter.isAuthority(target.getClass().getName(), method, role)) {
return method.invoke(target, args);
}
throw new RuntimeException("You have no permission use this function");
}
}

代理工厂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Data
public class JDKFactory {

private InvocationHandler handler;
private Object proxy;

/**
* JDK代理实现
* @param proxy
* @param role
*/
public JDKFactory(Object proxy, String role) {
this.proxy = proxy;
this.handler = new MyInvokationHandler<>(proxy, role);
this.proxy = (Object) Proxy.newProxyInstance(proxy.getClass().getClassLoader(),
new Class[]{proxy.getClass()}, handler);
}
}

Cglib 动态代理实现Spring依赖注入(2018-9-20补充)

之前用的都是JDK自带的方法代理拦截,虽然实现简单,但是有一个缺点,就是我们需要每个service都需要去implements一个接口才可以实现JDK的原生代理功能;在使用spring的依赖注入时,我们可能没有继承接口,那么这个时候,就没办法使用JDK的原生代理功能了,这个时候就需要使用CGLib动态代理了。什么是CGLib

获取CGLib代理的对象

1
2
3
4
5
6
7
8
9
10
11
public class CglibProxy {

private static final Enhancer enhancer = new Enhancer();

public <T> T getProxy(Object target) {
enhancer.setSuperclass(target.getClass());
enhancer.setCallback(new CglibMethodInterceptor());
return (T) enhancer.create();
}

}

CGLib方法拦截操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class CglibMethodInterceptor implements MethodInterceptor {

@Override
public Object intercept(Object obj, Method method, Object[] args, MethodProxy proxy) throws Throwable {
if (obj.getClass().equals(method.getDeclaringClass())) {
return method.invoke(obj, args);
}
String methodName = method.getName();
System.out.println("[before] The method " + methodName + " begins with " + (args!=null ? Arrays.asList(args) : "[]"));
Object result = proxy.invokeSuper(obj, args);
System.out.println(String.format("after method:%s execute", method.getName()));
System.out.println("[after] The method ends with " + result);
return result;
}
}

将代理对象注入到类中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@Data
public class ServiceFilter {

private HashMap<Class, Object> serviceMap;
private HashMap<String, Object> autowiredMap;
private CglibProxy cglibProxy;

public ServiceFilter() {
cglibProxy = new CglibProxy();
String packageName = Main.class.getPackage().getName();
ScanPackage scanPackage = new ScanPackage(packageName);
serviceMap = getServiceClasses(scanPackage.getClasses());
setServiceFields();
}

public HashMap getServiceClasses(Set<Class> classes) {
return classes.stream().filter(c -> c.isAnnotationPresent(Service.class))
.collect(HashMap::new, (m, v) -> {
try {
if (v.getInterfaces().length > 0) {
m.put(v.getInterfaces()[0], cglibProxy.getProxy(v.newInstance()));
} else {
m.put(v, v.newInstance());
}
} catch (InstantiationException | IllegalAccessException e) {
e.printStackTrace();
}
}, HashMap::putAll);
}

/**
* 将CGLib代理得到的对象注入到类中,实现Spring的@Autowired注解,从而实现IoC控制反转,将对象的生成转交
* 利用Field的set方法实现属性的赋值
*/
public void setServiceFields() {
serviceMap.keySet().forEach(cls -> {
Field[] fields = cls.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
if (field.isAnnotationPresent(Autowired.class)) {
try {
field.set(serviceMap.get(cls), serviceMap.get(field.getType()));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
});
}
}

测试用例

1
2
3
4
5
6
7
8
9
10
11
12
13
@Service(value = "Main")
public class Main {

@Autowired
private TestService testService;

public static void main(String[] args) {
ServiceFilter serviceFilter = new ServiceFilter();
Main main = (Main) serviceFilter.getServiceMap().get(Main.class);
main.testService.print();
}

}

源码

github-spring-demo