手动实现一个简单的RPC框架

自定义注解

首先需要创建几个注解用于使用RPC

EnableTensorRPC 开启RPC服务

RpcRegister 将接口注册为一个服务

RpcService 用于标注该接口是一个远程服务

服务接口创建代理类

为了方便,我直接和Spring进行了整合,这里使用了Spring的回调接口BeanPostProcessor来进行上述三个注解的处理工作(当然目前这个很粗糙)

处理 EnableTensorRPC 注解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public class EnableTensorRPCBeanPostProcessor implements BeanPostProcessor {

@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
EnableTensorRPC rpc = bean.getClass().getAnnotation(EnableTensorRPC.class);
if (rpc == null) {
return bean;
} else {
// 初始化RPC服务信息
RpcApplication.init(rpc);
// 如果自己是Provider,则会开启server,否则什么都不做
if (Objects.equals(rpc.type(), RpcType.PROVIDER)) {
log.info("[TENSOR RPC] work type is provider");
NettyServer.start(rpc);
} else {
log.info("[TENSOR RPC] work type is consumer");
}
}
return bean;
}

@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
return bean;
}
}

BeanPostProcessor主要的工作就是启动rpc server操作,因为可能项目既作为服务的提供者也作为服务的消费者,但是由于存在仅仅作为消费者的情况,因此做了一定的优化,根据参数来判断是否需要开启server

处理 RpcRegister 注解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@Component
public class RpcRegisterInjectBeanPostProcessor implements BeanPostProcessor {

/**
* 接口服务注册
*
* @param bean
* @param beanName
* @return
* @throws BeansException
*/
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
Class cls = bean.getClass();
RpcRegister rpcRegister = (RpcRegister) cls.getAnnotation(RpcRegister.class);
if (rpcRegister != null) {
String serviceName = rpcRegister.serviceName();

// 为该接口服务创建一个方法签名信息

Class<?> registerType = rpcRegister.value();

String key = KeyBuilder.buildServiceKey(serviceName, rpcRegister.ip(), rpcRegister.port());
// 注册服务者信息到ApplicationManager中
ApplicationManager.getRpcInfoManager().providerRegister(registerType.getCanonicalName(), rpcRegister);
ApplicationManager.getNativeMethodManager().register(bean, cls);
}
return bean;
}

@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
return bean;
}
}

BeanPostProcessor注解是处理对外暴露接口的,这里做的很简单,根据注解所标注的接口,为该接口创建一个签名然后放入一个本地方法池中,而目前的方法签名很简单,就是类的全路径名称

处理 RpcService 注解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class RpcServiceInjectBeanPostProcessor implements BeanPostProcessor {

@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
return bean;
}

@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
Class<?> cls = bean.getClass();
rpcServiceInject(cls, bean);
return bean;
}

/**
* 接口进行代理,对接口的所有操作转为 RPC 远程调用操作
*
* @param cls {@link Class} interface type
* @param bean {@link Object} Spring's bean
*/
private void rpcServiceInject(Class cls, Object bean) {
Field[] fields = cls.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
if (field.isAnnotationPresent(RpcService.class)) {
RpcService rpcService = field.getAnnotation(RpcService.class);
try {
// 创建一个Proxy类进行接口的代理
field.set(bean, RpcInjectProxy.inject(rpcService, field.getType()));
ApplicationManager.getRpcInfoManager().consumerRegister(field.getType().getCanonicalName(), rpcService);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}

}

BeanPostProcessor注解是为远程接口创建一个代理类,当服务消费者调用远程接口服务时,实际是由代理类进行远程访问调用服务提供者的方法的

如何正确进行方法的调用

当一个服务消费者调用远程服务时,首先来的是代理类的拦截方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
public class RpcCGlibCallBackHandler extends AbstractRpcCallBackHandler {

private String serviceName;

public RpcCGlibCallBackHandler(String serviceName) {
this.serviceName = serviceName;
}

@Override
public Object intercept(Object obj, Method method, Object[] args, MethodProxy proxy) throws Throwable {
Class<?> cls = method.getDeclaringClass();
// 获取服务接口的注解信息(目的就是为了得知服务的接口的全信息,好做调用)
RpcService rpcService = ApplicationManager.getRpcInfoManager().getRpcService(cls.getCanonicalName());
// 创建future进行异步RPC请求
RpcResult future = getMethodInvoker().invoke(methodRequest(cls.getCanonicalName(), method, args), rpcService);
return future.result();
}

// 封装本次RPC请求的信息,请求id、服务者key、请求方法、请求参数、返回类型
private RpcMethodRequest methodRequest(String className, Method method, Object[] args) {
String reqId = UUID.randomUUID().toString();
Class<?> returnType = method.getReturnType();
return RpcMethodRequest.builder()
.reqId(reqId)
.ownerName(className)
.methodName(KeyBuilder.methodSign(method))
.param(args)
.returnType(returnType)
.build();
}

}

然后是进入方法调用链(用于判断走的是RPC还是Native),这里就给出走RPC时的方法吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@Slf4j
public class RpcMethodExecutor implements MethodExecutor {

public RpcMethodExecutor() {
}

@Override
public RpcResult invoke(Invoker invoker, MethodExecutorChain chain) throws InterruptedException {
RpcResult[] future = new RpcResult[1];
Mono.just(invoker)
.map(invoker1 -> {
RpcMethodRequest request = invoker1.getRequest();
if (request == null) {
throw new RuntimeException("[RPC 调用过程出现异常!]");
} else if (Void.class.equals(request.getReturnType())) {
future[0] = RpcResultPool.createFuture(request.getReqId(), true);
} else {
future[0] = RpcResultPool.createFuture(request.getReqId());
}
return invoker1;
})
.publishOn(Schedulers.fromExecutor(RpcSchedule.RpcExecutor.RPC))
.subscribe(this::sendRequest);
return future[0];
}

@Override
public int priority() {
return -1;
}

// 这里会创建一个Netty channel连接(已做了优化,如果存在直接复用,不会每次请求都去创建)
private void sendRequest(Invoker invoker) {
RpcService rpcService = invoker.getService();
RpcMethodRequest request = invoker.getRequest();
String serviceName = rpcService.serviceName();
String ip = rpcService.ip();
int port = rpcService.port();
Channel channel = null;
try {
channel = NettyClient.getConnection(serviceName, ip, port);
channel.writeAndFlush(request).sync();
// 归还本次使用的channel
NettyClient.release(channel, serviceName, ip, port);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

}

接着就是当RPC请求来到了服务提供者时怎么路由到方法了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
public class NativeMethodExecutor implements MethodExecutor {

public NativeMethodExecutor() {
}

@Override
public RpcResult invoke(Invoker invoker, MethodExecutorChain chain) throws InterruptedException {
// 本地方法
if (invoker.isNative()) {
return innerExec(invoker.getRequest());
}
// 远程RPC请求
if (invoker.isRpcRequest()) {
return innerExec(invoker.getRequest(), invoker.getChannel());
}
return chain.chain(invoker);
}

/**
* 本地方法直接调用
*
* @param msg {@link RpcMethodRequest}
* @return {@link RpcResult}
*/
private RpcResult innerExec(RpcMethodRequest msg) {
RpcResult[] future = new RpcResult[1];
future[0] = RpcResultPool.createFuture(msg.getReqId());
Mono.just(msg)
// 切换执行的线程
.publishOn(Schedulers.fromExecutor(RpcSchedule.RpcExecutor.RPC))
// 根据请求去获取一个Executor用于处理请求
.map(request -> ApplicationManager.getNativeMethodManager().getExecutor(request.getOwnerName()))
.map(f -> f.apply(msg))
// 直接结束方法调用
.subscribe(rpcMethodResponse -> future[0].complete(rpcMethodResponse));
return future[0];
}

/**
* 用于接收RPC Request请求,执行相应的方法
*
* @param msg {@link RpcMethodRequest}
* @param channel {@link Channel}
* @return {@link RpcResult}
*/
private RpcResult innerExec(RpcMethodRequest msg, Channel channel) {
Mono.just(msg)
.map(request -> ApplicationManager.getNativeMethodManager().getExecutor(request.getOwnerName()))
.map(executor -> executor)
.publishOn(Schedulers.fromExecutor(RpcSchedule.RpcExecutor.RPC))
.map(f -> f.apply(msg))
// RpcResponse对象写入channel中返回给调用者
.subscribe(channel::writeAndFlush);
return null;
}

@Override
public int priority() {
return -2;
}

}

这里其实就是用到了之前在创建请求时设置的接口全路径名了以及一个管理本地注册为服务的接口方法管理对象ApplicationManager

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
public class ApplicationManager {

private static final RpcInfoManager RPC_INFO_MANAGER = new RpcInfoManager();
private static final NativeMethodManager NATIVE_METHOD_MANAGER = new NativeMethodManager();

public static RpcInfoManager getRpcInfoManager() {
return RPC_INFO_MANAGER;
}

public static NativeMethodManager getNativeMethodManager() {
return NATIVE_METHOD_MANAGER;
}

// 注解信息的管理
public static class RpcInfoManager {

RpcInfoManager() {}

private final ConcurrentHashMap<String, RpcService> SERVICE_MAP = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, RpcRegister> REGISTER_MAP = new ConcurrentHashMap<>();

public void consumerRegister(String key, RpcService value) {
SERVICE_MAP.put(key, value);
}

public void providerRegister(String key, RpcRegister value) {
REGISTER_MAP.put(key, value);
}

public RpcService getRpcService(String key) {
return SERVICE_MAP.get(key);
}

public RpcRegister getRpcRegister(String key) {
return REGISTER_MAP.get(key);
}

public boolean containRegisterInfo(String key) {
return REGISTER_MAP.containsKey(key);
}

}

public static class NativeMethodManager {
private final ConcurrentHashMap<String, RegisterInfo> METHOD_EXECUTOR = new ConcurrentHashMap<>();

NativeMethodManager() {}

// 进行服务提供者注册,将服务提供者的所有方法信息进行注册到管理池中
public void register(Object owner, Class<?> cls) {
Class[] interfaces = cls.getInterfaces();
for (Class inter : interfaces) {
METHOD_EXECUTOR.put(inter.getCanonicalName(), new RegisterInfo(owner, inter));
}
}

public Executor getExecutor(String key) {
return new Executor(METHOD_EXECUTOR.get(key));
}

// 接口注册的信息,该接口下的所有方法信息
private static class RegisterInfo {

Object owner;
Map<String, Method> methodMap = new HashMap<>();

RegisterInfo(Object owner, Class<?> cls) {

this.owner = owner;

Method[] methods = cls.getMethods();
for (Method method : methods) {
methodMap.put(KeyBuilder.methodSign(method), method);
}
}

Method getMethod(String sign) {
return methodMap.get(sign);
}

}

public static class Executor implements Function<RpcMethodRequest, RpcMethodResponse> {

RegisterInfo worker;

Executor(RegisterInfo worker) {
this.worker = worker;
}

@Override
public RpcMethodResponse apply(RpcMethodRequest request) {
Exception err = null;
Object result = new Object();
if (worker == null) {
err = new Exception("[Tensor RPC] : Not this function");
} else {
// 方法对象的获取
Method method = worker.getMethod(request.getMethodName());
try {
result = method.invoke(worker.owner, request.getParam());
} catch (IllegalAccessException | InvocationTargetException e) {
err = e;
}
}
// 创建方法返回响应,会带请求id、返回对象序列化、返回错误信息
return RpcMethodResponse
.builder()
.respId(request.getReqId())
.returnVal(GsonSerializer.encode(result))
.returnType(request.getReturnType())
.error(err)
.build();
}
}

public boolean isNative(RpcMethodRequest request) {
String key = request.getOwnerName();
return RPC_INFO_MANAGER.containRegisterInfo(key);
}
}

}

ApplicationManager类管理着本地被注册为@RpcRegister的方法以及所有的@RpcRegister以及@RpcService注解修饰类的信息,对所有RPC请求调用本地方法的处理都交由Executor类来执行

异步RPC调用,如何能够正确接收返回结果

由于每一次RPC请求的发起都会创建一个请求id,以及RPC响应都会讲请求id回带到响应对象中,因此只需要根据请求id找到对应的future对象,complete下即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Rpc 请求缓存池,每次的Rpc请求生成的RpcResult会放在这里进行缓存,当接收到返回响应时
// 从此处获取对应的RpcResult对象,执行complete操作
public class RpcResultPool {

private static final ConcurrentHashMap<String, RpcResult> RESULT_FUTURE = new ConcurrentHashMap<>();

public static RpcResult getFuture(String key) {
return RESULT_FUTURE.get(key);
}

public static RpcResult createFuture(String key) {
return createFuture(key, false);
}

public static RpcResult createFuture(String key, boolean isVoid) {
RpcResult result = new RpcResult(isVoid);
RESULT_FUTURE.put(key, result);
return result;
}

public static void remove(String key) {
RESULT_FUTURE.remove(key);
}

}

// Rpc 请求结果对象
public class RpcResult extends CompletableFuture<RpcMethodResponse> {

private boolean isVoid = false;

public RpcResult() {
}

public RpcResult(boolean isVoid) {
this.isVoid = isVoid;
}

@Override
public boolean complete(RpcMethodResponse value) {
return super.complete(value);
}

public Object result() throws ExecutionException, InterruptedException, RpcTimeOutException {
if (isVoid) {
return null;
}
RpcMethodResponse response = null;
response = get();
RpcResultPool.remove(response.getRespId());
Class cls = response.getReturnType();
String val = response.getReturnVal();
Throwable error = response.getError();
if (error != null) {
throw new RuntimeException(error);
}
return GsonSerializer.decode(val, cls);
}
}

项目地址

tensor-rpc