JRaft里面的时间轮算法

时间轮算法

java-web-61-1.png

时间轮算法简单来说就是有一个环形数组(一般我们会将数组长度设置为2^n),每一个间隔表示一个tick,然后每一个延迟任务根据延迟时间以及tick,算出来自己在哪一个solt,然后在计算自己需要几轮完整tick之后才算过期。

JRaft里面的实现

我们先来看看HashedWheelTimer里面的成员属性有什么

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 这里是用于控制 HashedWheelTimer 实例的个数,其实就是控制线程的数量,因为一个 HashedWheelTimer 对应一个线程
private static final int INSTANCE_COUNT_LIMIT = 256;
// 当前已经创建的数量
private static final AtomicInteger instanceCounter = new AtomicInteger();
// 告警实例数量太多,只会告警一次
private static final AtomicBoolean warnedTooManyInstances = new AtomicBoolean();
// 这里是针对 volatile 修饰的变量实现原子的读写操作
private static final AtomicIntegerFieldUpdater<HashedWheelTimer> workerStateUpdater = AtomicIntegerFieldUpdater
.newUpdater(
HashedWheelTimer.class,
"workerState");
// 具体的工作者,实现时间轮逻辑的核心
private final Worker worker = new Worker();
// 工作线程
private final Thread workerThread;
// worker的状态信息,准备初始化的状态值
public static final int WORKER_STATE_INIT = 0;
// worker的状态信息,已经初始化完成,正在运行当中
public static final int WORKER_STATE_STARTED = 1;
// worker的状态信息,处于关闭状态
public static final int WORKER_STATE_SHUTDOWN = 2;
// worker的状态,volatile实现线程间的可见读,然后利用AtomicIntegerFieldUpdater实现原子写
private volatile int workerState; // 0 - init, 1 - started, 2 - shut down
// 时间间隔,每一个tick代表的时间长度
private final long tickDuration;
// 时间轮
private final HashedWheelBucket[] wheel;
private final int mask;
private final CountDownLatch startTimeInitialized = new CountDownLatch(1);
// 暂存所有新增的任务
private final Queue<HashedWheelTimeout> timeouts = new ConcurrentLinkedQueue<>();
// 记录那些被取消的任务
private final Queue<HashedWheelTimeout> cancelledTimeouts = new ConcurrentLinkedQueue<>();
// 当前正在待执行的任务
private final AtomicLong pendingTimeouts = new AtomicLong(0);
// 最大可暂存的任务数量
private final long maxPendingTimeouts;

private volatile long startTime;

紧接着,就是创建一个HashedWheelTimer,在创建的过程当中,会初始化相应的数据结构已经资源。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
public HashedWheelTimer(ThreadFactory threadFactory, long tickDuration, TimeUnit unit, int ticksPerWheel,
long maxPendingTimeouts) {

if (threadFactory == null) {
throw new NullPointerException("threadFactory");
}
if (unit == null) {
throw new NullPointerException("unit");
}
if (tickDuration <= 0) {
throw new IllegalArgumentException("tickDuration must be greater than 0: " + tickDuration);
}
if (ticksPerWheel <= 0) {
throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel);
}

// 这里进行创建时间轮,其实就是初始化 HashedWheelBucket[] 数组,其数组的长度信息根据 ticksPerWheel
wheel = createWheel(ticksPerWheel);
mask = wheel.length - 1;

// Convert tickDuration to nanos.
this.tickDuration = unit.toNanos(tickDuration);

// Prevent overflow.
if (this.tickDuration >= Long.MAX_VALUE / wheel.length) {
throw new IllegalArgumentException(String.format(
"tickDuration: %d (expected: 0 < tickDuration in nanos < %d", tickDuration, Long.MAX_VALUE
/ wheel.length));
}

// 创建工作者线程
workerThread = threadFactory.newThread(worker);
// 最大的等待时间
this.maxPendingTimeouts = maxPendingTimeouts;

if (instanceCounter.incrementAndGet() > INSTANCE_COUNT_LIMIT
&& warnedTooManyInstances.compareAndSet(false, true)) {
reportTooManyInstances();
}
}

private static HashedWheelBucket[] createWheel(int ticksPerWheel) {
if (ticksPerWheel <= 0) {
throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel);
}
if (ticksPerWheel > 1073741824) {
throw new IllegalArgumentException("ticksPerWheel may not be greater than 2^30: " + ticksPerWheel);
}

// 对数组长度进行 2 的指数幂,找到最接近此数字的 2 的倍数
ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel);
HashedWheelBucket[] wheel = new HashedWheelBucket[ticksPerWheel];
for (int i = 0; i < wheel.length; i++) {
// 初始化每一个 HashedWheelBucket 对象,该对象其实本质就是一个 linked-list
wheel[i] = new HashedWheelBucket();
}
return wheel;
}

至此,HashedWheelTimer这个对象就创建完成,可以看到,在创建的过程中,workerThread只是进行了构建而已,并没有让它进行运转。而具体是什么时候才会运行起来呢?这里其实是一个懒惰的思想,只有当有任务被添加进来之后,才会触发时间轮算法的正式运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@Override
public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit) {
if (task == null) {
throw new NullPointerException("task");
}
if (unit == null) {
throw new NullPointerException("unit");
}

// 这里根据 pendingTimeoutsCount 的数值和 maxPendingTimeouts 相比,判断是否需要拒绝添加当前任务
long pendingTimeoutsCount = pendingTimeouts.incrementAndGet();

if (maxPendingTimeouts > 0 && pendingTimeoutsCount > maxPendingTimeouts) {
pendingTimeouts.decrementAndGet();
throw new RejectedExecutionException("Number of pending timeouts (" + pendingTimeoutsCount
+ ") is greater than or equal to maximum allowed pending "
+ "timeouts (" + maxPendingTimeouts + ")");
}

// 触发事件轮算法的正式运行
start();

// Add the timeout to the timeout queue which will be processed on the next tick.
// During processing all the queued HashedWheelTimeouts will be added to the correct HashedWheelBucket.
long deadline = System.nanoTime() + unit.toNanos(delay) - startTime;

// Guard against overflow.
if (delay > 0 && deadline < 0) {
deadline = Long.MAX_VALUE;
}

// 将任务进行添加
HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline);
timeouts.add(timeout);
return timeout;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public void start() {
switch (workerStateUpdater.get(this)) {
case WORKER_STATE_INIT:
// 如果还没有被初始化过,这将 workerThread 跑起来
if (workerStateUpdater.compareAndSet(this, WORKER_STATE_INIT, WORKER_STATE_STARTED)) {
workerThread.start();
}
break;
case WORKER_STATE_STARTED:
break;
case WORKER_STATE_SHUTDOWN:
throw new IllegalStateException("cannot be started once stopped");
default:
throw new Error("Invalid WorkerState");
}

// 这里之所以要加上训话,是为了避免 linux 的 spurious wakeups
// https://en.m.wikipedia.org/wiki/Spurious_wakeup
while (startTime == 0) {
try {
startTimeInitialized.await();
} catch (InterruptedException ignore) {
// Ignore - it will be ready very soon.
}
}
}

然后当workerThread运行起来之后,会等待 startTime 被赋值了,因为时间轮算法,既然是根据时间来进行调度,那么一定会需要一个时间点作为参照,而这个参照就是startTime变量了。而startTime变量被初始化这是在Workerrun方法中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@Override
public void run() {
// 这里设置时间参考点
startTime = System.nanoTime();
if (startTime == 0) {
// We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized.
startTime = 1;
}

// 这里通知 start() 可以继续初始化了
startTimeInitialized.countDown();

do {
// 等待下一个 tick 时间
final long deadline = waitForNextTick();
if (deadline > 0) {
// 计算时间轮
int idx = (int) (tick & mask);
// 处理被取消的任务,这里直接将任务移除
processCancelledTasks();
// 定位到对应的 bucket
HashedWheelBucket bucket = wheel[idx];
// 这里处理 HashedWheelTimeout,将其转移到对应的 HashedWheelBucket 中
transferTimeoutsToBuckets();
// 处理已经到期的任务
bucket.expireTimeouts(deadline);
// tick 步数向前推动一步
tick++;
}
} while (workerStateUpdater.get(HashedWheelTimer.this) == WORKER_STATE_STARTED);

// 这里是当 HashedWheelTimer 停止工作时,收集所有还没被处理当任务,以供返回给用户进行处理
for (HashedWheelBucket bucket : wheel) {
bucket.clearTimeouts(unprocessedTimeouts);
}
for (;;) {
HashedWheelTimeout timeout = timeouts.poll();
if (timeout == null) {
break;
}
if (!timeout.isCancelled()) {
unprocessedTimeouts.add(timeout);
}
}
processCancelledTasks();
}

既然刚刚看到,在处理已经到期的任务之前,会做一次transferTimeoutsToBuckets的操作,因此来看看这里面做了什么事情

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
private void transferTimeoutsToBuckets() {
// 这里为了尽可能小的影响任务阻塞情况,只会处理 10w 个数据
for (int i = 0; i < 100000; i++) {
// 从全局缓存的队列中取出 HashedWheelTimeout 对象,如果队列为空,则跳出处理逻辑
HashedWheelTimeout timeout = timeouts.poll();
if (timeout == null) {
// all processed
break;
}
// 判断该任务是否被取消了
if (timeout.state() == HashedWheelTimeout.ST_CANCELLED) {
// Was cancelled in the meantime.
continue;
}

// 计算该任务的到期时间需要经过多少次的 tick
long calculated = timeout.deadline / tickDuration;
// 这里需要减去任务被添加进来时已经经过的 tick,计算出还需要经过多少个完整的 tick 周期才会处理自己
timeout.remainingRounds = (calculated - tick) / wheel.length;

// 这里取最大值是为了避免处理了过去的任务
// Ensure we don't schedule for past.
final long ticks = Math.max(calculated, tick);
// 然后通过运算计算出该任务应该落在哪一个 HashedWheelBucket
int stopIndex = (int) (ticks & mask);

HashedWheelBucket bucket = wheel[stopIndex];
// 加入到链表当中
bucket.addTimeout(timeout);
}
}

这一步其实是在将收集到的任务进行转移到时间轮的每一个tick当中。

现在就来看看当任务过期时的处理方式吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public void expireTimeouts(long deadline) {
HashedWheelTimeout timeout = head;

// 便利所有的 HashedWheelTimeout 进行处理
while (timeout != null) {
HashedWheelTimeout next = timeout.next;

// 所要经历的轮数是否小于0了,是的话,则需要进行过期操作
if (timeout.remainingRounds <= 0) {
// 将自己移除
next = remove(timeout);
if (timeout.deadline <= deadline) {
// 调用过期函数,触发用户自定义的逻辑
timeout.expire();
} else {
// The timeout was placed into a wrong slot. This should never happen.
throw new IllegalStateException(String.format("timeout.deadline (%d) > deadline (%d)",
timeout.deadline, deadline));
}
// 如果任务被取消了,则移除自己
} else if (timeout.isCancelled()) {
next = remove(timeout);
} else {
// 轮数减一
timeout.remainingRounds--;
}
timeout = next;
}
}