Skip to content

Commit

Permalink
[compiler] Delete LoweredFunction.dependencies and hoisted instructions
Browse files Browse the repository at this point in the history
LoweredFunction dependencies were exclusively used for dependency extraction (in `propagateScopeDeps`). Now that we have a `propagateScopeDepsHIR` that recursively traverses into nested functions, we can delete `dependencies` and their associated artificial `LoadLocal`/`PropertyLoad` instructions.

'
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
mofeiZ committed Jan 16, 2025
1 parent eede39a commit 3819d3a
Show file tree
Hide file tree
Showing 46 changed files with 516 additions and 475 deletions.
154 changes: 12 additions & 142 deletions compiler/packages/babel-plugin-react-compiler/src/HIR/BuildHIR.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import {NodePath, Scope} from '@babel/traverse';
import * as t from '@babel/types';
import {Expression} from '@babel/types';
import invariant from 'invariant';
import {
CompilerError,
Expand Down Expand Up @@ -75,7 +74,7 @@ export function lower(
parent: NodePath<t.Function> | null = null,
): Result<HIRFunction, CompilerError> {
const builder = new HIRBuilder(env, parent ?? func, bindings, capturedRefs);
const context: Array<Place> = [];
const context: HIRFunction['context'] = [];

for (const ref of capturedRefs ?? []) {
context.push({
Expand Down Expand Up @@ -3377,7 +3376,7 @@ function lowerFunction(
>,
): LoweredFunction | null {
const componentScope: Scope = builder.parentFunction.scope;
const captured = gatherCapturedDeps(builder, expr, componentScope);
const capturedContext = gatherCapturedContext(expr, componentScope);

/*
* TODO(gsn): In the future, we could only pass in the context identifiers
Expand All @@ -3391,7 +3390,7 @@ function lowerFunction(
expr,
builder.environment,
builder.bindings,
[...builder.context, ...captured.identifiers],
[...builder.context, ...capturedContext],
builder.parentFunction,
);
let loweredFunc: HIRFunction;
Expand All @@ -3404,7 +3403,6 @@ function lowerFunction(
loweredFunc = lowering.unwrap();
return {
func: loweredFunc,
dependencies: captured.refs,
};
}

Expand Down Expand Up @@ -4078,14 +4076,6 @@ function lowerAssignment(
}
}

function isValidDependency(path: NodePath<t.MemberExpression>): boolean {
const parent: NodePath<t.Node> = path.parentPath;
return (
!path.node.computed &&
!(parent.isCallExpression() && parent.get('callee') === path)
);
}

function captureScopes({from, to}: {from: Scope; to: Scope}): Set<Scope> {
let scopes: Set<Scope> = new Set();
while (from) {
Expand All @@ -4100,19 +4090,16 @@ function captureScopes({from, to}: {from: Scope; to: Scope}): Set<Scope> {
return scopes;
}

function gatherCapturedDeps(
builder: HIRBuilder,
function gatherCapturedContext(
fn: NodePath<
| t.FunctionExpression
| t.ArrowFunctionExpression
| t.FunctionDeclaration
| t.ObjectMethod
>,
componentScope: Scope,
): {identifiers: Array<t.Identifier>; refs: Array<Place>} {
const capturedIds: Map<t.Identifier, number> = new Map();
const capturedRefs: Set<Place> = new Set();
const seenPaths: Set<string> = new Set();
): Array<t.Identifier> {
const capturedIds = new Set<t.Identifier>();

/*
* Capture all the scopes from the parent of this function up to and including
Expand All @@ -4123,33 +4110,11 @@ function gatherCapturedDeps(
to: componentScope,
});

function addCapturedId(bindingIdentifier: t.Identifier): number {
if (!capturedIds.has(bindingIdentifier)) {
const index = capturedIds.size;
capturedIds.set(bindingIdentifier, index);
return index;
} else {
return capturedIds.get(bindingIdentifier)!;
}
}

function handleMaybeDependency(
path:
| NodePath<t.MemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXOpeningElement>,
path: NodePath<t.Identifier> | NodePath<t.JSXOpeningElement>,
): void {
// Base context variable to depend on
let baseIdentifier: NodePath<t.Identifier> | NodePath<t.JSXIdentifier>;
/*
* Base expression to depend on, which (for now) may contain non side-effectful
* member expressions
*/
let dependency:
| NodePath<t.MemberExpression>
| NodePath<t.JSXMemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXIdentifier>;
if (path.isJSXOpeningElement()) {
const name = path.get('name');
if (!(name.isJSXMemberExpression() || name.isJSXIdentifier())) {
Expand All @@ -4165,115 +4130,20 @@ function gatherCapturedDeps(
'Invalid logic in gatherCapturedDeps',
);
baseIdentifier = current;

/*
* Get the expression to depend on, which may involve PropertyLoads
* for member expressions
*/
let currentDep:
| NodePath<t.JSXMemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXIdentifier> = baseIdentifier;

while (true) {
const nextDep: null | NodePath<t.Node> = currentDep.parentPath;
if (nextDep && nextDep.isJSXMemberExpression()) {
currentDep = nextDep;
} else {
break;
}
}
dependency = currentDep;
} else if (path.isMemberExpression()) {
// Calculate baseIdentifier
let currentId: NodePath<Expression> = path;
while (currentId.isMemberExpression()) {
currentId = currentId.get('object');
}
if (!currentId.isIdentifier()) {
return;
}
baseIdentifier = currentId;

/*
* Get the expression to depend on, which may involve PropertyLoads
* for member expressions
*/
let currentDep:
| NodePath<t.MemberExpression>
| NodePath<t.Identifier>
| NodePath<t.JSXIdentifier> = baseIdentifier;

while (true) {
const nextDep: null | NodePath<t.Node> = currentDep.parentPath;
if (
nextDep &&
nextDep.isMemberExpression() &&
isValidDependency(nextDep)
) {
currentDep = nextDep;
} else {
break;
}
}

dependency = currentDep;
} else {
baseIdentifier = path;
dependency = path;
}

/*
* Skip dependency path, as we already tried to recursively add it (+ all subexpressions)
* as a dependency.
*/
dependency.skip();
path.skip();

// Add the base identifier binding as a dependency.
const binding = baseIdentifier.scope.getBinding(baseIdentifier.node.name);
if (binding === undefined || !pureScopes.has(binding.scope)) {
return;
}
const idKey = String(addCapturedId(binding.identifier));

// Add the expression (potentially a memberexpr path) as a dependency.
let exprKey = idKey;
if (dependency.isMemberExpression()) {
let pathTokens = [];
let current: NodePath<Expression> = dependency;
while (current.isMemberExpression()) {
const property = current.get('property') as NodePath<t.Identifier>;
pathTokens.push(property.node.name);
current = current.get('object');
}

exprKey += '.' + pathTokens.reverse().join('.');
} else if (dependency.isJSXMemberExpression()) {
let pathTokens = [];
let current: NodePath<t.JSXMemberExpression | t.JSXIdentifier> =
dependency;
while (current.isJSXMemberExpression()) {
const property = current.get('property');
pathTokens.push(property.node.name);
current = current.get('object');
}
}

if (!seenPaths.has(exprKey)) {
let loweredDep: Place;
if (dependency.isJSXIdentifier()) {
loweredDep = lowerValueToTemporary(builder, {
kind: 'LoadLocal',
place: lowerIdentifier(builder, dependency),
loc: path.node.loc ?? GeneratedSource,
});
} else if (dependency.isJSXMemberExpression()) {
loweredDep = lowerJsxMemberExpression(builder, dependency);
} else {
loweredDep = lowerExpressionToTemporary(builder, dependency);
}
capturedRefs.add(loweredDep);
seenPaths.add(exprKey);
if (binding !== undefined && pureScopes.has(binding.scope)) {
capturedIds.add(binding.identifier);
}
}

Expand Down Expand Up @@ -4304,13 +4174,13 @@ function gatherCapturedDeps(
return;
} else if (path.isJSXElement()) {
handleMaybeDependency(path.get('openingElement'));
} else if (path.isMemberExpression() || path.isIdentifier()) {
} else if (path.isIdentifier()) {
handleMaybeDependency(path);
}
},
});

return {identifiers: [...capturedIds.keys()], refs: [...capturedRefs]};
return [...capturedIds.keys()];
}

function notNull<T>(value: T | null): value is T {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,7 @@ function collectHoistablePropertyLoadsImpl(
fn: HIRFunction,
context: CollectHoistablePropertyLoadsContext,
): ReadonlyMap<BlockId, BlockInfo> {
const functionExpressionLoads = collectFunctionExpressionFakeLoads(fn);
const actuallyEvaluatedTemporaries = new Map(
[...context.temporaries].filter(([id]) => !functionExpressionLoads.has(id)),
);

const nodes = collectNonNullsInBlocks(fn, {
...context,
temporaries: actuallyEvaluatedTemporaries,
});
const nodes = collectNonNullsInBlocks(fn, context);
propagateNonNull(fn, nodes, context.registry);

if (DEBUG_PRINT) {
Expand Down Expand Up @@ -598,30 +590,3 @@ function reduceMaybeOptionalChains(
}
} while (changed);
}

function collectFunctionExpressionFakeLoads(
fn: HIRFunction,
): Set<IdentifierId> {
const sources = new Map<IdentifierId, IdentifierId>();
const functionExpressionReferences = new Set<IdentifierId>();

for (const [_, block] of fn.body.blocks) {
for (const {lvalue, value} of block.instructions) {
if (
value.kind === 'FunctionExpression' ||
value.kind === 'ObjectMethod'
) {
for (const reference of value.loweredFunc.dependencies) {
let curr: IdentifierId | undefined = reference.identifier.id;
while (curr != null) {
functionExpressionReferences.add(curr);
curr = sources.get(curr);
}
}
} else if (value.kind === 'PropertyLoad') {
sources.set(lvalue.identifier.id, value.object.identifier.id);
}
}
}
return functionExpressionReferences;
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ const EnvironmentConfigSchema = z.object({
*/
enableUseTypeAnnotations: z.boolean().default(false),

enableFunctionDependencyRewrite: z.boolean().default(true),

/**
* Enables inference of optional dependency chains. Without this flag
* a property chain such as `props?.items?.foo` will infer as a dep on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,6 @@ export type ObjectProperty = {
};

export type LoweredFunction = {
dependencies: Array<Place>;
func: HIRFunction;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,21 @@ function visitPlace(
id: InstructionId,
place: Place,
{activeScopes, joined}: TraversalState,
isFnExpr: boolean,
): void {
// Here, behavior differs:

Check failure on line 253 in compiler/packages/babel-plugin-react-compiler/src/HIR/MergeOverlappingReactiveScopesHIR.ts

View workflow job for this annotation

GitHub Actions / Lint babel-plugin-react-compiler

Expected a block comment instead of consecutive line comments
// With LoweredFunction.dependencies, we never infer functions as mutating primitives
// as the layer of indirection (LoadLocal etc) ensures we don't treat this as a write
// We make the same "hack" in InferReactiveScopeVariables
/**
* If an instruction mutates an outer scope, flatten all scopes from the top
* of the stack to the mutated outer scope.
*/
const placeScope = getPlaceScope(id, place);
if (placeScope != null && isMutable({id} as any, place)) {
if (isFnExpr && place.identifier.type.kind === 'Primitive') {
return;
}
const placeScopeIdx = activeScopes.indexOf(placeScope);
if (placeScopeIdx !== -1 && placeScopeIdx !== activeScopes.length - 1) {
joined.union([placeScope, ...activeScopes.slice(placeScopeIdx + 1)]);
Expand All @@ -275,15 +283,21 @@ function getOverlappingReactiveScopes(
for (const instr of block.instructions) {
visitInstructionId(instr.id, context, state);
for (const place of eachInstructionOperand(instr)) {
visitPlace(instr.id, place, state);
visitPlace(
instr.id,
place,
state,
instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod',
);
}
for (const place of eachInstructionLValue(instr)) {
visitPlace(instr.id, place, state);
visitPlace(instr.id, place, state, false);
}
}
visitInstructionId(block.terminal.id, context, state);
for (const place of eachTerminalOperand(block.terminal)) {
visitPlace(block.terminal.id, place, state);
visitPlace(block.terminal.id, place, state, false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,8 @@ export function printInstructionValue(instrValue: ReactiveValue): string {
.split('\n')
.map(line => ` ${line}`)
.join('\n');
const deps = instrValue.loweredFunc.dependencies
.map(dep => printPlace(dep))
.join(',');
const context = instrValue.loweredFunc.func.context
.map(dep => printPlace(dep))
.map(dep => `${printPlace(dep)}`)
.join(',');
const effects =
instrValue.loweredFunc.func.effects
Expand All @@ -557,7 +554,7 @@ export function printInstructionValue(instrValue: ReactiveValue): string {
})
.join(', ') ?? '';
const type = printType(instrValue.loweredFunc.func.returnType).trim();
value = `${kind} ${name} @deps[${deps}] @context[${context}] @effects[${effects}]${type !== '' ? ` return${type}` : ''}:\n${fn}`;
value = `${kind} ${name} @context[${context}] @effects[${effects}]${type !== '' ? ` return${type}` : ''}:\n${fn}`;
break;
}
case 'TaggedTemplateExpression': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,8 @@ function collectDependencies(
}
for (const instr of block.instructions) {
if (
fn.env.config.enableFunctionDependencyRewrite &&
(instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod')
instr.value.kind === 'FunctionExpression' ||
instr.value.kind === 'ObjectMethod'
) {
context.declare(instr.lvalue.identifier, {
id: instr.id,
Expand Down
Loading

0 comments on commit 3819d3a

Please sign in to comment.