自定义注解
首先需要创建几个注解用于使用RPC
EnableTensorRPC
开启RPC服务
RpcRegister
将接口注册为一个服务
RpcService
用于标注该接口是一个远程服务
服务接口创建代理类
为了方便,我直接和Spring
进行了整合,这里使用了Spring
的回调接口BeanPostProcessor
来进行上述三个注解的处理工作(当然目前这个很粗糙)
处理 EnableTensorRPC 注解
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| public class EnableTensorRPCBeanPostProcessor implements BeanPostProcessor {
@Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { EnableTensorRPC rpc = bean.getClass().getAnnotation(EnableTensorRPC.class); if (rpc == null) { return bean; } else { RpcApplication.init(rpc); if (Objects.equals(rpc.type(), RpcType.PROVIDER)) { log.info("[TENSOR RPC] work type is provider"); NettyServer.start(rpc); } else { log.info("[TENSOR RPC] work type is consumer"); } } return bean; }
@Override public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { return bean; } }
|
该BeanPostProcessor
主要的工作就是启动rpc server
操作,因为可能项目既作为服务的提供者也作为服务的消费者,但是由于存在仅仅作为消费者的情况,因此做了一定的优化,根据参数来判断是否需要开启server
处理 RpcRegister 注解
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| @Component public class RpcRegisterInjectBeanPostProcessor implements BeanPostProcessor {
@Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { Class cls = bean.getClass(); RpcRegister rpcRegister = (RpcRegister) cls.getAnnotation(RpcRegister.class); if (rpcRegister != null) { String serviceName = rpcRegister.serviceName();
Class<?> registerType = rpcRegister.value();
String key = KeyBuilder.buildServiceKey(serviceName, rpcRegister.ip(), rpcRegister.port()); ApplicationManager.getRpcInfoManager().providerRegister(registerType.getCanonicalName(), rpcRegister); ApplicationManager.getNativeMethodManager().register(bean, cls); } return bean; }
@Override public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { return bean; } }
|
该BeanPostProcessor
注解是处理对外暴露接口的,这里做的很简单,根据注解所标注的接口,为该接口创建一个签名然后放入一个本地方法池中,而目前的方法签名很简单,就是类的全路径名称
处理 RpcService 注解
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
| public class RpcServiceInjectBeanPostProcessor implements BeanPostProcessor {
@Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { return bean; }
@Override public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { Class<?> cls = bean.getClass(); rpcServiceInject(cls, bean); return bean; }
private void rpcServiceInject(Class cls, Object bean) { Field[] fields = cls.getDeclaredFields(); for (Field field : fields) { field.setAccessible(true); if (field.isAnnotationPresent(RpcService.class)) { RpcService rpcService = field.getAnnotation(RpcService.class); try { field.set(bean, RpcInjectProxy.inject(rpcService, field.getType())); ApplicationManager.getRpcInfoManager().consumerRegister(field.getType().getCanonicalName(), rpcService); } catch (IllegalAccessException e) { e.printStackTrace(); } } } }
}
|
该BeanPostProcessor
注解是为远程接口创建一个代理类,当服务消费者调用远程接口服务时,实际是由代理类进行远程访问调用服务提供者的方法的
如何正确进行方法的调用
当一个服务消费者调用远程服务时,首先来的是代理类的拦截方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| public class RpcCGlibCallBackHandler extends AbstractRpcCallBackHandler {
private String serviceName;
public RpcCGlibCallBackHandler(String serviceName) { this.serviceName = serviceName; }
@Override public Object intercept(Object obj, Method method, Object[] args, MethodProxy proxy) throws Throwable { Class<?> cls = method.getDeclaringClass(); RpcService rpcService = ApplicationManager.getRpcInfoManager().getRpcService(cls.getCanonicalName()); RpcResult future = getMethodInvoker().invoke(methodRequest(cls.getCanonicalName(), method, args), rpcService); return future.result(); }
private RpcMethodRequest methodRequest(String className, Method method, Object[] args) { String reqId = UUID.randomUUID().toString(); Class<?> returnType = method.getReturnType(); return RpcMethodRequest.builder() .reqId(reqId) .ownerName(className) .methodName(KeyBuilder.methodSign(method)) .param(args) .returnType(returnType) .build(); }
}
|
然后是进入方法调用链(用于判断走的是RPC还是Native),这里就给出走RPC时的方法吧
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| @Slf4j public class RpcMethodExecutor implements MethodExecutor {
public RpcMethodExecutor() { }
@Override public RpcResult invoke(Invoker invoker, MethodExecutorChain chain) throws InterruptedException { RpcResult[] future = new RpcResult[1]; Mono.just(invoker) .map(invoker1 -> { RpcMethodRequest request = invoker1.getRequest(); if (request == null) { throw new RuntimeException("[RPC 调用过程出现异常!]"); } else if (Void.class.equals(request.getReturnType())) { future[0] = RpcResultPool.createFuture(request.getReqId(), true); } else { future[0] = RpcResultPool.createFuture(request.getReqId()); } return invoker1; }) .publishOn(Schedulers.fromExecutor(RpcSchedule.RpcExecutor.RPC)) .subscribe(this::sendRequest); return future[0]; }
@Override public int priority() { return -1; }
private void sendRequest(Invoker invoker) { RpcService rpcService = invoker.getService(); RpcMethodRequest request = invoker.getRequest(); String serviceName = rpcService.serviceName(); String ip = rpcService.ip(); int port = rpcService.port(); Channel channel = null; try { channel = NettyClient.getConnection(serviceName, ip, port); channel.writeAndFlush(request).sync(); NettyClient.release(channel, serviceName, ip, port); } catch (InterruptedException e) { throw new RuntimeException(e); } }
}
|
接着就是当RPC
请求来到了服务提供者时怎么路由到方法了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| public class NativeMethodExecutor implements MethodExecutor {
public NativeMethodExecutor() { }
@Override public RpcResult invoke(Invoker invoker, MethodExecutorChain chain) throws InterruptedException { if (invoker.isNative()) { return innerExec(invoker.getRequest()); } if (invoker.isRpcRequest()) { return innerExec(invoker.getRequest(), invoker.getChannel()); } return chain.chain(invoker); }
private RpcResult innerExec(RpcMethodRequest msg) { RpcResult[] future = new RpcResult[1]; future[0] = RpcResultPool.createFuture(msg.getReqId()); Mono.just(msg) .publishOn(Schedulers.fromExecutor(RpcSchedule.RpcExecutor.RPC)) .map(request -> ApplicationManager.getNativeMethodManager().getExecutor(request.getOwnerName())) .map(f -> f.apply(msg)) .subscribe(rpcMethodResponse -> future[0].complete(rpcMethodResponse)); return future[0]; }
private RpcResult innerExec(RpcMethodRequest msg, Channel channel) { Mono.just(msg) .map(request -> ApplicationManager.getNativeMethodManager().getExecutor(request.getOwnerName())) .map(executor -> executor) .publishOn(Schedulers.fromExecutor(RpcSchedule.RpcExecutor.RPC)) .map(f -> f.apply(msg)) .subscribe(channel::writeAndFlush); return null; }
@Override public int priority() { return -2; }
}
|
这里其实就是用到了之前在创建请求时设置的接口全路径名了以及一个管理本地注册为服务的接口方法管理对象ApplicationManager
了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| public class ApplicationManager {
private static final RpcInfoManager RPC_INFO_MANAGER = new RpcInfoManager(); private static final NativeMethodManager NATIVE_METHOD_MANAGER = new NativeMethodManager();
public static RpcInfoManager getRpcInfoManager() { return RPC_INFO_MANAGER; }
public static NativeMethodManager getNativeMethodManager() { return NATIVE_METHOD_MANAGER; }
public static class RpcInfoManager {
RpcInfoManager() {}
private final ConcurrentHashMap<String, RpcService> SERVICE_MAP = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, RpcRegister> REGISTER_MAP = new ConcurrentHashMap<>();
public void consumerRegister(String key, RpcService value) { SERVICE_MAP.put(key, value); }
public void providerRegister(String key, RpcRegister value) { REGISTER_MAP.put(key, value); }
public RpcService getRpcService(String key) { return SERVICE_MAP.get(key); }
public RpcRegister getRpcRegister(String key) { return REGISTER_MAP.get(key); }
public boolean containRegisterInfo(String key) { return REGISTER_MAP.containsKey(key); }
}
public static class NativeMethodManager { private final ConcurrentHashMap<String, RegisterInfo> METHOD_EXECUTOR = new ConcurrentHashMap<>();
NativeMethodManager() {}
public void register(Object owner, Class<?> cls) { Class[] interfaces = cls.getInterfaces(); for (Class inter : interfaces) { METHOD_EXECUTOR.put(inter.getCanonicalName(), new RegisterInfo(owner, inter)); } }
public Executor getExecutor(String key) { return new Executor(METHOD_EXECUTOR.get(key)); }
private static class RegisterInfo {
Object owner; Map<String, Method> methodMap = new HashMap<>();
RegisterInfo(Object owner, Class<?> cls) {
this.owner = owner;
Method[] methods = cls.getMethods(); for (Method method : methods) { methodMap.put(KeyBuilder.methodSign(method), method); } }
Method getMethod(String sign) { return methodMap.get(sign); }
}
public static class Executor implements Function<RpcMethodRequest, RpcMethodResponse> {
RegisterInfo worker;
Executor(RegisterInfo worker) { this.worker = worker; }
@Override public RpcMethodResponse apply(RpcMethodRequest request) { Exception err = null; Object result = new Object(); if (worker == null) { err = new Exception("[Tensor RPC] : Not this function"); } else { Method method = worker.getMethod(request.getMethodName()); try { result = method.invoke(worker.owner, request.getParam()); } catch (IllegalAccessException | InvocationTargetException e) { err = e; } } return RpcMethodResponse .builder() .respId(request.getReqId()) .returnVal(GsonSerializer.encode(result)) .returnType(request.getReturnType()) .error(err) .build(); } }
public boolean isNative(RpcMethodRequest request) { String key = request.getOwnerName(); return RPC_INFO_MANAGER.containRegisterInfo(key); } }
}
|
ApplicationManager
类管理着本地被注册为@RpcRegister
的方法以及所有的@RpcRegister
以及@RpcService
注解修饰类的信息,对所有RPC请求调用本地方法的处理都交由Executor
类来执行
异步RPC调用,如何能够正确接收返回结果
由于每一次RPC
请求的发起都会创建一个请求id,以及RPC
响应都会讲请求id回带到响应对象中,因此只需要根据请求id找到对应的future对象,complete下即可
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
public class RpcResultPool {
private static final ConcurrentHashMap<String, RpcResult> RESULT_FUTURE = new ConcurrentHashMap<>();
public static RpcResult getFuture(String key) { return RESULT_FUTURE.get(key); }
public static RpcResult createFuture(String key) { return createFuture(key, false); }
public static RpcResult createFuture(String key, boolean isVoid) { RpcResult result = new RpcResult(isVoid); RESULT_FUTURE.put(key, result); return result; }
public static void remove(String key) { RESULT_FUTURE.remove(key); }
}
public class RpcResult extends CompletableFuture<RpcMethodResponse> {
private boolean isVoid = false;
public RpcResult() { }
public RpcResult(boolean isVoid) { this.isVoid = isVoid; }
@Override public boolean complete(RpcMethodResponse value) { return super.complete(value); }
public Object result() throws ExecutionException, InterruptedException, RpcTimeOutException { if (isVoid) { return null; } RpcMethodResponse response = null; response = get(); RpcResultPool.remove(response.getRespId()); Class cls = response.getReturnType(); String val = response.getReturnVal(); Throwable error = response.getError(); if (error != null) { throw new RuntimeException(error); } return GsonSerializer.decode(val, cls); } }
|
项目地址
tensor-rpc