/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action.task;

import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.inference.action.task.FlowTask;

public class StreamingTaskManager {
    private final TaskManager taskManager;
    private final ThreadPool threadPool;

    @Inject
    public StreamingTaskManager(TransportService transportService, ThreadPool threadPool) {
        this(transportService.getTaskManager(), threadPool);
    }

    StreamingTaskManager(TaskManager taskManager, ThreadPool threadPool) {
        this.taskManager = taskManager;
        this.threadPool = threadPool;
    }

    public <E> Flow.Processor<E, E> create(String taskType, String taskAction) {
        return new TaskBackedProcessor(taskType, taskAction);
    }

    public static List<NamedWriteableRegistry.Entry> namedWriteables() {
        return List.of(new NamedWriteableRegistry.Entry(Task.Status.class, "streaming_task_manager_flow_status", FlowTask.FlowStatus.STREAM_READER));
    }

    private class TaskBackedProcessor<E>
    implements Flow.Processor<E, E> {
        private static final Logger log = LogManager.getLogger(TaskBackedProcessor.class);
        private final String taskType;
        private final String taskAction;
        private Flow.Subscriber<? super E> downstream;
        private Flow.Subscription upstream;
        private FlowTask task;
        private final AtomicBoolean isClosed = new AtomicBoolean(false);
        private final AtomicLong pendingRequests = new AtomicLong();

        private TaskBackedProcessor(String taskType, String taskAction) {
            this.taskType = taskType;
            this.taskAction = taskAction;
        }

        @Override
        public void subscribe(Flow.Subscriber<? super E> subscriber) {
            if (this.downstream != null) {
                subscriber.onError(new IllegalStateException("Another subscriber is already subscribed."));
                return;
            }
            this.downstream = subscriber;
            this.openOrUpdateTask();
            this.downstream.onSubscribe(this.forwardingSubscription());
        }

        private void openOrUpdateTask() {
            if (this.task != null) {
                this.task.updateStatus(FlowTask.FlowStatus.CONNECTED);
            } else {
                try (ThreadContext.StoredContext ignored = StreamingTaskManager.this.threadPool.getThreadContext().newTraceContext();){
                    this.task = (FlowTask)StreamingTaskManager.this.taskManager.register(this.taskType, this.taskAction, new TaskAwareRequest(){

                        public void setParentTask(TaskId taskId) {
                            throw new UnsupportedOperationException("parent task id for streaming results shouldn't change");
                        }

                        public void setRequestId(long requestId) {
                            throw new UnsupportedOperationException("does not have request ID");
                        }

                        public TaskId getParentTask() {
                            return TaskId.EMPTY_TASK_ID;
                        }

                        public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
                            FlowTask flowTask = new FlowTask(id, type, action, "", parentTaskId, headers);
                            flowTask.addListener(TaskBackedProcessor.this::cancelTask);
                            return flowTask;
                        }
                    });
                }
            }
        }

        private void cancelTask() {
            if (this.isClosed.compareAndSet(false, true)) {
                if (this.upstream != null) {
                    this.upstream.cancel();
                }
                if (this.downstream != null) {
                    this.downstream.onComplete();
                }
            }
        }

        private Flow.Subscription forwardingSubscription() {
            return new Flow.Subscription(){

                @Override
                public void request(long n) {
                    if (TaskBackedProcessor.this.isClosed.get()) {
                        TaskBackedProcessor.this.downstream.onComplete();
                    } else if (TaskBackedProcessor.this.upstream != null) {
                        TaskBackedProcessor.this.upstream.request(n);
                    } else {
                        TaskBackedProcessor.this.pendingRequests.accumulateAndGet(n, Long::sum);
                    }
                }

                @Override
                public void cancel() {
                    TaskBackedProcessor.this.finishTask();
                    if (TaskBackedProcessor.this.upstream != null) {
                        TaskBackedProcessor.this.upstream.cancel();
                    }
                }
            };
        }

        @Override
        public void onSubscribe(Flow.Subscription subscription) {
            if (this.isClosed.get()) {
                subscription.cancel();
                return;
            }
            this.upstream = subscription;
            this.openOrUpdateTask();
            long currentRequestCount = this.pendingRequests.getAndSet(0L);
            if (currentRequestCount != 0L) {
                this.upstream.request(currentRequestCount);
            }
        }

        @Override
        public void onNext(E item) {
            if (this.isClosed.get()) {
                this.upstream.cancel();
            } else {
                this.downstream.onNext(item);
            }
        }

        @Override
        public void onError(Throwable throwable) {
            this.finishTask();
            if (this.downstream == null) {
                log.atDebug().withThrowable(throwable).log("onError was called before the downstream subscription, rethrowing to close listener.");
                throw new IllegalStateException("onError was called before the downstream subscription", throwable);
            }
            this.downstream.onError(throwable);
        }

        @Override
        public void onComplete() {
            this.finishTask();
            if (this.downstream != null) {
                this.downstream.onComplete();
            }
        }

        private void finishTask() {
            if (this.isClosed.compareAndSet(false, true) && this.task != null) {
                StreamingTaskManager.this.taskManager.unregister((Task)this.task);
            }
        }
    }
}

