Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inheritance Support #10

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions src/main/java/io/papermc/restamp/recipe/MethodATMutator.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.cadixdev.bombe.type.VoidType;
import org.cadixdev.bombe.type.signature.MethodSignature;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.JavaType.FullyQualified;
import org.openrewrite.java.tree.TypeTree;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -34,6 +36,7 @@ public class MethodATMutator extends Recipe {
private static final Logger LOGGER = LoggerFactory.getLogger(MethodATMutator.class);

private final AccessTransformSet atDictionary;
private final AccessTransformSet inheritanceAccessTransformAtDirectory;
private final ModifierTransformer modifierTransformer;
private final AccessTransformerTypeConverter atTypeConverter;

Expand All @@ -43,6 +46,12 @@ public MethodATMutator(final AccessTransformSet atDictionary,
this.atDictionary = atDictionary;
this.modifierTransformer = modifierTransformer;
this.atTypeConverter = atTypeConverter;

// Create a copy of the atDirectory for inherited at lookups.
// Needed as the parent type may be processed first, removing its access transformer for tracking purposes.
// Child types hence lookup using this.
this.inheritanceAccessTransformAtDirectory = AccessTransformSet.create();
this.inheritanceAccessTransformAtDirectory.merge(this.atDictionary);
}

@Override
Expand All @@ -67,20 +76,14 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre
if (parentClassDeclaration == null || parentClassDeclaration.getType() == null)
return methodDeclaration;

// Find access transformers for class
final AccessTransformSet.Class transformerClass = atDictionary.getClass(
parentClassDeclaration.getType().getFullyQualifiedName()
).orElse(null);
if (transformerClass == null) return methodDeclaration;

final String methodIdentifier = parentClassDeclaration.getType().getFullyQualifiedName() + "#" + methodDeclaration.getName();

if (methodDeclaration.getMethodType() == null) {
LOGGER.warn("Method {} did not have a method type!", methodIdentifier);
return methodDeclaration;
}

// Fetch access transformer to apply to specific field.
// Fetch access transformer to apply to specific method.
String atMethodName = methodDeclaration.getMethodType().getName();
Type returnType = atTypeConverter.convert(methodDeclaration.getMethodType().getReturnType(),
() -> "Parsing return type " + methodDeclaration.getReturnTypeExpression().toString() + " of method " + methodIdentifier);
Expand All @@ -101,10 +104,14 @@ public J.MethodDeclaration visitMethodDeclaration(final J.MethodDeclaration unre
returnType = VoidType.INSTANCE;
}

final AccessTransform accessTransform = transformerClass.replaceMethod(new MethodSignature(
atMethodName, new MethodDescriptor(parameterTypes, returnType)
), AccessTransform.EMPTY);
if (accessTransform == null || accessTransform.isEmpty()) return methodDeclaration;
// Find access transformers for method
final AccessTransform accessTransform = findApplicableAccessTransformer(
parentClassDeclaration.getType(),
atMethodName,
returnType,
parameterTypes
);
if (accessTransform == null) return methodDeclaration;

final TypeTree returnTypeExpression = methodDeclaration.getReturnTypeExpression();
final ModifierTransformationResult transformationResult = modifierTransformer.transformModifiers(
Expand All @@ -125,4 +132,48 @@ atMethodName, new MethodDescriptor(parameterTypes, returnType)
};
}

/**
* Finds the applicable access transformer for a method and *optionally* removes it from the atDirectory.
*
* @param owningType the owning type of the method, e.g. the type it is defined in.
* @param atMethodName the method name.
* @param returnType the return type.
* @param parameterTypes the method parameters.
*
* @return the access transformer or null.
*/
@Nullable
private AccessTransform findApplicableAccessTransformer(
final FullyQualified owningType,
final String atMethodName,
final Type returnType,
final List<FieldType> parameterTypes
) {
final MethodSignature methodSignature = new MethodSignature(
atMethodName,
new MethodDescriptor(parameterTypes, returnType)
);

for (FullyQualified currentCheckedType = owningType; currentCheckedType != null; currentCheckedType = currentCheckedType.getSupertype()) {
// The class at data from the copy of the at dir.
// Removal of these happens later but we need the original state to ensure overrides are updated.
final AccessTransformSet.Class transformerClass = inheritanceAccessTransformAtDirectory
.getClass(currentCheckedType.getFullyQualifiedName())
.orElse(null);
if (transformerClass == null) continue;

// Only get the method here.
final AccessTransform accessTransform = transformerClass.getMethod(methodSignature);
if (accessTransform == null || accessTransform.isEmpty()) continue;

// If we *did* find an AT here and this *is* the direct owning type, remove it from the original atDirectory.
if (currentCheckedType == owningType) {
atDictionary.getClass(transformerClass.getName()).ifPresent(c -> c.replaceMethod(methodSignature, AccessTransform.EMPTY));
}
return accessTransform;
}

return null; // We did not find anything applicable.
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.openrewrite.java.tree.Space;
import org.openrewrite.marker.Markers;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;
Expand All @@ -33,16 +34,16 @@ public class RestampFunctionTestHelper {
* Constructs a new restamp input object from a single java class' source in a string.
*
* @param accessTransformSet the access transformers to apply.
* @param javaClassSource the source code of a java class.
* @param javaClassesSource the source code of a java class.
*
* @return the constructed restamp input.
*/
public static RestampInput inputFromSourceString(final AccessTransformSet accessTransformSet,
final String javaClassSource) {
final String... javaClassesSource) {
final Java21Parser javaParser = Java21Parser.builder().build();
final InMemoryExecutionContext executionContext = new InMemoryExecutionContext(t -> Assertions.fail("Failed to parse inputs", t));
final List<SourceFile> sourceFiles = javaParser.parseInputs(
List.of(Parser.Input.fromString(javaClassSource)),
Arrays.stream(javaClassesSource).map(Parser.Input::fromString).toList(),
null,
executionContext
).toList();
Expand Down
65 changes: 65 additions & 0 deletions src/test/java/io/papermc/restamp/at/InheritanceMethodATTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package io.papermc.restamp.at;

import io.papermc.restamp.Restamp;
import io.papermc.restamp.RestampFunctionTestHelper;
import io.papermc.restamp.RestampInput;
import org.cadixdev.at.AccessTransform;
import org.cadixdev.at.AccessTransformSet;
import org.cadixdev.bombe.type.signature.MethodSignature;
import org.jspecify.annotations.NullMarked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.openrewrite.Result;

import java.util.List;

@NullMarked
public class InheritanceMethodATTest {

@Test
public void testInheritedATs() {
final AccessTransformSet accessTransformSet = AccessTransformSet.create();
accessTransformSet.getOrCreateClass("io.papermc.test.Test").replaceMethod(
MethodSignature.of("test", "(Ljava.lang.Object;)Ljava.lang.String;"), AccessTransform.PUBLIC
);

final RestampInput input = RestampFunctionTestHelper.inputFromSourceString(
accessTransformSet,
"""
package io.papermc.test;

public class Test {
protected String test(final Object parameter) {
return "hi there";
}
}
""",
"""
package io.papermc.test;

public class SuperTest extends Test {
@Override
protected String test(final Object parameter) {
return "hi there but better";
}
}
"""
);

final List<Result> results = Restamp.run(input).getAllResults();
Assertions.assertEquals(
"""
package io.papermc.test;

public class SuperTest extends Test {
@Override
public String test(final Object parameter) {
return "hi there but better";
}
}
""",
results.get(1).getAfter().printAll()
);
}

}
Loading