Java时间轮算法的实现代码示例
作者:唐华春
本篇文章主要介绍了Java时间轮算法的实现代码示例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000个定时器,一个定时器就是一个线程,你懂了吧,这种方法肯定是不行的。
针对这个场景,催生了时间轮算法,时间轮到底是什么?我一贯的风格,自行谷歌去。大发慈悲,发个时间轮介绍你们看看,看文字和图就好了,代码不要看了,那个文章里的代码运行不起来,时间轮介绍。
看好了介绍,我们就开始动手吧。
开发环境:idea + jdk1.8 + maven
新建一个maven工程
创建如下的目录结构
不要忘了pom.xml中添加netty库
<dependencies> <dependency> <groupId>io.netty</groupId> <artifactId>netty-all</artifactId> <version>4.1.5.Final</version> </dependency> </dependencies>
代码如下
Timeout.Java
package com.tanghuachun.timer; public interface Timeout { Timer timer(); TimerTask task(); boolean isExpired(); boolean isCancelled(); boolean cancel(); }
Timer.java
package com.tanghuachun.timer; import java.util.Set; import java.util.concurrent.TimeUnit; public interface Timer { Timeout newTimeout(TimerTask task, long delay, TimeUnit unit, String argv); Set<Timeout> stop(); }
TimerTask.java
package com.tanghuachun.timer; public interface TimerTask { void run(Timeout timeout, String argv) throws Exception; }
TimerWheel.java
/* * Copyright 2012 The Netty Project * * The Netty Project licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.tanghuachun.timer; import io.netty.util.*; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.util.Collections; import java.util.HashSet; import java.util.Queue; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; public class TimerWheel implements Timer { static final InternalLogger logger = InternalLoggerFactory.getInstance(TimerWheel.class); private static final ResourceLeakDetector<TimerWheel> leakDetector = ResourceLeakDetectorFactory.instance() .newResourceLeakDetector(TimerWheel.class, 1, Runtime.getRuntime().availableProcessors() * 4L); private static final AtomicIntegerFieldUpdater<TimerWheel> WORKER_STATE_UPDATER; static { AtomicIntegerFieldUpdater<TimerWheel> workerStateUpdater = PlatformDependent.newAtomicIntegerFieldUpdater(TimerWheel.class, "workerState"); if (workerStateUpdater == null) { workerStateUpdater = AtomicIntegerFieldUpdater.newUpdater(TimerWheel.class, "workerState"); } WORKER_STATE_UPDATER = workerStateUpdater; } private final ResourceLeak leak; private final Worker worker = new Worker(); private final Thread workerThread; public static final int WORKER_STATE_INIT = 0; public static final int WORKER_STATE_STARTED = 1; public static final int WORKER_STATE_SHUTDOWN = 2; @SuppressWarnings({ "unused", "FieldMayBeFinal", "RedundantFieldInitialization" }) private volatile int workerState = WORKER_STATE_INIT; // 0 - init, 1 - started, 2 - shut down private final long tickDuration; private final HashedWheelBucket[] wheel; private final int mask; private final CountDownLatch startTimeInitialized = new CountDownLatch(1); private final Queue<HashedWheelTimeout> timeouts = PlatformDependent.newMpscQueue(); private final Queue<HashedWheelTimeout> cancelledTimeouts = PlatformDependent.newMpscQueue(); private volatile long startTime; /** * Creates a new timer with the default thread factory * ({@link Executors#defaultThreadFactory()}), default tick duration, and * default number of ticks per wheel. */ public TimerWheel() { this(Executors.defaultThreadFactory()); } /** * Creates a new timer with the default thread factory * ({@link Executors#defaultThreadFactory()}) and default number of ticks * per wheel. * * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @throws NullPointerException if {@code unit} is {@code null} * @throws IllegalArgumentException if {@code tickDuration} is <= 0 */ public TimerWheel(long tickDuration, TimeUnit unit) { this(Executors.defaultThreadFactory(), tickDuration, unit); } /** * Creates a new timer with the default thread factory * ({@link Executors#defaultThreadFactory()}). * * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @param ticksPerWheel the size of the wheel * @throws NullPointerException if {@code unit} is {@code null} * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 */ public TimerWheel(long tickDuration, TimeUnit unit, int ticksPerWheel) { this(Executors.defaultThreadFactory(), tickDuration, unit, ticksPerWheel); } /** * Creates a new timer with the default tick duration and default number of * ticks per wheel. * * @param threadFactory a {@link ThreadFactory} that creates a * background {@link Thread} which is dedicated to * {@link TimerTask} execution. * @throws NullPointerException if {@code threadFactory} is {@code null} */ public TimerWheel(ThreadFactory threadFactory) { this(threadFactory, 100, TimeUnit.MILLISECONDS); } /** * Creates a new timer with the default number of ticks per wheel. * * @param threadFactory a {@link ThreadFactory} that creates a * background {@link Thread} which is dedicated to * {@link TimerTask} execution. * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} * @throws IllegalArgumentException if {@code tickDuration} is <= 0 */ public TimerWheel( ThreadFactory threadFactory, long tickDuration, TimeUnit unit) { this(threadFactory, tickDuration, unit, 512); } /** * Creates a new timer. * * @param threadFactory a {@link ThreadFactory} that creates a * background {@link Thread} which is dedicated to * {@link TimerTask} execution. * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @param ticksPerWheel the size of the wheel * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 */ public TimerWheel( ThreadFactory threadFactory, long tickDuration, TimeUnit unit, int ticksPerWheel) { this(threadFactory, tickDuration, unit, ticksPerWheel, true); } /** * Creates a new timer. * * @param threadFactory a {@link ThreadFactory} that creates a * background {@link Thread} which is dedicated to * {@link TimerTask} execution. * @param tickDuration the duration between tick * @param unit the time unit of the {@code tickDuration} * @param ticksPerWheel the size of the wheel * @param leakDetection {@code true} if leak detection should be enabled always, if false it will only be enabled * if the worker thread is not a daemon thread. * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 */ public TimerWheel( ThreadFactory threadFactory, long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection) { if (threadFactory == null) { throw new NullPointerException("threadFactory"); } if (unit == null) { throw new NullPointerException("unit"); } if (tickDuration <= 0) { throw new IllegalArgumentException("tickDuration must be greater than 0: " + tickDuration); } if (ticksPerWheel <= 0) { throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel); } // Normalize ticksPerWheel to power of two and initialize the wheel. wheel = createWheel(ticksPerWheel); mask = wheel.length - 1; // Convert tickDuration to nanos. this.tickDuration = unit.toNanos(tickDuration); // Prevent overflow. if (this.tickDuration >= Long.MAX_VALUE / wheel.length) { throw new IllegalArgumentException(String.format( "tickDuration: %d (expected: 0 < tickDuration in nanos < %d", tickDuration, Long.MAX_VALUE / wheel.length)); } workerThread = threadFactory.newThread(worker); leak = leakDetection || !workerThread.isDaemon() ? leakDetector.open(this) : null; } private static HashedWheelBucket[] createWheel(int ticksPerWheel) { if (ticksPerWheel <= 0) { throw new IllegalArgumentException( "ticksPerWheel must be greater than 0: " + ticksPerWheel); } if (ticksPerWheel > 1073741824) { throw new IllegalArgumentException( "ticksPerWheel may not be greater than 2^30: " + ticksPerWheel); } ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel); HashedWheelBucket[] wheel = new HashedWheelBucket[ticksPerWheel]; for (int i = 0; i < wheel.length; i ++) { wheel[i] = new HashedWheelBucket(); } return wheel; } private static int normalizeTicksPerWheel(int ticksPerWheel) { int normalizedTicksPerWheel = 1; while (normalizedTicksPerWheel < ticksPerWheel) { normalizedTicksPerWheel <<= 1; } return normalizedTicksPerWheel; } /** * Starts the background thread explicitly. The background thread will * start automatically on demand even if you did not call this method. * * @throws IllegalStateException if this timer has been * {@linkplain #stop() stopped} already */ public void start() { switch (WORKER_STATE_UPDATER.get(this)) { case WORKER_STATE_INIT: if (WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_INIT, WORKER_STATE_STARTED)) { workerThread.start(); } break; case WORKER_STATE_STARTED: break; case WORKER_STATE_SHUTDOWN: throw new IllegalStateException("cannot be started once stopped"); default: throw new Error("Invalid WorkerState"); } // Wait until the startTime is initialized by the worker. while (startTime == 0) { try { startTimeInitialized.await(); } catch (InterruptedException ignore) { // Ignore - it will be ready very soon. } } } @Override public Set<Timeout> stop() { if (Thread.currentThread() == workerThread) { throw new IllegalStateException( TimerWheel.class.getSimpleName() + ".stop() cannot be called from " + TimerTask.class.getSimpleName()); } if (!WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_STARTED, WORKER_STATE_SHUTDOWN)) { // workerState can be 0 or 2 at this moment - let it always be 2. WORKER_STATE_UPDATER.set(this, WORKER_STATE_SHUTDOWN); if (leak != null) { leak.close(); } return Collections.emptySet(); } boolean interrupted = false; while (workerThread.isAlive()) { workerThread.interrupt(); try { workerThread.join(100); } catch (InterruptedException ignored) { interrupted = true; } } if (interrupted) { Thread.currentThread().interrupt(); } if (leak != null) { leak.close(); } return worker.unprocessedTimeouts(); } @Override public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit, String argv) { if (task == null) { throw new NullPointerException("task"); } if (unit == null) { throw new NullPointerException("unit"); } start(); // Add the timeout to the timeout queue which will be processed on the next tick. // During processing all the queued HashedWheelTimeouts will be added to the correct HashedWheelBucket. long deadline = System.nanoTime() + unit.toNanos(delay) - startTime; HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline, argv); timeouts.add(timeout); return timeout; } private final class Worker implements Runnable { private final Set<Timeout> unprocessedTimeouts = new HashSet<Timeout>(); private long tick; @Override public void run() { // Initialize the startTime. startTime = System.nanoTime(); if (startTime == 0) { // We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized. startTime = 1; } // Notify the other threads waiting for the initialization at start(). startTimeInitialized.countDown(); do { final long deadline = waitForNextTick(); if (deadline > 0) { int idx = (int) (tick & mask); processCancelledTasks(); HashedWheelBucket bucket = wheel[idx]; transferTimeoutsToBuckets(); bucket.expireTimeouts(deadline); tick++; } } while (WORKER_STATE_UPDATER.get(TimerWheel.this) == WORKER_STATE_STARTED); // Fill the unprocessedTimeouts so we can return them from stop() method. for (HashedWheelBucket bucket: wheel) { bucket.clearTimeouts(unprocessedTimeouts); } for (;;) { HashedWheelTimeout timeout = timeouts.poll(); if (timeout == null) { break; } if (!timeout.isCancelled()) { unprocessedTimeouts.add(timeout); } } processCancelledTasks(); } private void transferTimeoutsToBuckets() { // transfer only max. 100000 timeouts per tick to prevent a thread to stale the workerThread when it just // adds new timeouts in a loop. for (int i = 0; i < 100000; i++) { HashedWheelTimeout timeout = timeouts.poll(); if (timeout == null) { // all processed break; } if (timeout.state() == HashedWheelTimeout.ST_CANCELLED) { // Was cancelled in the meantime. continue; } long calculated = timeout.deadline / tickDuration; timeout.remainingRounds = (calculated - tick) / wheel.length; final long ticks = Math.max(calculated, tick); // Ensure we don't schedule for past. int stopIndex = (int) (ticks & mask); HashedWheelBucket bucket = wheel[stopIndex]; bucket.addTimeout(timeout); } } private void processCancelledTasks() { for (;;) { HashedWheelTimeout timeout = cancelledTimeouts.poll(); if (timeout == null) { // all processed break; } try { timeout.remove(); } catch (Throwable t) { if (logger.isWarnEnabled()) { logger.warn("An exception was thrown while process a cancellation task", t); } } } } /** * calculate goal nanoTime from startTime and current tick number, * then wait until that goal has been reached. * @return Long.MIN_VALUE if received a shutdown request, * current time otherwise (with Long.MIN_VALUE changed by +1) */ private long waitForNextTick() { long deadline = tickDuration * (tick + 1); for (;;) { final long currentTime = System.nanoTime() - startTime; long sleepTimeMs = (deadline - currentTime + 999999) / 1000000; if (sleepTimeMs <= 0) { if (currentTime == Long.MIN_VALUE) { return -Long.MAX_VALUE; } else { return currentTime; } } // Check if we run on windows, as if thats the case we will need // to round the sleepTime as workaround for a bug that only affect // the JVM if it runs on windows. // // See https://github.com/netty/netty/issues/356 if (PlatformDependent.isWindows()) { sleepTimeMs = sleepTimeMs / 10 * 10; } try { Thread.sleep(sleepTimeMs); } catch (InterruptedException ignored) { if (WORKER_STATE_UPDATER.get(TimerWheel.this) == WORKER_STATE_SHUTDOWN) { return Long.MIN_VALUE; } } } } public Set<Timeout> unprocessedTimeouts() { return Collections.unmodifiableSet(unprocessedTimeouts); } } private static final class HashedWheelTimeout implements Timeout { private static final int ST_INIT = 0; private static final int ST_CANCELLED = 1; private static final int ST_EXPIRED = 2; private static final AtomicIntegerFieldUpdater<HashedWheelTimeout> STATE_UPDATER; static { AtomicIntegerFieldUpdater<HashedWheelTimeout> updater = PlatformDependent.newAtomicIntegerFieldUpdater(HashedWheelTimeout.class, "state"); if (updater == null) { updater = AtomicIntegerFieldUpdater.newUpdater(HashedWheelTimeout.class, "state"); } STATE_UPDATER = updater; } private final TimerWheel timer; private final TimerTask task; private final long deadline; @SuppressWarnings({"unused", "FieldMayBeFinal", "RedundantFieldInitialization" }) private volatile int state = ST_INIT; // remainingRounds will be calculated and set by Worker.transferTimeoutsToBuckets() before the // HashedWheelTimeout will be added to the correct HashedWheelBucket. long remainingRounds; String argv; // This will be used to chain timeouts in HashedWheelTimerBucket via a double-linked-list. // As only the workerThread will act on it there is no need for synchronization / volatile. HashedWheelTimeout next; HashedWheelTimeout prev; // The bucket to which the timeout was added HashedWheelBucket bucket; HashedWheelTimeout(TimerWheel timer, TimerTask task, long deadline, String argv) { this.timer = timer; this.task = task; this.deadline = deadline; this.argv = argv; } @Override public Timer timer() { return timer; } @Override public TimerTask task() { return task; } @Override public boolean cancel() { // only update the state it will be removed from HashedWheelBucket on next tick. if (!compareAndSetState(ST_INIT, ST_CANCELLED)) { return false; } // If a task should be canceled we put this to another queue which will be processed on each tick. // So this means that we will have a GC latency of max. 1 tick duration which is good enough. This way // we can make again use of our MpscLinkedQueue and so minimize the locking / overhead as much as possible. timer.cancelledTimeouts.add(this); return true; } void remove() { HashedWheelBucket bucket = this.bucket; if (bucket != null) { bucket.remove(this); } } public boolean compareAndSetState(int expected, int state) { return STATE_UPDATER.compareAndSet(this, expected, state); } public int state() { return state; } @Override public boolean isCancelled() { return state() == ST_CANCELLED; } @Override public boolean isExpired() { return state() == ST_EXPIRED; } public void expire() { if (!compareAndSetState(ST_INIT, ST_EXPIRED)) { return; } try { task.run(this, argv); } catch (Throwable t) { if (logger.isWarnEnabled()) { logger.warn("An exception was thrown by " + TimerTask.class.getSimpleName() + '.', t); } } } @Override public String toString() { final long currentTime = System.nanoTime(); long remaining = deadline - currentTime + timer.startTime; StringBuilder buf = new StringBuilder(192) .append(StringUtil.simpleClassName(this)) .append('(') .append("deadline: "); if (remaining > 0) { buf.append(remaining) .append(" ns later"); } else if (remaining < 0) { buf.append(-remaining) .append(" ns ago"); } else { buf.append("now"); } if (isCancelled()) { buf.append(", cancelled"); } return buf.append(", task: ") .append(task()) .append(')') .toString(); } } /** * Bucket that stores HashedWheelTimeouts. These are stored in a linked-list like datastructure to allow easy * removal of HashedWheelTimeouts in the middle. Also the HashedWheelTimeout act as nodes themself and so no * extra object creation is needed. */ private static final class HashedWheelBucket { // Used for the linked-list datastructure private HashedWheelTimeout head; private HashedWheelTimeout tail; /** * Add {@link HashedWheelTimeout} to this bucket. */ public void addTimeout(HashedWheelTimeout timeout) { assert timeout.bucket == null; timeout.bucket = this; if (head == null) { head = tail = timeout; } else { tail.next = timeout; timeout.prev = tail; tail = timeout; } } /** * Expire all {@link HashedWheelTimeout}s for the given {@code deadline}. */ public void expireTimeouts(long deadline) { HashedWheelTimeout timeout = head; // process all timeouts while (timeout != null) { boolean remove = false; if (timeout.remainingRounds <= 0) { if (timeout.deadline <= deadline) { timeout.expire(); } else { // The timeout was placed into a wrong slot. This should never happen. throw new IllegalStateException(String.format( "timeout.deadline (%d) > deadline (%d)", timeout.deadline, deadline)); } remove = true; } else if (timeout.isCancelled()) { remove = true; } else { timeout.remainingRounds --; } // store reference to next as we may null out timeout.next in the remove block. HashedWheelTimeout next = timeout.next; if (remove) { remove(timeout); } timeout = next; } } public void remove(HashedWheelTimeout timeout) { HashedWheelTimeout next = timeout.next; // remove timeout that was either processed or cancelled by updating the linked-list if (timeout.prev != null) { timeout.prev.next = next; } if (timeout.next != null) { timeout.next.prev = timeout.prev; } if (timeout == head) { // if timeout is also the tail we need to adjust the entry too if (timeout == tail) { tail = null; head = null; } else { head = next; } } else if (timeout == tail) { // if the timeout is the tail modify the tail to be the prev node. tail = timeout.prev; } // null out prev, next and bucket to allow for GC. timeout.prev = null; timeout.next = null; timeout.bucket = null; } /** * Clear this bucket and return all not expired / cancelled {@link Timeout}s. */ public void clearTimeouts(Set<Timeout> set) { for (;;) { HashedWheelTimeout timeout = pollTimeout(); if (timeout == null) { return; } if (timeout.isExpired() || timeout.isCancelled()) { continue; } set.add(timeout); } } private HashedWheelTimeout pollTimeout() { HashedWheelTimeout head = this.head; if (head == null) { return null; } HashedWheelTimeout next = head.next; if (next == null) { tail = this.head = null; } else { this.head = next; next.prev = null; } // null out prev and next to allow for GC. head.next = null; head.prev = null; head.bucket = null; return head; } } }
编写测试类Main.java
package com.tanghuachun.timer; import java.util.concurrent.TimeUnit; /** * Created by darren on 2016/11/17. */ public class Main implements TimerTask{ final static Timer timer = new TimerWheel(); public static void main(String[] args) { TimerTask timerTask = new Main(); for (int i = 0; i < 10; i++) { timer.newTimeout(timerTask, 5, TimeUnit.SECONDS, "" + i ); } } @Override public void run(Timeout timeout, String argv) throws Exception { System.out.println("timeout, argv = " + argv ); } }
然后就可以看到运行结果啦。
工程代码下载(以maven的方式导入)。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。