Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Delete LoweredFunction.dependencies and hoisted instructions #32096

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 12 additions & 142 deletions compiler/packages/babel-plugin-react-compiler/src/HIR/BuildHIR.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import {NodePath, Scope} from '@babel/traverse';
import * as t from '@babel/types';
import {Expression} from '@babel/types';
import invariant from 'invariant';
import {
CompilerError,
Expand Down Expand Up @@ -75,7 +74,7 @@ export function lower(
parent: NodePath<t.Function> | null = null,
): Result<HIRFunction, CompilerError> {
const builder = new HIRBuilder(env, parent ?? func, bindings, capturedRefs);
const context: Array<Place> = [];
const context: HIRFunction['context'] = [];

for (const ref of capturedRefs ?? []) {
context.push({
Expand Down Expand Up @@ -3377,7 +3376,7 @@ function lowerFunction(
>,
): LoweredFunction | null {
const componentScope: Scope = builder.parentFunction.scope;
const captured = gatherCapturedDeps(builder, expr, componentScope);
const capturedContext = gatherCapturedContext(expr, componentScope);

/*
* TODO(gsn): In the future, we could only pass in the context identifiers
Expand All @@ -3391,7 +3390,7 @@ function lowerFunction(
expr,
builder.environment,
builder.bindings,
[...builder.context, ...captured.identifiers],
[...builder.context, ...capturedContext],
builder.parentFunction,
);
let loweredFunc: HIRFunction;
Expand All @@ -3404,7 +3403,6 @@ function lowerFunction(
loweredFunc = lowering.unwrap();
return {
func: loweredFunc,
dependencies: captured.refs,
};
}

Expand Down Expand Up @@ -4078,14 +4076,6 @@ function lowerAssignment(
}
}

function isValidDependency(path: NodePath<t.MemberExpression>): boolean {
const parent: NodePath<t.Node> = path.parentPath;
return (
!path.node.computed &&
!(parent.isCallExpression() && parent.get('callee') === path)
);
}

function captureScopes({from, to}: {from: Scope; to: Scope}): Set<Scope> {
let scopes: Set<Scope> = new Set();
while (from) {
Expand All @@ -4100,19 +4090,16 @@ function captureScopes({from, to}: {from: Scope; to: Scope}): Set<Scope> {
return scopes;
}

function gatherCapturedDeps(
builder: HIRBuilder,
function gatherCapturedContext(
fn: NodePath<
| t.FunctionExpression
| t.ArrowFunctionExpression
| t.FunctionDeclaration
| t.ObjectMethod
>,
componentScope: Scope,
): {identifiers: Array<t.Identifier>; refs: Array<Place>} {
const capturedIds: Map<t.Identifier, number> = new Map();
const capturedRefs: Set<Place> = new Set();
const seenPaths: Set<string> = new Set();
): Array<t.Identifier> {
const capturedIds = new Set<t.Identifier>();

/*
* Capture all the scopes from the parent of this function up to and including
Expand All @@ -4123,33 +4110,11 @@ function gatherCapturedDeps(
to: componentScope,
});

function addCapturedId(bindingIdentifier: t.Identifier): number {
if (!capturedIds.has(bindingIdentifier)) {
const index = capturedIds.size;
capturedIds.set(bindingIdentifier, index);
return index;
} else {
return capturedIds.get(bindingIdentifier)!;
}
}

function handleMaybeDependency(
path:
| NodePath<t.MemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXOpeningElement>,
path: NodePath<t.Identifier> | NodePath<t.JSXOpeningElement>,
): void {
// Base context variable to depend on
let baseIdentifier: NodePath<t.Identifier> | NodePath<t.JSXIdentifier>;
/*
* Base expression to depend on, which (for now) may contain non side-effectful
* member expressions
*/
let dependency:
| NodePath<t.MemberExpression>
| NodePath<t.JSXMemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXIdentifier>;
if (path.isJSXOpeningElement()) {
const name = path.get('name');
if (!(name.isJSXMemberExpression() || name.isJSXIdentifier())) {
Expand All @@ -4165,115 +4130,20 @@ function gatherCapturedDeps(
'Invalid logic in gatherCapturedDeps',
);
baseIdentifier = current;

/*
* Get the expression to depend on, which may involve PropertyLoads
* for member expressions
*/
let currentDep:
| NodePath<t.JSXMemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXIdentifier> = baseIdentifier;

while (true) {
const nextDep: null | NodePath<t.Node> = currentDep.parentPath;
if (nextDep && nextDep.isJSXMemberExpression()) {
currentDep = nextDep;
} else {
break;
}
}
dependency = currentDep;
} else if (path.isMemberExpression()) {
// Calculate baseIdentifier
let currentId: NodePath<Expression> = path;
while (currentId.isMemberExpression()) {
currentId = currentId.get('object');
}
if (!currentId.isIdentifier()) {
return;
}
baseIdentifier = currentId;

/*
* Get the expression to depend on, which may involve PropertyLoads
* for member expressions
*/
let currentDep:
| NodePath<t.MemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXIdentifier> = baseIdentifier;

while (true) {
const nextDep: null | NodePath<t.Node> = currentDep.parentPath;
if (
nextDep &&
nextDep.isMemberExpression() &&
isValidDependency(nextDep)
) {
currentDep = nextDep;
} else {
break;
}
}

dependency = currentDep;
} else {
baseIdentifier = path;
dependency = path;
}

/*
* Skip dependency path, as we already tried to recursively add it (+ all subexpressions)
* as a dependency.
*/
dependency.skip();
path.skip();

// Add the base identifier binding as a dependency.
const binding = baseIdentifier.scope.getBinding(baseIdentifier.node.name);
if (binding === undefined || !pureScopes.has(binding.scope)) {
return;
}
const idKey = String(addCapturedId(binding.identifier));

// Add the expression (potentially a memberexpr path) as a dependency.
let exprKey = idKey;
if (dependency.isMemberExpression()) {
let pathTokens = [];
let current: NodePath<Expression> = dependency;
while (current.isMemberExpression()) {
const property = current.get('property') as NodePath<t.Identifier>;
pathTokens.push(property.node.name);
current = current.get('object');
}

exprKey += '.' + pathTokens.reverse().join('.');
} else if (dependency.isJSXMemberExpression()) {
let pathTokens = [];
let current: NodePath<t.JSXMemberExpression | t.JSXIdentifier> =
dependency;
while (current.isJSXMemberExpression()) {
const property = current.get('property');
pathTokens.push(property.node.name);
current = current.get('object');
}
}

if (!seenPaths.has(exprKey)) {
let loweredDep: Place;
if (dependency.isJSXIdentifier()) {
loweredDep = lowerValueToTemporary(builder, {
kind: 'LoadLocal',
place: lowerIdentifier(builder, dependency),
loc: path.node.loc ?? GeneratedSource,
});
} else if (dependency.isJSXMemberExpression()) {
loweredDep = lowerJsxMemberExpression(builder, dependency);
} else {
loweredDep = lowerExpressionToTemporary(builder, dependency);
}
capturedRefs.add(loweredDep);
seenPaths.add(exprKey);
if (binding !== undefined && pureScopes.has(binding.scope)) {
capturedIds.add(binding.identifier);
}
}

Expand Down Expand Up @@ -4304,13 +4174,13 @@ function gatherCapturedDeps(
return;
} else if (path.isJSXElement()) {
handleMaybeDependency(path.get('openingElement'));
} else if (path.isMemberExpression() || path.isIdentifier()) {
} else if (path.isIdentifier()) {
handleMaybeDependency(path);
}
},
});

return {identifiers: [...capturedIds.keys()], refs: [...capturedRefs]};
return [...capturedIds.keys()];
}

function notNull<T>(value: T | null): value is T {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,7 @@ function collectHoistablePropertyLoadsImpl(
fn: HIRFunction,
context: CollectHoistablePropertyLoadsContext,
): ReadonlyMap<BlockId, BlockInfo> {
const functionExpressionLoads = collectFunctionExpressionFakeLoads(fn);
const actuallyEvaluatedTemporaries = new Map(
[...context.temporaries].filter(([id]) => !functionExpressionLoads.has(id)),
);

const nodes = collectNonNullsInBlocks(fn, {
...context,
temporaries: actuallyEvaluatedTemporaries,
});
const nodes = collectNonNullsInBlocks(fn, context);
propagateNonNull(fn, nodes, context.registry);

if (DEBUG_PRINT) {
Expand Down Expand Up @@ -598,30 +590,3 @@ function reduceMaybeOptionalChains(
}
} while (changed);
}

function collectFunctionExpressionFakeLoads(
fn: HIRFunction,
): Set<IdentifierId> {
const sources = new Map<IdentifierId, IdentifierId>();
const functionExpressionReferences = new Set<IdentifierId>();

for (const [_, block] of fn.body.blocks) {
for (const {lvalue, value} of block.instructions) {
if (
value.kind === 'FunctionExpression' ||
value.kind === 'ObjectMethod'
) {
for (const reference of value.loweredFunc.dependencies) {
let curr: IdentifierId | undefined = reference.identifier.id;
while (curr != null) {
functionExpressionReferences.add(curr);
curr = sources.get(curr);
}
}
} else if (value.kind === 'PropertyLoad') {
sources.set(lvalue.identifier.id, value.object.identifier.id);
}
}
}
return functionExpressionReferences;
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ const EnvironmentConfigSchema = z.object({
*/
enableUseTypeAnnotations: z.boolean().default(false),

enableFunctionDependencyRewrite: z.boolean().default(true),

/**
* Enables inference of optional dependency chains. Without this flag
* a property chain such as `props?.items?.foo` will infer as a dep on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,6 @@ export type ObjectProperty = {
};

export type LoweredFunction = {
dependencies: Array<Place>;
func: HIRFunction;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ function getOverlappingReactiveScopes(
for (const instr of block.instructions) {
visitInstructionId(instr.id, context, state);
for (const place of eachInstructionOperand(instr)) {
if (
(instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod') &&
place.identifier.type.kind === 'Primitive'
) {
continue;
}
visitPlace(instr.id, place, state);
}
for (const place of eachInstructionLValue(instr)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,16 @@ addObject(BUILTIN_SHAPES, BuiltInMixedReadonlyId, [
[
'map',
addFunction(BUILTIN_SHAPES, [], {
/**
* Note `map`'s arguments are annotated as Effect.ConditionallyMutate as
* calling `<array>.map(fn)` might invoke `fn`, which means replaying its
* effects.
*
* (Note that Effect.Read / Effect.Capture on a function type means
* potential data dependency or aliasing respectively.)
*/
positionalParams: [],
restParam: Effect.Read,
restParam: Effect.ConditionallyMutate,
returnType: {kind: 'Object', shapeId: BuiltInArrayId},
calleeEffect: Effect.ConditionallyMutate,
returnValueKind: ValueKind.Mutable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,8 @@ export function printInstructionValue(instrValue: ReactiveValue): string {
.split('\n')
.map(line => ` ${line}`)
.join('\n');
const deps = instrValue.loweredFunc.dependencies
.map(dep => printPlace(dep))
.join(',');
const context = instrValue.loweredFunc.func.context
.map(dep => printPlace(dep))
.map(dep => `${printPlace(dep)}`)
.join(',');
const effects =
instrValue.loweredFunc.func.effects
Expand All @@ -557,7 +554,7 @@ export function printInstructionValue(instrValue: ReactiveValue): string {
})
.join(', ') ?? '';
const type = printType(instrValue.loweredFunc.func.returnType).trim();
value = `${kind} ${name} @deps[${deps}] @context[${context}] @effects[${effects}]${type !== '' ? ` return${type}` : ''}:\n${fn}`;
value = `${kind} ${name} @context[${context}] @effects[${effects}]${type !== '' ? ` return${type}` : ''}:\n${fn}`;
break;
}
case 'TaggedTemplateExpression': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,8 @@ function collectDependencies(
}
for (const instr of block.instructions) {
if (
fn.env.config.enableFunctionDependencyRewrite &&
(instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod')
instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod'
) {
context.declare(instr.lvalue.identifier, {
id: instr.id,
Expand Down
Loading
Loading