Skip to content

Commit 7205f05

Browse files
committed
Allow @BeforeAll and @afterall methods to be non-static
Issue: #419
1 parent 551cf90 commit 7205f05

File tree

5 files changed

+92
-40
lines changed

5 files changed

+92
-40
lines changed

junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/ClassTestDescriptor.java

+24-16
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public class ClassTestDescriptor extends JupiterTestDescriptor {
6868
private static final ExecutableInvoker executableInvoker = new ExecutableInvoker();
6969

7070
private final Class<?> testClass;
71+
private final Lifecycle lifecycle;
7172

7273
private final List<Method> beforeAllMethods;
7374
private final List<Method> afterAllMethods;
@@ -85,9 +86,10 @@ protected ClassTestDescriptor(UniqueId uniqueId, Function<Class<?>, String> defa
8586
defaultDisplayNameGenerator));
8687

8788
this.testClass = testClass;
89+
this.lifecycle = getTestInstanceLifecycle(testClass);
8890

89-
this.beforeAllMethods = findBeforeAllMethods(testClass);
90-
this.afterAllMethods = findAfterAllMethods(testClass);
91+
this.beforeAllMethods = findBeforeAllMethods(testClass, this.lifecycle == Lifecycle.PER_METHOD);
92+
this.afterAllMethods = findAfterAllMethods(testClass, this.lifecycle == Lifecycle.PER_METHOD);
9193
this.beforeEachMethods = findBeforeEachMethods(testClass);
9294
this.afterEachMethods = findAfterEachMethods(testClass);
9395

@@ -204,9 +206,11 @@ private void invokeBeforeAllMethods(JupiterEngineExecutionContext context) {
204206
ExtensionRegistry registry = context.getExtensionRegistry();
205207
ContainerExtensionContext extensionContext = (ContainerExtensionContext) context.getExtensionContext();
206208
ThrowableCollector throwableCollector = context.getThrowableCollector();
209+
Object testInstance = getTestInstanceForClassLevelCallbacks(context);
207210

208211
for (Method method : this.beforeAllMethods) {
209-
throwableCollector.execute(() -> executableInvoker.invoke(method, extensionContext, registry));
212+
throwableCollector.execute(
213+
() -> executableInvoker.invoke(method, testInstance, extensionContext, registry));
210214
if (throwableCollector.isNotEmpty()) {
211215
break;
212216
}
@@ -217,9 +221,15 @@ private void invokeAfterAllMethods(JupiterEngineExecutionContext context) {
217221
ExtensionRegistry registry = context.getExtensionRegistry();
218222
ContainerExtensionContext extensionContext = (ContainerExtensionContext) context.getExtensionContext();
219223
ThrowableCollector throwableCollector = context.getThrowableCollector();
224+
Object testInstance = getTestInstanceForClassLevelCallbacks(context);
220225

221-
this.afterAllMethods.forEach(
222-
method -> throwableCollector.execute(() -> executableInvoker.invoke(method, extensionContext, registry)));
226+
this.afterAllMethods.forEach(method -> throwableCollector.execute(
227+
() -> executableInvoker.invoke(method, testInstance, extensionContext, registry)));
228+
}
229+
230+
private Object getTestInstanceForClassLevelCallbacks(JupiterEngineExecutionContext context) {
231+
return this.lifecycle == Lifecycle.PER_CLASS
232+
? context.getTestInstanceProvider().getTestInstance(Optional.empty()) : null;
223233
}
224234

225235
private void invokeAfterAllCallbacks(JupiterEngineExecutionContext context) {
@@ -274,7 +284,6 @@ private void invokeMethodInTestExtensionContext(Method method, TestExtensionCont
274284
private final class LifecycleAwareTestInstanceProvider implements TestInstanceProvider {
275285

276286
private final Class<?> testClass;
277-
private final Lifecycle lifecycle;
278287
private final ExtensionRegistry registry;
279288
private final ExtensionContext extensionContext;
280289
private Object testInstance;
@@ -283,14 +292,13 @@ private final class LifecycleAwareTestInstanceProvider implements TestInstancePr
283292
ExtensionContext extensionContext) {
284293

285294
this.testClass = testClass;
286-
this.lifecycle = getInstanceLifecycle(testClass);
287295
this.registry = registry;
288296
this.extensionContext = extensionContext;
289297
}
290298

291299
@Override
292-
public Object getTestInstance(Optional<ExtensionRegistry> childExtensionRegistry) throws Exception {
293-
if (this.lifecycle == Lifecycle.PER_METHOD) {
300+
public Object getTestInstance(Optional<ExtensionRegistry> childExtensionRegistry) {
301+
if (ClassTestDescriptor.this.lifecycle == Lifecycle.PER_METHOD) {
294302
return createTestInstance(childExtensionRegistry);
295303
}
296304

@@ -310,14 +318,14 @@ private Object createTestInstance(Optional<ExtensionRegistry> childExtensionRegi
310318
return instance;
311319
}
312320

313-
private TestInstance.Lifecycle getInstanceLifecycle(Class<?> testClass) {
314-
// @formatter:off
315-
return AnnotationUtils.findAnnotation(testClass, TestInstance.class)
316-
.map(TestInstance::value)
317-
.orElse(Lifecycle.PER_METHOD);
318-
// @formatter:on
319-
}
321+
}
320322

323+
private static TestInstance.Lifecycle getTestInstanceLifecycle(Class<?> testClass) {
324+
// @formatter:off
325+
return AnnotationUtils.findAnnotation(testClass, TestInstance.class)
326+
.map(TestInstance::value)
327+
.orElse(Lifecycle.PER_METHOD);
328+
// @formatter:on
321329
}
322330

323331
}

junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/LifecycleMethodUtils.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,19 @@ private LifecycleMethodUtils() {
3737
}
3838
///CLOVER:ON
3939

40-
static List<Method> findBeforeAllMethods(Class<?> testClass) {
40+
static List<Method> findBeforeAllMethods(Class<?> testClass, boolean requireStatic) {
4141
List<Method> methods = findAnnotatedMethods(testClass, BeforeAll.class, HierarchyTraversalMode.TOP_DOWN);
42-
methods.forEach(method -> assertStatic(BeforeAll.class, method));
42+
if (requireStatic) {
43+
methods.forEach(method -> assertStatic(BeforeAll.class, method));
44+
}
4345
return methods;
4446
}
4547

46-
static List<Method> findAfterAllMethods(Class<?> testClass) {
48+
static List<Method> findAfterAllMethods(Class<?> testClass, boolean requireStatic) {
4749
List<Method> methods = findAnnotatedMethods(testClass, AfterAll.class, HierarchyTraversalMode.BOTTOM_UP);
48-
methods.forEach(method -> assertStatic(AfterAll.class, method));
50+
if (requireStatic) {
51+
methods.forEach(method -> assertStatic(AfterAll.class, method));
52+
}
4953
return methods;
5054
}
5155

@@ -63,8 +67,9 @@ static List<Method> findAfterEachMethods(Class<?> testClass) {
6367

6468
private static void assertStatic(Class<? extends Annotation> annotationType, Method method) {
6569
if (!ReflectionUtils.isStatic(method)) {
66-
throw new JUnitException(String.format("@%s method '%s' must be static.", annotationType.getSimpleName(),
67-
method.toGenericString()));
70+
throw new JUnitException(String.format(
71+
"@%s method '%s' must be static unless the test class is annotated with @TestInstance(Lifecycle.PER_CLASS).",
72+
annotationType.getSimpleName(), method.toGenericString()));
6873
}
6974
}
7075

junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/execution/TestInstanceProvider.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@
2424
@API(Internal)
2525
public interface TestInstanceProvider {
2626

27-
Object getTestInstance(Optional<ExtensionRegistry> childExtensionRegistry) throws Exception;
27+
Object getTestInstance(Optional<ExtensionRegistry> childExtensionRegistry);
2828

2929
}

junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/NestedTestClassesTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ void failing() {
115115
}
116116
}
117117

118-
static private class TestCaseWithDoubleNesting {
118+
private static class TestCaseWithDoubleNesting {
119119

120120
static int beforeTopCount = 0;
121121
static int beforeNestedCount = 0;

junit-jupiter-engine/src/test/java/org/junit/jupiter/engine/TestInstanceLifecycleTests.java

+55-16
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111
package org.junit.jupiter.engine;
1212

13+
import static org.junit.Assert.assertNotNull;
1314
import static org.junit.jupiter.api.Assertions.assertAll;
1415
import static org.junit.jupiter.api.Assertions.assertEquals;
1516

17+
import org.junit.jupiter.api.AfterAll;
1618
import org.junit.jupiter.api.AfterEach;
19+
import org.junit.jupiter.api.BeforeAll;
1720
import org.junit.jupiter.api.BeforeEach;
1821
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.TestInfo;
1923
import org.junit.jupiter.api.TestInstance;
2024
import org.junit.jupiter.api.TestInstance.Lifecycle;
2125
import org.junit.platform.engine.test.event.ExecutionEventRecorder;
@@ -33,38 +37,48 @@
3337
class TestInstanceLifecycleTests extends AbstractJupiterTestEngineTests {
3438

3539
private static int instanceCount;
36-
private static int beforeCount;
37-
private static int afterCount;
40+
private static int beforeAllCount;
41+
private static int afterAllCount;
42+
private static int beforeEachCount;
43+
private static int afterEachCount;
3844

3945
@BeforeEach
4046
void init() {
4147
instanceCount = 0;
42-
beforeCount = 0;
43-
afterCount = 0;
48+
beforeAllCount = 0;
49+
afterAllCount = 0;
50+
beforeEachCount = 0;
51+
afterEachCount = 0;
4452
}
4553

4654
@Test
4755
void instancePerMethod() {
48-
performAssertions(InstancePerMethodTestCase.class, 2);
56+
performAssertions(InstancePerMethodTestCase.class, 2, 1, 1);
4957
}
5058

5159
@Test
5260
void instancePerClass() {
53-
performAssertions(InstancePerClassTestCase.class, 1);
61+
performAssertions(InstancePerClassTestCase.class, 1, 2, 2);
5462
}
5563

56-
private void performAssertions(Class<?> testClass, int expectedInstanceCount) {
64+
private void performAssertions(Class<?> testClass, int expectedInstanceCount, int expectedBeforeAllCount,
65+
int expectedAfterAllCount) {
66+
5767
ExecutionEventRecorder eventRecorder = executeTestsForClass(testClass);
5868

69+
// eventRecorder.eventStream().forEach(System.out::println);
70+
5971
// @formatter:off
6072
assertAll(
61-
() -> assertEquals(expectedInstanceCount, instanceCount, "instance count"),
6273
() -> assertEquals(2, eventRecorder.getContainerStartedCount(), "# containers started"),
6374
() -> assertEquals(2, eventRecorder.getContainerFinishedCount(), "# containers finished"),
6475
() -> assertEquals(2, eventRecorder.getTestStartedCount(), "# tests started"),
6576
() -> assertEquals(2, eventRecorder.getTestSuccessfulCount(), "# tests succeeded"),
66-
() -> assertEquals(2, beforeCount, "# before calls"),
67-
() -> assertEquals(2, afterCount, "# after calls")
77+
() -> assertEquals(expectedInstanceCount, instanceCount, "instance count"),
78+
() -> assertEquals(expectedBeforeAllCount, beforeAllCount, "@BeforeAll count"),
79+
() -> assertEquals(expectedAfterAllCount, afterAllCount, "@AfterAll count"),
80+
() -> assertEquals(2, beforeEachCount, "@BeforeEach count"),
81+
() -> assertEquals(2, afterEachCount, "@AfterEach count")
6882
);
6983
// @formatter:on
7084
}
@@ -77,14 +91,15 @@ private static class InstancePerMethodTestCase {
7791
instanceCount++;
7892
}
7993

80-
@BeforeEach
81-
void before() {
82-
beforeCount++;
94+
@BeforeAll
95+
static void beforeAllStatic(TestInfo testInfo) {
96+
assertNotNull(testInfo);
97+
beforeAllCount++;
8398
}
8499

85-
@AfterEach
86-
void after() {
87-
afterCount++;
100+
@BeforeEach
101+
void beforeEach() {
102+
beforeEachCount++;
88103
}
89104

90105
@Test
@@ -95,10 +110,34 @@ void test1() {
95110
void test2() {
96111
}
97112

113+
@AfterEach
114+
void afterEach() {
115+
afterEachCount++;
116+
}
117+
118+
@AfterAll
119+
static void afterAllStatic(TestInfo testInfo) {
120+
assertNotNull(testInfo);
121+
afterAllCount++;
122+
}
123+
98124
}
99125

100126
@TestInstance(Lifecycle.PER_CLASS)
101127
private static class InstancePerClassTestCase extends InstancePerMethodTestCase {
128+
129+
@BeforeAll
130+
void beforeAll(TestInfo testInfo) {
131+
assertNotNull(testInfo);
132+
beforeAllCount++;
133+
}
134+
135+
@AfterAll
136+
void afterAll(TestInfo testInfo) {
137+
assertNotNull(testInfo);
138+
afterAllCount++;
139+
}
140+
102141
}
103142

104143
}

0 commit comments

Comments
 (0)