/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ml;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.parkservices.AnomalyDescriptor;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.core.action.ActionListener;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.feature.FeatureManager;
import org.opensearch.timeseries.indices.IndexManagement;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.IntermediateResult;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.ml.SingleStreamModelIdMapper;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IndexableResult;
import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker;
import org.opensearch.timeseries.util.DataUtil;

public abstract class ModelManager<RCFModelType extends ThresholdedRandomCutForest, IndexableResultType extends IndexableResult, IntermediateResultType extends IntermediateResult<IndexableResultType>, IndexType extends Enum<IndexType>, IndexManagementType extends IndexManagement<IndexType>, CheckpointDaoType extends CheckpointDao<RCFModelType, IndexType, IndexManagementType>, CheckpointWriteWorkerType extends CheckpointWriteWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType>, ColdStarterType extends ModelColdStart<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType, CheckpointWriteWorkerType, IndexableResultType>> {
    private static final Logger LOG = LogManager.getLogger(ModelManager.class);
    protected final int rcfNumTrees;
    protected final int rcfNumSamplesInTree;
    protected final int rcfNumMinSamples;
    protected ColdStarterType coldStarter;
    protected MemoryTracker memoryTracker;
    protected final Clock clock;
    protected FeatureManager featureManager;
    protected final CheckpointDaoType checkpointDao;

    public ModelManager(int rcfNumTrees, int rcfNumSamplesInTree, int rcfNumMinSamples, ColdStarterType coldStarter, MemoryTracker memoryTracker, Clock clock, FeatureManager featureManager, CheckpointDaoType checkpointDao) {
        this.rcfNumTrees = rcfNumTrees;
        this.rcfNumSamplesInTree = rcfNumSamplesInTree;
        this.rcfNumMinSamples = rcfNumMinSamples;
        this.coldStarter = coldStarter;
        this.memoryTracker = memoryTracker;
        this.clock = clock;
        this.featureManager = featureManager;
        this.checkpointDao = checkpointDao;
    }

    public IntermediateResultType getResult(Sample sample, ModelState<RCFModelType> modelState, String modelId, Config config, String taskId) {
        IntermediateResultType result = this.createEmptyResult();
        if (modelState != null) {
            Optional<RCFModelType> entityModel = modelState.getModel();
            if (entityModel.isEmpty()) {
                ((ModelColdStart)this.coldStarter).trainModelFromExistingSamples(modelState, config, taskId);
            }
            if (modelState.getModel().isPresent()) {
                result = this.score(sample, modelId, modelState, config);
            } else {
                modelState.addSample(sample);
            }
        }
        return result;
    }

    public void clearModels(String detectorId, Map<String, ?> models, ActionListener<Void> listener) {
        Iterator<String> id = models.keySet().iterator();
        this.clearModelForIterator(detectorId, models, id, listener);
    }

    protected void clearModelForIterator(String detectorId, Map<String, ?> models, Iterator<String> idIter, ActionListener<Void> listener) {
        if (idIter.hasNext()) {
            String modelId = idIter.next();
            if (SingleStreamModelIdMapper.getConfigIdForModelId(modelId).equals(detectorId)) {
                models.remove(modelId);
                ((CheckpointDao)this.checkpointDao).deleteModelCheckpoint(modelId, (ActionListener<Void>)ActionListener.wrap(r -> this.clearModelForIterator(detectorId, models, idIter, listener), arg_0 -> listener.onFailure(arg_0)));
            } else {
                this.clearModelForIterator(detectorId, models, idIter, listener);
            }
        } else {
            listener.onResponse(null);
        }
    }

    public <RCFDescriptor extends AnomalyDescriptor> IntermediateResultType score(Sample sample, String modelId, ModelState<RCFModelType> modelState, Config config) {
        Optional<RCFModelType> model = modelState.getModel();
        try {
            if (model != null && model.isPresent()) {
                ThresholdedRandomCutForest rcfModel = (ThresholdedRandomCutForest)model.get();
                if (!modelState.getSamples().isEmpty()) {
                    for (Sample unProcessedSample : modelState.getSamples()) {
                        double[] unProcessedPoint = unProcessedSample.getValueList();
                        int[] missingIndices = DataUtil.generateMissingIndicesArray(unProcessedPoint);
                        rcfModel.process(unProcessedPoint, unProcessedSample.getDataEndTime().getEpochSecond(), missingIndices);
                    }
                    modelState.clearSamples();
                }
                Iterator<Sample> iterator = this.score(sample, config, rcfModel);
                return (IntermediateResultType)iterator;
            }
        }
        catch (Exception e) {
            LOG.error((Message)new ParameterizedMessage("Fail to score for [{}] at [{}]: model Id [{}], feature [{}]", new Object[]{modelState.getEntity().isEmpty() ? modelState.getConfigId() : modelState.getEntity().get(), sample.getDataEndTime().getEpochSecond(), modelId, Arrays.toString(sample.getValueList())}), (Throwable)e);
            throw e;
        }
        finally {
            modelState.setLastUsedTime(this.clock.instant());
            modelState.setLastSeenDataEndTime(sample.getDataEndTime());
        }
        return this.createEmptyResult();
    }

    public <RCFDescriptor extends AnomalyDescriptor> IntermediateResultType score(Sample sample, Config config, RCFModelType rcfModel) {
        double[] point = sample.getValueList();
        int[] missingValues = DataUtil.generateMissingIndicesArray(point);
        AnomalyDescriptor lastResult = rcfModel.process(point, sample.getDataEndTime().getEpochSecond(), missingValues);
        if (lastResult != null) {
            return this.toResult(rcfModel.getForest(), lastResult, point, missingValues != null, config);
        }
        return this.createEmptyResult();
    }

    protected abstract IntermediateResultType createEmptyResult();

    protected abstract <RCFDescriptor extends AnomalyDescriptor> IntermediateResultType toResult(RandomCutForest var1, RCFDescriptor var2, double[] var3, boolean var4, Config var5);

    public static enum ModelType {
        RCF("rcf"),
        THRESHOLD("threshold"),
        TRCF("trcf"),
        RCFCASTER("rcf_caster");

        private String name;

        private ModelType(String name) {
            this.name = name;
        }

        public String getName() {
            return this.name;
        }
    }
}

