Skip to content

Commit 6adc4cf

Browse files
authored
Changed inference logic for exception groups to more closely match the runtime. If a non-base exception is targeted, the inferred type is now ExceptionGroup rather than BaseExceptionGroup. This addresses #9466. (#9467)
1 parent 7b8ce24 commit 6adc4cf

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

packages/pyright-internal/src/analyzer/typeEvaluator.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -19446,6 +19446,7 @@ export function createTypeEvaluator(
1944619446

1944719447
const exceptionTypeResult = getTypeOfExpression(node.d.typeExpr!);
1944819448
const exceptionTypes = exceptionTypeResult.type;
19449+
let includesBaseException = false;
1944919450

1945019451
function getExceptionType(exceptionType: Type, errorNode: ExpressionNode) {
1945119452
exceptionType = makeTopLevelTypeVarsConcrete(exceptionType);
@@ -19455,6 +19456,9 @@ export function createTypeEvaluator(
1945519456
}
1945619457

1945719458
if (isInstantiableClass(exceptionType)) {
19459+
if (ClassType.isBuiltIn(exceptionType, 'BaseException')) {
19460+
includesBaseException = true;
19461+
}
1945819462
return ClassType.cloneAsInstance(exceptionType);
1945919463
}
1946019464

@@ -19492,9 +19496,13 @@ export function createTypeEvaluator(
1949219496
return getExceptionType(subType, node.d.typeExpr!);
1949319497
});
1949419498

19495-
// If this is an except group, wrap the exception type in an BaseExceptionGroup.
19499+
// If this is an except group, wrap the exception type in an ExceptionGroup
19500+
// or BaseExceptionGroup depending on whether the target exception is
19501+
// a BaseException.
1949619502
if (node.d.isExceptGroup) {
19497-
targetType = getBuiltInObject(node, 'BaseExceptionGroup', [targetType]);
19503+
targetType = getBuiltInObject(node, includesBaseException ? 'BaseExceptionGroup' : 'ExceptionGroup', [
19504+
targetType,
19505+
]);
1949819506
}
1949919507

1950019508
if (node.d.name) {

packages/pyright-internal/src/tests/samples/exceptionGroup1.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def func1():
99

1010
# This should generate an error if using Python 3.10 or earlier.
1111
except* ValueError as e:
12-
reveal_type(e, expected_text="BaseExceptionGroup[ValueError]")
12+
reveal_type(e, expected_text="ExceptionGroup[ValueError]")
1313
pass
1414

1515
# This should generate an error if using Python 3.10 or earlier.
@@ -105,3 +105,19 @@ def inner():
105105
# return is not allowed in an except* block.
106106
return
107107

108+
109+
110+
def func8():
111+
112+
try:
113+
pass
114+
115+
# This should generate an error if using Python 3.10 or earlier.
116+
except* (ValueError, FloatingPointError) as e:
117+
reveal_type(e, expected_text="ExceptionGroup[ValueError | FloatingPointError]")
118+
pass
119+
120+
# This should generate an error if using Python 3.10 or earlier.
121+
except* BaseException as e:
122+
reveal_type(e, expected_text="BaseExceptionGroup[BaseException]")
123+
pass

packages/pyright-internal/src/tests/typeEvaluator7.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ test('exceptionGroup1', () => {
971971

972972
configOptions.defaultPythonVersion = pythonVersion3_10;
973973
const analysisResults1 = TestUtils.typeAnalyzeSampleFiles(['exceptionGroup1.py'], configOptions);
974-
TestUtils.validateResults(analysisResults1, 28);
974+
TestUtils.validateResults(analysisResults1, 34);
975975

976976
configOptions.defaultPythonVersion = pythonVersion3_11;
977977
const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['exceptionGroup1.py'], configOptions);

0 commit comments

Comments
 (0)