/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Einsum, util } from '@tensorflow/tfjs-core';
import { multiply } from './Multiply';
import { reshape } from './Reshape';
import { sum } from './Sum';
import { transpose } from './Transpose';
export function einsum(args) {
    const { inputs, backend, attrs } = args;
    const { equation } = attrs;
    const tensors = inputs;
    const { allDims, summedDims, idDims } = backend_util.decodeEinsumEquation(equation, tensors.length);
    backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors);
    const { path, steps } = backend_util.getEinsumComputePath(summedDims, idDims);
    const nSteps = steps.length;
    let out = null;
    let numDimsRemaining = allDims.length;
    const tensorsToDispose = [];
    for (let i = 0; i < nSteps; ++i) {
        for (const idTerm of steps[i]) {
            const { permutationIndices: perm, expandDims: dimsToExpand } = backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
            let x;
            if (backend_util.isIdentityPermutation(perm)) {
                x = tensors[idTerm];
            }
            else {
                x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
                tensorsToDispose.push(x);
            }
            const targetShape = x.shape.slice();
            for (let k = 0; k < dimsToExpand.length; ++k) {
                targetShape.splice(dimsToExpand[k], 0, 1);
            }
            if (!util.arraysEqual(x.shape, targetShape)) {
                x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } });
                tensorsToDispose.push(x);
            }
            if (out === null) {
                out = x;
            }
            else {
                // tslint:disable-next-line: no-unnecessary-type-assertion
                out = multiply({ inputs: { a: x, b: out }, backend });
                tensorsToDispose.push(out);
            }
        }
        if (i < nSteps - 1) {
            if (path[i] >= 0) {
                out = sum({
                    inputs: { x: out },
                    backend,
                    attrs: {
                        axis: path[i] - (allDims.length - numDimsRemaining),
                        keepDims: false
                    }
                });
                tensorsToDispose.push(out);
            }
            numDimsRemaining--;
        }
    }
    // Clean up intermediate tensors.
    for (const tensorInfo of tensorsToDispose) {
        if (tensorInfo === out) {
            continue;
        }
        backend.disposeIntermediateTensorInfo(tensorInfo);
    }
    return out;
}
export const einsumConfig = {
    kernelName: Einsum,
    backendName: 'webgl',
    kernelFunc: einsum
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRWluc3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9rZXJuZWxzL0VpbnN1bS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFFLE1BQU0sRUFBMkUsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJMUksT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUNwQyxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2xDLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxPQUFPLENBQUM7QUFDMUIsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUV0QyxNQUFNLFVBQVUsTUFBTSxDQUNsQixJQUN5RTtJQUUzRSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUN6QixNQUFNLE9BQU8sR0FBRyxNQUFrQixDQUFDO0lBRW5DLE1BQU0sRUFBQyxPQUFPLEVBQUUsVUFBVSxFQUFFLE1BQU0sRUFBQyxHQUMvQixZQUFZLENBQUMsb0JBQW9CLENBQUMsUUFBUSxFQUFFLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQztJQUNoRSxZQUFZLENBQUMsbUJBQW1CLENBQUMsT0FBTyxDQUFDLE1BQU0sRUFBRSxNQUFNLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDbEUsTUFBTSxFQUFDLElBQUksRUFBRSxLQUFLLEVBQUMsR0FBRyxZQUFZLENBQUMsb0JBQW9CLENBQUMsVUFBVSxFQUFFLE1BQU0sQ0FBQyxDQUFDO0lBRTVFLE1BQU0sTUFBTSxHQUFHLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFDNUIsSUFBSSxHQUFHLEdBQW9CLElBQUksQ0FBQztJQUNoQyxJQUFJLGdCQUFnQixHQUFHLE9BQU8sQ0FBQyxNQUFNLENBQUM7SUFDdEMsTUFBTSxnQkFBZ0IsR0FBaUIsRUFBRSxDQUFDO0lBQzFDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDL0IsS0FBSyxNQUFNLE1BQU0sSUFBSSxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUU7WUFDN0IsTUFBTSxFQUFDLGtCQUFrQixFQUFFLElBQUksRUFBRSxVQUFVLEVBQUUsWUFBWSxFQUFDLEdBQ3RELFlBQVksQ0FBQyxvQkFBb0IsQ0FBQyxnQkFBZ0IsRUFBRSxNQUFNLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQztZQUN4RSxJQUFJLENBQWEsQ0FBQztZQUNsQixJQUFJLFlBQVksQ0FBQyxxQkFBcUIsQ0FBQyxJQUFJLENBQUMsRUFBRTtnQkFDNUMsQ0FBQyxHQUFHLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQzthQUNyQjtpQkFBTTtnQkFDTCxDQUFDLEdBQUcsU0FBUyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE9BQU8sQ0FBQyxNQUFNLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxJQUFJLEVBQUMsRUFBQyxDQUFDLENBQUM7Z0JBQ3RFLGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQzthQUMxQjtZQUNELE1BQU0sV0FBVyxHQUFhLENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxFQUFFLENBQUM7WUFDOUMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFlBQVksQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7Z0JBQzVDLFdBQVcsQ0FBQyxNQUFNLENBQUMsWUFBWSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQzthQUMzQztZQUVELElBQUksQ0FBQyxJQUFJLENBQUMsV0FBVyxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsV0FBVyxDQUFDLEVBQUU7Z0JBQzNDLENBQUMsR0FBRyxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFdBQVcsRUFBQyxFQUFDLENBQUMsQ0FBQztnQkFDakUsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO2FBQzFCO1lBQ0QsSUFBSSxHQUFHLEtBQUssSUFBSSxFQUFFO2dCQUNoQixHQUFHLEdBQUcsQ0FBQyxDQUFDO2FBQ1Q7aUJBQU07Z0JBQ0wsMERBQTBEO2dCQUMxRCxHQUFHLEdBQUcsUUFBUSxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLEVBQUUsR0FBRyxFQUFDLEVBQUUsT0FBTyxFQUFDLENBQWUsQ0FBQztnQkFDaEUsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDO2FBQzVCO1NBQ0Y7UUFDRCxJQUFJLENBQUMsR0FBRyxNQUFNLEdBQUcsQ0FBQyxFQUFFO1lBQ2xCLElBQUksSUFBSSxDQUFDLENBQUMsQ0FBQyxJQUFJLENBQUMsRUFBRTtnQkFDaEIsR0FBRyxHQUFHLEdBQUcsQ0FBQztvQkFDUixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsR0FBRyxFQUFDO29CQUNoQixPQUFPO29CQUNQLEtBQUssRUFBRTt3QkFDTCxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sR0FBRyxnQkFBZ0IsQ0FBQzt3QkFDbkQsUUFBUSxFQUFFLEtBQUs7cUJBQ2hCO2lCQUNGLENBQUMsQ0FBQztnQkFDSCxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUM7YUFDNUI7WUFDRCxnQkFBZ0IsRUFBRSxDQUFDO1NBQ3BCO0tBQ0Y7SUFFRCxpQ0FBaUM7SUFDakMsS0FBSyxNQUFNLFVBQVUsSUFBSSxnQkFBZ0IsRUFBRTtRQUN6QyxJQUFJLFVBQVUsS0FBSyxHQUFHLEVBQUU7WUFDdEIsU0FBUztTQUNWO1FBQ0QsT0FBTyxDQUFDLDZCQUE2QixDQUFDLFVBQVUsQ0FBQyxDQUFDO0tBQ25EO0lBRUQsT0FBTyxHQUFHLENBQUM7QUFDYixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sWUFBWSxHQUFpQjtJQUN4QyxVQUFVLEVBQUUsTUFBTTtJQUNsQixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsTUFBK0I7Q0FDNUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIxIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIEVpbnN1bSwgRWluc3VtQXR0cnMsIEVpbnN1bUlucHV0cywgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBUZW5zb3IsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7bXVsdGlwbHl9IGZyb20gJy4vTXVsdGlwbHknO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHtzdW19IGZyb20gJy4vU3VtJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBlaW5zdW0oXG4gICAgYXJnczpcbiAgICAgICAge2lucHV0czogRWluc3VtSW5wdXRzLCBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMLCBhdHRyczogRWluc3VtQXR0cnN9KTpcbiAgICBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge2VxdWF0aW9ufSA9IGF0dHJzO1xuICBjb25zdCB0ZW5zb3JzID0gaW5wdXRzIGFzIFRlbnNvcltdO1xuXG4gIGNvbnN0IHthbGxEaW1zLCBzdW1tZWREaW1zLCBpZERpbXN9ID1cbiAgICAgIGJhY2tlbmRfdXRpbC5kZWNvZGVFaW5zdW1FcXVhdGlvbihlcXVhdGlvbiwgdGVuc29ycy5sZW5ndGgpO1xuICBiYWNrZW5kX3V0aWwuY2hlY2tFaW5zdW1EaW1TaXplcyhhbGxEaW1zLmxlbmd0aCwgaWREaW1zLCB0ZW5zb3JzKTtcbiAgY29uc3Qge3BhdGgsIHN0ZXBzfSA9IGJhY2tlbmRfdXRpbC5nZXRFaW5zdW1Db21wdXRlUGF0aChzdW1tZWREaW1zLCBpZERpbXMpO1xuXG4gIGNvbnN0IG5TdGVwcyA9IHN0ZXBzLmxlbmd0aDtcbiAgbGV0IG91dDogVGVuc29ySW5mb3xudWxsID0gbnVsbDtcbiAgbGV0IG51bURpbXNSZW1haW5pbmcgPSBhbGxEaW1zLmxlbmd0aDtcbiAgY29uc3QgdGVuc29yc1RvRGlzcG9zZTogVGVuc29ySW5mb1tdID0gW107XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgblN0ZXBzOyArK2kpIHtcbiAgICBmb3IgKGNvbnN0IGlkVGVybSBvZiBzdGVwc1tpXSkge1xuICAgICAgY29uc3Qge3Blcm11dGF0aW9uSW5kaWNlczogcGVybSwgZXhwYW5kRGltczogZGltc1RvRXhwYW5kfSA9XG4gICAgICAgICAgYmFja2VuZF91dGlsLmdldEVpbnN1bVBlcm11dGF0aW9uKG51bURpbXNSZW1haW5pbmcsIGlkRGltc1tpZFRlcm1dKTtcbiAgICAgIGxldCB4OiBUZW5zb3JJbmZvO1xuICAgICAgaWYgKGJhY2tlbmRfdXRpbC5pc0lkZW50aXR5UGVybXV0YXRpb24ocGVybSkpIHtcbiAgICAgICAgeCA9IHRlbnNvcnNbaWRUZXJtXTtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIHggPSB0cmFuc3Bvc2Uoe2lucHV0czoge3g6IHRlbnNvcnNbaWRUZXJtXX0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybX19KTtcbiAgICAgICAgdGVuc29yc1RvRGlzcG9zZS5wdXNoKHgpO1xuICAgICAgfVxuICAgICAgY29uc3QgdGFyZ2V0U2hhcGU6IG51bWJlcltdID0geC5zaGFwZS5zbGljZSgpO1xuICAgICAgZm9yIChsZXQgayA9IDA7IGsgPCBkaW1zVG9FeHBhbmQubGVuZ3RoOyArK2spIHtcbiAgICAgICAgdGFyZ2V0U2hhcGUuc3BsaWNlKGRpbXNUb0V4cGFuZFtrXSwgMCwgMSk7XG4gICAgICB9XG5cbiAgICAgIGlmICghdXRpbC5hcnJheXNFcXVhbCh4LnNoYXBlLCB0YXJnZXRTaGFwZSkpIHtcbiAgICAgICAgeCA9IHJlc2hhcGUoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiB0YXJnZXRTaGFwZX19KTtcbiAgICAgICAgdGVuc29yc1RvRGlzcG9zZS5wdXNoKHgpO1xuICAgICAgfVxuICAgICAgaWYgKG91dCA9PT0gbnVsbCkge1xuICAgICAgICBvdXQgPSB4O1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOiBuby11bm5lY2Vzc2FyeS10eXBlLWFzc2VydGlvblxuICAgICAgICBvdXQgPSBtdWx0aXBseSh7aW5wdXRzOiB7YTogeCwgYjogb3V0fSwgYmFja2VuZH0pIGFzIFRlbnNvckluZm87XG4gICAgICAgIHRlbnNvcnNUb0Rpc3Bvc2UucHVzaChvdXQpO1xuICAgICAgfVxuICAgIH1cbiAgICBpZiAoaSA8IG5TdGVwcyAtIDEpIHtcbiAgICAgIGlmIChwYXRoW2ldID49IDApIHtcbiAgICAgICAgb3V0ID0gc3VtKHtcbiAgICAgICAgICBpbnB1dHM6IHt4OiBvdXR9LFxuICAgICAgICAgIGJhY2tlbmQsXG4gICAgICAgICAgYXR0cnM6IHtcbiAgICAgICAgICAgIGF4aXM6IHBhdGhbaV0gLSAoYWxsRGltcy5sZW5ndGggLSBudW1EaW1zUmVtYWluaW5nKSxcbiAgICAgICAgICAgIGtlZXBEaW1zOiBmYWxzZVxuICAgICAgICAgIH1cbiAgICAgICAgfSk7XG4gICAgICAgIHRlbnNvcnNUb0Rpc3Bvc2UucHVzaChvdXQpO1xuICAgICAgfVxuICAgICAgbnVtRGltc1JlbWFpbmluZy0tO1xuICAgIH1cbiAgfVxuXG4gIC8vIENsZWFuIHVwIGludGVybWVkaWF0ZSB0ZW5zb3JzLlxuICBmb3IgKGNvbnN0IHRlbnNvckluZm8gb2YgdGVuc29yc1RvRGlzcG9zZSkge1xuICAgIGlmICh0ZW5zb3JJbmZvID09PSBvdXQpIHtcbiAgICAgIGNvbnRpbnVlO1xuICAgIH1cbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHRlbnNvckluZm8pO1xuICB9XG5cbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IGVpbnN1bUNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBFaW5zdW0sXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiBlaW5zdW0gYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19