/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.geometry.Circle;
import org.elasticsearch.geometry.Geometry;
import org.elasticsearch.geometry.Point;
import org.elasticsearch.geometry.utils.WellKnownBinary;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialDisjoint;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialIntersects;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushFiltersToSource;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.FilterExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;

public class EnableSpatialDistancePushdown
extends PhysicalOptimizerRules.ParameterizedOptimizerRule<FilterExec, LocalPhysicalOptimizerContext> {
    @Override
    protected PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext ctx) {
        PhysicalPlan plan = filterExec;
        PhysicalPlan physicalPlan = filterExec.child();
        if (physicalPlan instanceof EsQueryExec) {
            EsQueryExec esQueryExec = (EsQueryExec)physicalPlan;
            plan = this.rewrite(ctx.foldCtx(), filterExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats()));
        } else {
            EvalExec evalExec;
            physicalPlan = filterExec.child();
            if (physicalPlan instanceof EvalExec && (physicalPlan = (evalExec = (EvalExec)physicalPlan).child()) instanceof EsQueryExec) {
                EsQueryExec esQueryExec = (EsQueryExec)physicalPlan;
                plan = this.rewriteBySplittingFilter(ctx.foldCtx(), filterExec, evalExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats()));
            }
        }
        return plan;
    }

    private FilterExec rewrite(FoldContext ctx, FilterExec filterExec, EsQueryExec esQueryExec, LucenePushdownPredicates lucenePushdownPredicates) {
        Expression rewritten = (Expression)filterExec.condition().transformDown(EsqlBinaryComparison.class, comparison -> {
            Expression patt1$temp;
            ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType());
            Expression patt0$temp = comparison.left();
            if (patt0$temp instanceof StDistance) {
                StDistance dist = (StDistance)patt0$temp;
                if (comparison.right().foldable()) {
                    return this.rewriteComparison(ctx, (EsqlBinaryComparison)comparison, dist, comparison.right(), comparisonType);
                }
            }
            if ((patt1$temp = comparison.right()) instanceof StDistance) {
                StDistance dist = (StDistance)patt1$temp;
                if (comparison.left().foldable()) {
                    return this.rewriteComparison(ctx, (EsqlBinaryComparison)comparison, dist, comparison.left(), ComparisonType.invert(comparisonType));
                }
            }
            return comparison;
        });
        if (!rewritten.equals((Object)filterExec.condition()) && PushFiltersToSource.canPushToSource(rewritten, lucenePushdownPredicates)) {
            return new FilterExec(filterExec.source(), esQueryExec, rewritten);
        }
        return filterExec;
    }

    private PhysicalPlan rewriteBySplittingFilter(FoldContext ctx, FilterExec filterExec, EvalExec evalExec, EsQueryExec esQueryExec, LucenePushdownPredicates lucenePushdownPredicates) {
        Map<NameId, StDistance> distances = this.getPushableDistances(evalExec.fields(), lucenePushdownPredicates);
        if (distances.isEmpty()) {
            return filterExec;
        }
        AttributeMap<Attribute> aliasReplacedBy = PushFiltersToSource.getAliasReplacedBy(evalExec);
        ArrayList<Expression> pushable = new ArrayList<Expression>();
        ArrayList<Expression> nonPushable = new ArrayList<Expression>();
        for (Expression exp : Predicates.splitAnd(filterExec.condition())) {
            Expression resExp = (Expression)exp.transformUp(ReferenceAttribute.class, r -> (Expression)aliasReplacedBy.resolve(r, r));
            Expression rewritten = this.rewriteDistanceFilters(ctx, resExp, distances);
            if (!rewritten.equals((Object)resExp) && PushFiltersToSource.canPushToSource(rewritten, lucenePushdownPredicates)) {
                pushable.add(rewritten);
                continue;
            }
            nonPushable.add(exp);
        }
        if (pushable.isEmpty()) {
            return filterExec;
        }
        FilterExec distanceFilter = new FilterExec(filterExec.source(), esQueryExec, Predicates.combineAnd(pushable));
        EvalExec newEval = new EvalExec(evalExec.source(), distanceFilter, evalExec.fields());
        if (nonPushable.isEmpty()) {
            return newEval;
        }
        return new FilterExec(filterExec.source(), newEval, Predicates.combineAnd(nonPushable));
    }

    private Map<NameId, StDistance> getPushableDistances(List<Alias> aliases, LucenePushdownPredicates lucenePushdownPredicates) {
        LinkedHashMap<NameId, StDistance> distances = new LinkedHashMap<NameId, StDistance>();
        aliases.forEach(alias -> {
            StDistance distance;
            Expression patt0$temp = alias.child();
            if (patt0$temp instanceof StDistance && (distance = (StDistance)patt0$temp).translatable(lucenePushdownPredicates)) {
                distances.put(alias.id(), distance);
            } else {
                ReferenceAttribute ref;
                Expression patt1$temp = alias.child();
                if (patt1$temp instanceof ReferenceAttribute && distances.containsKey((ref = (ReferenceAttribute)patt1$temp).id())) {
                    StDistance distance2 = (StDistance)distances.get(ref.id());
                    distances.put(alias.id(), distance2);
                }
            }
        });
        return distances;
    }

    private Expression rewriteDistanceFilters(FoldContext ctx, Expression expr, Map<NameId, StDistance> distances) {
        return (Expression)expr.transformDown(EsqlBinaryComparison.class, comparison -> {
            ReferenceAttribute r;
            StDistance dist;
            ReferenceAttribute r2;
            ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType());
            Expression patt0$temp = comparison.left();
            if (patt0$temp instanceof ReferenceAttribute && distances.containsKey((r2 = (ReferenceAttribute)patt0$temp).id()) && comparison.right().foldable()) {
                dist = (StDistance)distances.get(r2.id());
                return this.rewriteComparison(ctx, (EsqlBinaryComparison)comparison, dist, comparison.right(), comparisonType);
            }
            Expression patt1$temp = comparison.right();
            if (patt1$temp instanceof ReferenceAttribute && distances.containsKey((r = (ReferenceAttribute)patt1$temp).id()) && comparison.left().foldable()) {
                dist = (StDistance)distances.get(r.id());
                return this.rewriteComparison(ctx, (EsqlBinaryComparison)comparison, dist, comparison.left(), ComparisonType.invert(comparisonType));
            }
            return comparison;
        });
    }

    private Expression rewriteComparison(FoldContext ctx, EsqlBinaryComparison comparison, StDistance dist, Expression literal, ComparisonType comparisonType) {
        Object value = literal.fold(ctx);
        if (value instanceof Number) {
            Number number = (Number)value;
            if (dist.right().foldable()) {
                return this.rewriteDistanceFilter(ctx, comparison, dist.left(), dist.right(), number, comparisonType);
            }
            if (dist.left().foldable()) {
                return this.rewriteDistanceFilter(ctx, comparison, dist.right(), dist.left(), number, comparisonType);
            }
        }
        return comparison;
    }

    private Expression rewriteDistanceFilter(FoldContext ctx, EsqlBinaryComparison comparison, Expression spatialExp, Expression literalExp, Number number, ComparisonType comparisonType) {
        DataType shapeDataType = this.getShapeDataType(spatialExp);
        Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(ctx, literalExp);
        if (geometry instanceof Point) {
            Point point = (Point)geometry;
            double distance = number.doubleValue();
            Source source = comparison.source();
            if (comparisonType.lt) {
                distance = comparisonType.eq ? distance : Math.nextDown(distance);
                return new SpatialIntersects(source, spatialExp, (Expression)this.makeCircleLiteral(point, distance, literalExp, shapeDataType));
            }
            if (comparisonType.gt) {
                distance = comparisonType.eq ? distance : Math.nextUp(distance);
                return new SpatialDisjoint(source, spatialExp, (Expression)this.makeCircleLiteral(point, distance, literalExp, shapeDataType));
            }
            if (comparisonType.eq) {
                return new And(source, (Expression)new SpatialIntersects(source, spatialExp, (Expression)this.makeCircleLiteral(point, distance, literalExp, shapeDataType)), (Expression)new SpatialDisjoint(source, spatialExp, (Expression)this.makeCircleLiteral(point, Math.nextDown(distance), literalExp, shapeDataType)));
            }
        }
        return comparison;
    }

    private Literal makeCircleLiteral(Point point, double distance, Expression literalExpression, DataType shapeDataType) {
        Circle circle = new Circle(point.getX(), point.getY(), distance);
        byte[] wkb = WellKnownBinary.toWKB((Geometry)circle, (ByteOrder)ByteOrder.LITTLE_ENDIAN);
        return new Literal(literalExpression.source(), (Object)new BytesRef(wkb), shapeDataType);
    }

    private DataType getShapeDataType(Expression expression) {
        return switch (expression.dataType()) {
            case DataType.GEO_POINT, DataType.GEO_SHAPE -> DataType.GEO_SHAPE;
            case DataType.CARTESIAN_POINT, DataType.CARTESIAN_SHAPE -> DataType.CARTESIAN_SHAPE;
            default -> throw new IllegalArgumentException("Unsupported spatial data type: " + String.valueOf(expression.dataType()));
        };
    }

    static enum ComparisonType {
        LTE(true, false, true),
        LT(true, false, false),
        GTE(false, true, true),
        GT(false, true, false),
        EQ(false, false, true);

        private final boolean lt;
        private final boolean gt;
        private final boolean eq;

        private ComparisonType(boolean lt, boolean gt, boolean eq) {
            this.lt = lt;
            this.gt = gt;
            this.eq = eq;
        }

        static ComparisonType from(EsqlBinaryComparison.BinaryComparisonOperation op) {
            return switch (op) {
                case EsqlBinaryComparison.BinaryComparisonOperation.LT -> LT;
                case EsqlBinaryComparison.BinaryComparisonOperation.LTE -> LTE;
                case EsqlBinaryComparison.BinaryComparisonOperation.GT -> GT;
                case EsqlBinaryComparison.BinaryComparisonOperation.GTE -> GTE;
                default -> EQ;
            };
        }

        static ComparisonType invert(ComparisonType comparisonType) {
            return switch (comparisonType.ordinal()) {
                case 1 -> GT;
                case 0 -> GTE;
                case 3 -> LT;
                case 2 -> LTE;
                default -> EQ;
            };
        }
    }
}

