Sunday, July 15, 2012

Concurrency Tests with JUnit

Sometimes in the life of a programmer, a concurrency test is needed to check if a given framework function doesn't break the data when used simultaneously by different threads.

JUnit provides some sense of AOP by using the @Rule annotation which allows us to modify how a given test runs. I was trying to create these kind of tests for my DTO framework (jDTO Binder) and inspired by this blog post, I created a test rule that allows the programmer to specify the number of threads to run a given test case and collect some little statistics about the running time of the test.

First of all I created an implementation of TestRule.

public class ConcurrencyRule implements TestRule {
private static final Logger logger = LoggerFactory.getLogger(ConcurrencyRule.class);
private final int threadsSpawn;
/**
* Builds a concurrency test with the given number of worker threads.
* @param threadsSpawn
*/
public ConcurrencyRule(int threadsSpawn) {
this.threadsSpawn = threadsSpawn;
}
@Override
public Statement apply(final Statement base, Description description) {
ArrayBlockingQueue<Runnable> queue = new ArrayBlockingQueue<Runnable>(threadsSpawn + 1);
//build a new executor
ThreadPoolExecutor executor = new ThreadPoolExecutor(threadsSpawn, threadsSpawn, 1, TimeUnit.HOURS, queue);
//build the count down latch to synchronize the start.
final CountDownLatch startLatch = new CountDownLatch(1);
//build the count down latch to synchronize the end.
final CountDownLatch endLatch = new CountDownLatch(threadsSpawn);
//atomic variable to write the results of the tests.
final AtomicBoolean result = new AtomicBoolean(false);
logger.info("Scheduling "+threadsSpawn+" threads for testing...");
//iterate until all the threads are created
for (int i = 0; i < threadsSpawn; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
startLatch.await();
base.evaluate();
} catch (Throwable ex) {
logger.error("Error while executing test",ex);
result.set(true);
} finally {
endLatch.countDown();
}
}
});
}
//rise the flag and allow everyone to start.
long startTime = System.currentTimeMillis();
startLatch.countDown();
logger.info("Finished scheduling threads, now waiting for the results...");
try {
endLatch.await();
long endTime = System.currentTimeMillis();
//compute the running time
logger.info("Running Time: "+(endTime - startTime)+" millis");
logger.info("All threads have finished.");
//if some of the tests have failed, then fail all!
if (result.get()) {
Assert.fail("Some of the child threads have failed the test!");
}
} catch (InterruptedException ex) {
logger.error("Error while executing concurrency test");
throw new RuntimeException(ex);
}
return base;
}
}

The previous gist may be used by anyone who would like to create concurrency tests and dont care that much about the execution exceptions but rather the overall results of the tests.

The code does the following:

  1. Create a ThreadPoolExecutor with the size of the tests, you may use other executor implementation if you wish so.
  2. Create a CountdownLatch to make sure all the tests start almost at the same time.
  3. Create a CountdownLatch to block the test thread until all tests finish.
  4. Add all the required tasks to the executor, which will wait until all all have been scheduled.
  5. If some test fails, then a flag will be set to true.
  6. Log the execution time.
  7. Fail the test if some of the threads had an exception.
So at this point we're ready to use the rule on a simple concurrency test:

public class TestConcurrencyStability {
private static final Logger logger = LoggerFactory.getLogger(TestConcurrencyStability.class);
private static DTOBinder binder;
@BeforeClass
public static void globalInit() {
binder = DTOBinderFactory.buildBinder();
}
/**
* The test will run in parallell 100 times.
*/
@Rule
public ConcurrencyRule rule = new ConcurrencyRule(100);
@Test
public void testCase() {
String threadName = Thread.currentThread().getName();
logger.info(threadName);
ArrayList<GeneralPurposeEntity> simpleEntities = new ArrayList<GeneralPurposeEntity>();
//add 100 entities to the source collection
for (int i = 0; i < 100; i++) {
simpleEntities.add(new GeneralPurposeEntity(threadName+" "+i, null, null, 0.0, null, i));
}
List<GeneralPurposeEntity> dtos = binder.bindFromBusinessObjectList(GeneralPurposeEntity.class, simpleEntities);
assertEquals("Lists should have the same size", simpleEntities.size(), dtos.size());
//check every relevant field
for (int i = 0; i < simpleEntities.size(); i++) {
GeneralPurposeEntity source = simpleEntities.get(i);
GeneralPurposeEntity target = dtos.get(i);
assertNotSame("Objects should not be the same", source, target);
assertEquals("The String should be equal",source.getTheString(), target.getTheString());
assertEquals("The int should be equals", source.getTheInt(), target.getTheInt());
}
}
}

And as simple as that we have a concurrency test with some statistics on the running time.