diff --git a/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java b/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java index 84c3366..7c17469 100644 --- a/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java +++ b/src/main/java/io/papermc/restamp/recipe/MethodATMutator.java @@ -12,12 +12,14 @@ import org.cadixdev.bombe.type.VoidType; import org.cadixdev.bombe.type.signature.MethodSignature; import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; import org.openrewrite.ExecutionContext; import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.JavaType.FullyQualified; import org.openrewrite.java.tree.TypeTree; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +36,7 @@ public class MethodATMutator extends Recipe { private static final Logger LOGGER = LoggerFactory.getLogger(MethodATMutator.class); private final AccessTransformSet atDictionary; + private final AccessTransformSet inheritanceAccessTransformAtDirectory; private final ModifierTransformer modifierTransformer; private final AccessTransformerTypeConverter atTypeConverter; @@ -43,6 +46,12 @@ public MethodATMutator(final AccessTransformSet atDictionary, this.atDictionary = atDictionary; this.modifierTransformer = modifierTransformer; this.atTypeConverter = atTypeConverter; + + // Create a copy of the atDirectory for inherited at lookups. + // Needed as the parent type may be processed first, removing its access transformer for tracking purposes. + // Child types hence lookup using this. + this.inheritanceAccessTransformAtDirectory = AccessTransformSet.create(); + this.inheritanceAccessTransformAtDirectory.merge(this.atDictionary); } @Override @@ -67,12 +76,6 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre if (parentClassDeclaration == null || parentClassDeclaration.getType() == null) return methodDeclaration; - // Find access transformers for class - final AccessTransformSet.Class transformerClass = atDictionary.getClass( - parentClassDeclaration.getType().getFullyQualifiedName() - ).orElse(null); - if (transformerClass == null) return methodDeclaration; - final String methodIdentifier = parentClassDeclaration.getType().getFullyQualifiedName() + "#" + methodDeclaration.getName(); if (methodDeclaration.getMethodType() == null) { @@ -80,7 +83,7 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre return methodDeclaration; } - // Fetch access transformer to apply to specific field. + // Fetch access transformer to apply to specific method. String atMethodName = methodDeclaration.getMethodType().getName(); Type returnType = atTypeConverter.convert(methodDeclaration.getMethodType().getReturnType(), () -> "Parsing return type " + methodDeclaration.getReturnTypeExpression().toString() + " of method " + methodIdentifier); @@ -101,10 +104,14 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre returnType = VoidType.INSTANCE; } - final AccessTransform accessTransform = transformerClass.replaceMethod(new MethodSignature( - atMethodName, new MethodDescriptor(parameterTypes, returnType) - ), AccessTransform.EMPTY); - if (accessTransform == null || accessTransform.isEmpty()) return methodDeclaration; + // Find access transformers for method + final AccessTransform accessTransform = findApplicableAccessTransformer( + parentClassDeclaration.getType(), + atMethodName, + returnType, + parameterTypes + ); + if (accessTransform == null) return methodDeclaration; final TypeTree returnTypeExpression = methodDeclaration.getReturnTypeExpression(); final ModifierTransformationResult transformationResult = modifierTransformer.transformModifiers( @@ -125,4 +132,48 @@ atMethodName, new MethodDescriptor(parameterTypes, returnType) }; } + /** + * Finds the applicable access transformer for a method and *optionally* removes it from the atDirectory. + * + * @param owningType the owning type of the method, e.g. the type it is defined in. + * @param atMethodName the method name. + * @param returnType the return type. + * @param parameterTypes the method parameters. + * + * @return the access transformer or null. + */ + @Nullable + private AccessTransform findApplicableAccessTransformer( + final FullyQualified owningType, + final String atMethodName, + final Type returnType, + final List parameterTypes + ) { + final MethodSignature methodSignature = new MethodSignature( + atMethodName, + new MethodDescriptor(parameterTypes, returnType) + ); + + for (FullyQualified currentCheckedType = owningType; currentCheckedType != null; currentCheckedType = currentCheckedType.getSupertype()) { + // The class at data from the copy of the at dir. + // Removal of these happens later but we need the original state to ensure overrides are updated. + final AccessTransformSet.Class transformerClass = inheritanceAccessTransformAtDirectory + .getClass(currentCheckedType.getFullyQualifiedName()) + .orElse(null); + if (transformerClass == null) continue; + + // Only get the method here. + final AccessTransform accessTransform = transformerClass.getMethod(methodSignature); + if (accessTransform == null || accessTransform.isEmpty()) continue; + + // If we *did* find an AT here and this *is* the direct owning type, remove it from the original atDirectory. + if (currentCheckedType == owningType) { + atDictionary.getClass(transformerClass.getName()).ifPresent(c -> c.replaceMethod(methodSignature, AccessTransform.EMPTY)); + } + return accessTransform; + } + + return null; // We did not find anything applicable. + } + } diff --git a/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java b/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java index f1cd353..dc2d69a 100644 --- a/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java +++ b/src/test/java/io/papermc/restamp/RestampFunctionTestHelper.java @@ -19,6 +19,7 @@ import org.openrewrite.java.tree.Space; import org.openrewrite.marker.Markers; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Stream; @@ -33,16 +34,16 @@ public class RestampFunctionTestHelper { * Constructs a new restamp input object from a single java class' source in a string. * * @param accessTransformSet the access transformers to apply. - * @param javaClassSource the source code of a java class. + * @param javaClassesSource the source code of a java class. * * @return the constructed restamp input. */ public static RestampInput inputFromSourceString(final AccessTransformSet accessTransformSet, - final String javaClassSource) { + final String... javaClassesSource) { final Java21Parser javaParser = Java21Parser.builder().build(); final InMemoryExecutionContext executionContext = new InMemoryExecutionContext(t -> Assertions.fail("Failed to parse inputs", t)); final List sourceFiles = javaParser.parseInputs( - List.of(Parser.Input.fromString(javaClassSource)), + Arrays.stream(javaClassesSource).map(Parser.Input::fromString).toList(), null, executionContext ).toList(); diff --git a/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java b/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java new file mode 100644 index 0000000..0d1a0c9 --- /dev/null +++ b/src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java @@ -0,0 +1,65 @@ +package io.papermc.restamp.at; + +import io.papermc.restamp.Restamp; +import io.papermc.restamp.RestampFunctionTestHelper; +import io.papermc.restamp.RestampInput; +import org.cadixdev.at.AccessTransform; +import org.cadixdev.at.AccessTransformSet; +import org.cadixdev.bombe.type.signature.MethodSignature; +import org.jspecify.annotations.NullMarked; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.openrewrite.Result; + +import java.util.List; + +@NullMarked +public class InheritanceMethodATTest { + + @Test + public void testInheritedATs() { + final AccessTransformSet accessTransformSet = AccessTransformSet.create(); + accessTransformSet.getOrCreateClass("io.papermc.test.Test").replaceMethod( + MethodSignature.of("test", "(Ljava.lang.Object;)Ljava.lang.String;"), AccessTransform.PUBLIC + ); + + final RestampInput input = RestampFunctionTestHelper.inputFromSourceString( + accessTransformSet, + """ + package io.papermc.test; + + public class Test { + protected String test(final Object parameter) { + return "hi there"; + } + } + """, + """ + package io.papermc.test; + + public class SuperTest extends Test { + @Override + protected String test(final Object parameter) { + return "hi there but better"; + } + } + """ + ); + + final List results = Restamp.run(input).getAllResults(); + Assertions.assertEquals( + """ + package io.papermc.test; + + public class SuperTest extends Test { + @Override + public String test(final Object parameter) { + return "hi there but better"; + } + } + """, + results.get(1).getAfter().printAll() + ); + } + +}