Sunday, November 20, 2011

Howto categorize JUnit test methods and filter them for execution

I was looking for a solution to categorize test methods select them in a flexible way for running.

The closest thing I found was the article by Romain Linsolas which was very helpful for me: http://linsolas.free.fr/wordpress/index.php/2011/02/how-to-categorize-junit-tests-with-maven/

Romain's requirement was to categorize test classes and run a subset of them.

My test classes map directly to their implementation classes using the convention Test.java. Some of the test methods in the same class run very fast, some slow and others require a network connection. So I needed the ability to categorize tests on the method level and search for classes by file name convention (similar to the way the surefire maven plugin does it).

My approach combined the approach from Romain with the class Categories from JUnit. The modifications were that test suites can be annotated with a package name (@TestScanPackage), class name prefix (@TestClassPrefix), class name suffix (@TestClassSuffix) and a test method annotation (@TestMethodAnnotation) to scan for matching test classes in the class path. It is also possible to annotate test methods with multiple categories (e.g. slow and requires an internet connection).

Here's a description of the relevant files:
  • SlowTestCategory.java: Category class to mark slow tests.
  • OnlineTestCategory.java: Category to mark test which require an internet connection.
  • SampleTest.java: Example JUnit test class which uses the categories from above using the standard junit Category annotation.
  • MyTestSuite.java: Example test suite which uses FlexibleCategories as test runner.
  • FlexibleCategories.java: Test runner which does all the magic
  • PatternClasspathClassesFinder.java: Helper class for FlexibleCategories to find all classes in the classpath which match the annotations (@TestScanPackage, @TestClassPrefix, @TestClassSuffix, @TestMethodAnnotation)
If you find this useful make you may be interested in the issue I filed for JUnit here: https://github.com/KentBeck/junit/issues/363

Here is an example that shows how to use it ...

SlowTestCategory.java

/** This category marks slow tests. */
public interface SlowTestCategory {
}

OnlineTestCategory.java

/** This category marks tests that require an internet connection. */
public interface OnlineTestCategory {
}

SampleTest.java

public class SampleTest {
 @Test
 @Category({OnlineTestCategory.class, SlowTestCategory.class})
 public void onlineAndSlowTestCategoryMethod() {
 }

 @Test
 @Category(OnlineTestCategory.class)
 public void onlineTestCategoryMethod() {
 }

 @Test
 @Category(SlowTestCategory.class)
 public void slowTestCategoryMethod() {
 }

 @Test
 public void noTestCategoryMethod() {
 }
}

MyTestSuite.java

/** MyTestSuite runs all slow tests, excluding all test which require a network connection. */
@RunWith(FlexibleCategories.class)
@ExcludeCategory(OnlineTestCategory.class)
@IncludeCategory(SlowTestCategory.class)
@TestScanPackage("my.package")
@TestClassPrefix("")
@TestClassSuffix("Test")
public class MyTestSuite {
}

FlexibleCategories.java

import java.lang.annotation.Annotation;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

import org.junit.Test;
import org.junit.experimental.categories.Categories.CategoryFilter;
import org.junit.experimental.categories.Categories.ExcludeCategory;
import org.junit.experimental.categories.Categories.IncludeCategory;
import org.junit.experimental.categories.Category;
import org.junit.runner.Description;
import org.junit.runner.manipulation.NoTestsRemainException;
import org.junit.runners.Suite;
import org.junit.runners.model.InitializationError;
import org.junit.runners.model.RunnerBuilder;

/**
 * This class is based on org.junit.experimental.categories.Categories from JUnit 4.10.
 *
 * All anotations and inner classes from the original class Categories are removed,
 * since they will be re-used.
 * Unfortunately sub-classing Categories did not work.
 */
public class FlexibleCategories extends Suite {

 /**
  * Specifies the package which should be scanned for test classes (e.g. @TestScanPackage("my.package")).
  * This annotation is required.
  */
 @Retention(RetentionPolicy.RUNTIME)
 public @interface TestScanPackage {
  public String value();
 }

 /**
  * Specifies the prefix of matching class names (e.g. @TestClassPrefix("Test")).
  * This annotation is optional (default: "").
  */
 @Retention(RetentionPolicy.RUNTIME)
 public @interface TestClassPrefix {
  public String value();
 }

 /**
  * Specifies the suffix of matching class names (e.g. @TestClassSuffix("Test")).
  * This annotation is optional (default: "Test").
  */
 @Retention(RetentionPolicy.RUNTIME)
 public @interface TestClassSuffix {
  public String value();
 }

 /**
  * Specifies an annotation for methods which must be present in a matching class (e.g. @TestMethodAnnotationFilter(Test.class)).
  * This annotation is optional (default: org.junit.Test.class).
  */
 @Retention(RetentionPolicy.RUNTIME)
 public @interface TestMethodAnnotation {
  public Class<? extends Annotation> value();
 }

 public FlexibleCategories(Class<?> clazz, RunnerBuilder builder)
   throws InitializationError {
  this(builder, clazz, PatternClasspathClassesFinder.getSuiteClasses(
    getTestScanPackage(clazz), getTestClassPrefix(clazz), getTestClassSuffix(clazz),
    getTestMethodAnnotation(clazz)));
  try {
   filter(new CategoryFilter(getIncludedCategory(clazz),
     getExcludedCategory(clazz)));
  } catch (NoTestsRemainException e) {
   // Ignore all classes with no matching tests.
  }
  assertNoCategorizedDescendentsOfUncategorizeableParents(getDescription());
 }

 public FlexibleCategories(RunnerBuilder builder, Class<?> clazz,
   Class<?>[] suiteClasses) throws InitializationError {
  super(builder, clazz, suiteClasses);
 }

 private static String getTestScanPackage(Class<?> clazz) throws InitializationError {
  TestScanPackage annotation = clazz.getAnnotation(TestScanPackage.class);
  if (annotation == null) {
   throw new InitializationError("No package given to scan for tests!\nUse the annotation @TestScanPackage(\"my.package\") on the test suite " + clazz + ".");
  }
  return annotation.value();
 }

 private static String getTestClassPrefix(Class<?> clazz) {
  TestClassPrefix annotation = clazz.getAnnotation(TestClassPrefix.class);
  return annotation == null ? "" : annotation.value();
 }

 private static String getTestClassSuffix(Class<?> clazz) {
  TestClassSuffix annotation = clazz.getAnnotation(TestClassSuffix.class);
  return annotation == null ? "Test" : annotation.value();
 }

 private static Class<? extends Annotation> getTestMethodAnnotation(Class<?> clazz) {
  TestMethodAnnotation annotation = clazz.getAnnotation(TestMethodAnnotation.class);
  return annotation == null ? Test.class : annotation.value();
 }

 private Class<?> getIncludedCategory(Class<?> clazz) {
  IncludeCategory annotation= clazz.getAnnotation(IncludeCategory.class);
  return annotation == null ? null : annotation.value();
 }

 private Class<?> getExcludedCategory(Class<?> clazz) {
  ExcludeCategory annotation= clazz.getAnnotation(ExcludeCategory.class);
  return annotation == null ? null : annotation.value();
 }

 private void assertNoCategorizedDescendentsOfUncategorizeableParents(Description description) throws InitializationError {
  if (!canHaveCategorizedChildren(description))
   assertNoDescendantsHaveCategoryAnnotations(description);
  for (Description each : description.getChildren())
   assertNoCategorizedDescendentsOfUncategorizeableParents(each);
 }

 private void assertNoDescendantsHaveCategoryAnnotations(Description description) throws InitializationError {
  for (Description each : description.getChildren()) {
   if (each.getAnnotation(Category.class) != null)
    throw new InitializationError("Category annotations on Parameterized classes are not supported on individual methods.");
   assertNoDescendantsHaveCategoryAnnotations(each);
  }
 }

 // If children have names like [0], our current magical category code can't determine their
 // parentage.
 private static boolean canHaveCategorizedChildren(Description description) {
  for (Description each : description.getChildren())
   if (each.getTestClass() == null)
    return false;
  return true;
 }
}

PatternClasspathClassesFinder.java

import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;

/**
 *
 * Modified version of ClasspathClassesFinder from:
 * http://linsolas.free.fr/wordpress/index.php/2011/02/how-to-categorize-junit-tests-with-maven/
 *
 * The difference is, that it does not search for annotated classes but for classes with a certain
 * class name prefix and suffix.
 */
public final class PatternClasspathClassesFinder {

 /**
  * Get the list of classes of a given package name, and that are annotated
  * by a given annotation.
  *
  * @param packageName
  *            The package name of the classes.
  * @param classPrefix
  *            The prefix of the class name.
  * @param classSuffix
  *            The suffix of the class name.
  * @param methodAnnotation
  *            Only return classes containing methods annotated with methodAnnotation.
  * @return The List of classes that matches the requirements.
  */
 public static Class<?>[] getSuiteClasses(String packageName,
   String classPrefix, String classSuffix,
   Class<? extends Annotation> methodAnnotation) {
  try {
   return getClasses(packageName, classPrefix, classSuffix, methodAnnotation);
  } catch (Exception e) {
   e.printStackTrace();
  }
  return null;
 }

 /**
  * Get the list of classes of a given package name, and that are annotated
  * by a given annotation.
  *
  * @param packageName
  *            The package name of the classes.
  * @param classPrefix
  *            The prefix of the class name.
  * @param classSuffix
  *            The suffix of the class name.
  * @param methodAnnotation
  *            Only return classes containing methods annotated with methodAnnotation.
  * @return The List of classes that matches the requirements.
  * @throws ClassNotFoundException
  *             If something goes wrong...
  * @throws IOException
  *             If something goes wrong...
  */
 private static Class<?>[] getClasses(String packageName,
   String classPrefix, String classSuffix,
   Class<? extends Annotation> methodAnnotation)
   throws ClassNotFoundException, IOException {
  ClassLoader classLoader = Thread.currentThread()
    .getContextClassLoader();
  String path = packageName.replace('.', '/');
  // Get classpath
  Enumeration<URL> resources = classLoader.getResources(path);
  List<File> dirs = new ArrayList<File>();
  while (resources.hasMoreElements()) {
   URL resource = resources.nextElement();
   dirs.add(new File(resource.getFile()));
  }
  // For each classpath, get the classes.
  ArrayList<Class<?>> classes = new ArrayList<Class<?>>();
  for (File directory : dirs) {
   classes.addAll(findClasses(directory, packageName, classPrefix, classSuffix, methodAnnotation));
  }
  return classes.toArray(new Class[classes.size()]);
 }

 /**
  * Find classes, in a given directory (recursively), for a given package
  * name, that are annotated by a given annotation.
  *
  * @param directory
  *            The directory where to look for.
  * @param packageName
  *            The package name of the classes.
  * @param classPrefix
  *            The prefix of the class name.
  * @param classSuffix
  *            The suffix of the class name.
  * @param methodAnnotation
  *            Only return classes containing methods annotated with methodAnnotation.
  * @return The List of classes that matches the requirements.
  * @throws ClassNotFoundException
  *             If something goes wrong...
  */
 private static List<Class<?>> findClasses(File directory,
   String packageName, String classPrefix, String classSuffix,
   Class<? extends Annotation> methodAnnotation)
   throws ClassNotFoundException {
  List<Class<?>> classes = new ArrayList<Class<?>>();
  if (!directory.exists()) {
   return classes;
  }
  File[] files = directory.listFiles();
  for (File file : files) {
   if (file.isDirectory()) {
    classes.addAll(findClasses(file,
      packageName + "." + file.getName(), classPrefix, classSuffix, methodAnnotation));
   } else if (file.getName().startsWith(classPrefix) && file.getName().endsWith(classSuffix + ".class")) {
    // We remove the .class at the end of the filename to get the
    // class name...
    Class<?> clazz = Class.forName(packageName
      + '.'
      + file.getName().substring(0,
        file.getName().length() - 6));

    // Check, if class contains test methods (prevent "No runnable methods" exception):
    boolean classHasTest = false;
    for (Method method : clazz.getMethods()) {
     if (method.getAnnotation(methodAnnotation) != null) {
      classHasTest = true;
      break;
     }
    }
    if (classHasTest) {
     classes.add(clazz);
    }
   }
  }
  return classes;
 }
}