/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class MPNetTokenizer
extends BertTokenizer {
    public static final String UNKNOWN_TOKEN = "[UNK]";
    public static final String SEPARATOR_TOKEN = "</s>";
    public static final String PAD_TOKEN = "<pad>";
    public static final String CLASS_TOKEN = "<s>";
    public static final String MASK_TOKEN = "<mask>";
    private static final Set<String> NEVER_SPLIT = Set.of("<mask>");

    protected MPNetTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, boolean withSpecialTokens, int maxSequenceLength, Set<String> neverSplit) {
        super(originalVocab, vocab, doLowerCase, doTokenizeCjKChars, doStripAccents, withSpecialTokens, maxSequenceLength, Sets.union(neverSplit, NEVER_SPLIT), SEPARATOR_TOKEN, CLASS_TOKEN, PAD_TOKEN, MASK_TOKEN, UNKNOWN_TOKEN);
    }

    @Override
    protected int getNumExtraTokensForSeqPair() {
        return 4;
    }

    @Override
    TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
        return new MPNetTokenizationResult.MPNetTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
    }

    @Override
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
        return new MPNetTokenizationResult(this.originalVocab, tokenizations, this.getPadTokenId().orElseThrow());
    }

    public static Builder mpBuilder(List<String> vocab, Tokenization tokenization) {
        return new Builder(vocab, tokenization);
    }

    public static class Builder {
        protected final List<String> originalVocab;
        protected final SortedMap<String, Integer> vocab;
        protected boolean doLowerCase;
        protected boolean doTokenizeCjKChars = true;
        protected boolean withSpecialTokens;
        protected int maxSequenceLength;
        protected Boolean doStripAccents = null;
        protected Set<String> neverSplit;

        protected Builder(List<String> vocab, Tokenization tokenization) {
            this.originalVocab = vocab;
            this.vocab = Builder.buildSortedVocab(vocab);
            this.doLowerCase = tokenization.doLowerCase();
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
            TreeMap<String, Integer> sortedVocab = new TreeMap<String, Integer>();
            for (int i = 0; i < vocab.size(); ++i) {
                sortedVocab.put(vocab.get(i), i);
            }
            return sortedVocab;
        }

        public Builder setDoLowerCase(boolean doLowerCase) {
            this.doLowerCase = doLowerCase;
            return this;
        }

        public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) {
            this.doTokenizeCjKChars = doTokenizeCjKChars;
            return this;
        }

        public Builder setDoStripAccents(Boolean doStripAccents) {
            this.doStripAccents = doStripAccents;
            return this;
        }

        public Builder setNeverSplit(Set<String> neverSplit) {
            this.neverSplit = neverSplit;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
            this.withSpecialTokens = withSpecialTokens;
            return this;
        }

        public MPNetTokenizer build() {
            if (this.doStripAccents == null) {
                this.doStripAccents = this.doLowerCase;
            }
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new MPNetTokenizer(this.originalVocab, this.vocab, this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents, this.withSpecialTokens, this.maxSequenceLength, this.neverSplit);
        }
    }
}

