Skip to content

Commit fdc1b08

Browse files
philloooomoz-wptsync-bot
authored andcommitted
Bug 1949929 [wpt PR 50891] - webnn: implement quantizeLinear on coreml, a=testonly
Automatic update from web-platform-tests webnn: implement quantizeLinear on coreml Implement `quantizeLinear` using CoreML `quantize` op. Add a cast_to_supported_type flag to `webnn_conformance_test` that skips return early when input/output doesn't match opSupportLimits. `buildAndExecuteGraphFunc` already has logic that handles implicit casting. Change-Id: I370ffccfa51b3156216c4504622bfb81543ea853 Bug: 385173305 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6289944 Reviewed-by: Reilly Grant <[email protected]> Commit-Queue: Phillis Tang <[email protected]> Cr-Commit-Position: refs/heads/main@{#1423539} -- wpt-commits: f33579b93fb82e61f0c5302f05caa94bc0a274d7 wpt-pr: 50891
1 parent 383a786 commit fdc1b08

File tree

3 files changed

+76
-16
lines changed

3 files changed

+76
-16
lines changed

testing/web-platform/tests/webnn/conformance_tests/dequantizeLinear.https.any.js

+2-1
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,8 @@ const dequantizeLinearTests = [
624624
if (navigator.ml) {
625625
dequantizeLinearTests.forEach((test) => {
626626
webnn_conformance_test(
627-
buildAndExecuteGraph, getDequantizeLinearPrecisionTolerance, test);
627+
buildAndExecuteGraph, getDequantizeLinearPrecisionTolerance, test,
628+
/*cast_to_supported_type=*/ true);
628629
});
629630
} else {
630631
test(() => assert_implements(navigator.ml, 'missing navigator.ml'));

testing/web-platform/tests/webnn/conformance_tests/quantizeLinear.https.any.js

+52-6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,51 @@ const quantizeLinearTests = [
6161
}
6262
},
6363
{
64-
'name': 'quantizeLinear float32 1D constant tensor broadcasting zeroPoint',
64+
'name': 'quantizeLinear float32 1D constant tensor',
65+
'graph': {
66+
'inputs': {
67+
'quantizeLinearInput': {
68+
'data': [
69+
-2.549168109893799, -4.794857501983643, 8.413617134094238,
70+
6.108623504638672
71+
],
72+
'descriptor': {shape: [4], dataType: 'float32'},
73+
'constant': true
74+
},
75+
'quantizeLinearScale': {
76+
'data': [
77+
9.343092918395996,
78+
0.2800687253475189,
79+
4.617084980010986,
80+
1.1202747821807861,
81+
],
82+
'descriptor': {shape: [4], dataType: 'float32'},
83+
'constant': true
84+
},
85+
'quantizeLinearZeroPoint': {
86+
'data': [128, 128, 128, 128],
87+
'descriptor': {shape: [4], dataType: 'uint8'},
88+
'constant': true
89+
}
90+
},
91+
'operators': [{
92+
'name': 'quantizeLinear',
93+
'arguments': [
94+
{'input': 'quantizeLinearInput'}, {'scale': 'quantizeLinearScale'},
95+
{'zeroPoint': 'quantizeLinearZeroPoint'}
96+
],
97+
'outputs': 'quantizeLinearOutput'
98+
}],
99+
'expectedOutputs': {
100+
'quantizeLinearOutput': {
101+
'data': [128, 111, 130, 133],
102+
'descriptor': {shape: [4], dataType: 'uint8'}
103+
}
104+
}
105+
}
106+
},
107+
{
108+
'name': 'quantizeLinear float32 1D constant tensor with negative scale',
65109
'graph': {
66110
'inputs': {
67111
'quantizeLinearInput': {
@@ -185,8 +229,7 @@ const quantizeLinearTests = [
185229
}
186230
},
187231
{
188-
'name':
189-
'per-tensor quantizeLinear for float32 4D constant',
232+
'name': 'per-tensor quantizeLinear for float32 4D constant',
190233
'graph': {
191234
'inputs': {
192235
'quantizeLinearInput': {
@@ -198,8 +241,10 @@ const quantizeLinearTests = [
198241
'constant': true
199242
},
200243
'quantizeLinearScale': {
201-
'data': [0.2800687253475189, -4.617084980010986, 0.2800687253475189,
202-
-4.617084980010986],
244+
'data': [
245+
0.2800687253475189, -4.617084980010986, 0.2800687253475189,
246+
-4.617084980010986
247+
],
203248
'descriptor': {shape: [2, 2], dataType: 'float32'},
204249
'constant': true
205250
},
@@ -535,7 +580,8 @@ const quantizeLinearTests = [
535580
if (navigator.ml) {
536581
quantizeLinearTests.forEach((test) => {
537582
webnn_conformance_test(
538-
buildAndExecuteGraph, getQuantizeLinearPrecisionTolerance, test);
583+
buildAndExecuteGraph, getQuantizeLinearPrecisionTolerance, test,
584+
/*cast_to_supported_type=*/ true);
539585
});
540586
} else {
541587
test(() => assert_implements(navigator.ml, 'missing navigator.ml'));

testing/web-platform/tests/webnn/resources/utils.js

+22-9
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,17 @@ const kIntTypes =
9393
['uint4', 'int4', 'uint8', 'int8', 'uint32', 'int32', 'uint64', 'int64'];
9494
const kFloatTypes = ['float16', 'float32'];
9595

96-
const findCompatibleType = (dataType, supportedTypes) => {
96+
const findCompatibleType = (dataType, supportedTypes, castOpSupportLimits) => {
97+
if (!castOpSupportLimits.input.dataTypes.includes(dataType)) {
98+
// Cannot cast from `dataType` to any other type.
99+
return null;
100+
}
101+
97102
for (let supportedType of supportedTypes) {
98-
if (kIntTypes.includes(dataType)) {
99-
if (kIntTypes.indexOf(supportedType) > kIntTypes.indexOf(dataType)) {
100-
return supportedType;
101-
}
103+
if (kIntTypes.includes(dataType) &&
104+
castOpSupportLimits.output.dataTypes.includes(dataType) &&
105+
kIntTypes.indexOf(supportedType) > kIntTypes.indexOf(dataType)) {
106+
return supportedType;
102107
}
103108

104109
if (kFloatTypes.includes(dataType)) {
@@ -483,7 +488,8 @@ const createOperand = (context, builder, operandName, resources) => {
483488
// If input data type is not supported on current platform, attempt to use
484489
// a supported type to pass the data, then cast back to original type.
485490
if (!supportedDataTypes.includes(dataType)) {
486-
const compatibleType = findCompatibleType(dataType, supportedDataTypes);
491+
const compatibleType = findCompatibleType(
492+
dataType, supportedDataTypes, context.opSupportLimits().cast);
487493
if (compatibleType) {
488494
descriptor.castedType = compatibleType;
489495
descriptor.dataType = compatibleType;
@@ -820,7 +826,8 @@ const buildAndExecuteGraph = async (context, builder, graphResources) => {
820826
expectedDescriptor.dataType)) {
821827
const compatibleType = findCompatibleType(
822828
expectedDescriptor.dataType,
823-
context.opSupportLimits().output.dataTypes);
829+
context.opSupportLimits().output.dataTypes,
830+
context.opSupportLimits().cast);
824831
outputOperands[i] = builder.cast(outputOperands[i], compatibleType);
825832
expectedDescriptor.castedType = compatibleType;
826833
}
@@ -989,8 +996,12 @@ const getReducedElementCount =
989996
1;
990997
};
991998

999+
// `cast_to_supported_type` will check if the graph input/output is
1000+
// supported by current context, if not, it will try to find a compatible
1001+
// type that's supported and use that type, then cast back to original type.
9921002
const webnn_conformance_test =
993-
(buildAndExecuteGraphFunc, toleranceFunc, testResources) => {
1003+
(buildAndExecuteGraphFunc, toleranceFunc, testResources,
1004+
cast_to_supported_type = false) => {
9941005
promise_test(async () => {
9951006
let context;
9961007
try {
@@ -999,7 +1010,9 @@ const webnn_conformance_test =
9991010
throw new AssertionError(
10001011
`Unable to create context for ${variant} variant. ${e}`);
10011012
}
1002-
validateContextSupportsGraph(context, testResources.graph);
1013+
if (!cast_to_supported_type) {
1014+
validateContextSupportsGraph(context, testResources.graph);
1015+
}
10031016
const builder = new MLGraphBuilder(context);
10041017
const {result, intermediateOperands} = await buildAndExecuteGraphFunc(
10051018
context, builder, testResources.graph);

0 commit comments

Comments
 (0)