package org.residuum.alligator.samplefiles;

import android.content.ContentResolver;
import android.content.Context;
import android.net.Uri;
import android.os.ParcelFileDescriptor;

import org.residuum.alligator.utils.FileUri;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Objects;
import java.util.stream.IntStream;

import androidx.annotation.NonNull;
import androidx.documentfile.provider.DocumentFile;

public final class WaveFileExporter {

    private static final String SESSION_RECORDING = "Alligator_Bytes_%s%s.wav";
    private static final String SAMPLE_RECORDING = "recording_%d_%s.wav";
    private static final String TMP_FILE_SNIPPET = "_tmp";
    private static final String WAV_FILE_SNIPPET = "";
    private final ContentResolver resolver;
    private final DateTimeFormatter dateTimeFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd--HH-mm-ss");
    private final Context context;

    public WaveFileExporter(@NonNull final Context context) {
        this.resolver = context.getContentResolver();
        this.context = context;
    }

    public String writeSampleRecording(final float[] audioData, final int sampleRate, final int recordingNo, @NonNull final LocalDateTime dateOfRecording) throws IOException {
        String fileName = String.format(WaveFileExporter.SAMPLE_RECORDING, recordingNo, dateOfRecording.format(this.dateTimeFormat));
        this.writeWaveFile(fileName, audioData, sampleRate, 1);
        return fileName;
    }

    private void writeWaveFile(final String fileName, @NonNull final float[] audioData, final int sampleRate, final int channels) throws IOException {
        final short[] pcmContent = new short[audioData.length];
        IntStream.range(0, audioData.length).forEach(i -> pcmContent[i] = (short) (audioData[i] * 1.414 * 32767));
        Uri fileUri = FileUri.getUriForFile(fileName, this.context, "audio/wav");
        if (fileUri == null) {
            throw new IOException("Could not write to folder");
        }
        new WaveFileWriter().rawToWave(pcmContent, sampleRate, channels, this.resolver, fileUri);
    }

    public void writeSessionRecording(final float[] audioData, @NonNull final LocalDateTime dateOfRecording) throws IOException {
        String fileName = String.format(WaveFileExporter.SESSION_RECORDING, dateOfRecording.format(this.dateTimeFormat), WaveFileExporter.TMP_FILE_SNIPPET);
        this.writeHeadlessWaveFile(fileName, audioData);
    }

    private void writeHeadlessWaveFile(final String fileName, @NonNull final float[] audioData) throws IOException {
        final short[] pcmContent = new short[audioData.length];
        IntStream.range(0, audioData.length).forEach(i -> pcmContent[i] = (short) (audioData[i] * 1.414 * 32767));
        Uri fileUri = FileUri.getUriForFile(fileName, this.context, "audio/wav");
        if (fileUri == null) {
            throw new IOException("Could not write to folder");
        }
        new WaveFileWriter().rawToHeadlessWave(pcmContent, this.resolver, fileUri);
    }

    public String finalizeSessionRecording(final LocalDateTime recordDateTime, final int sampleRate) throws IOException {
        String fileName = String.format(WaveFileExporter.SESSION_RECORDING, recordDateTime.format(this.dateTimeFormat), WaveFileExporter.TMP_FILE_SNIPPET);
        Uri tmpFileUri = FileUri.getUriForFile(fileName, this.context, "audio/wav");
        if (tmpFileUri == null) {
            throw new IOException("Could not read temporary file");
        }
        fileName = String.format(WaveFileExporter.SESSION_RECORDING, recordDateTime.format(this.dateTimeFormat), WaveFileExporter.WAV_FILE_SNIPPET);
        Uri wavFileUri = FileUri.getUriForFile(fileName, this.context, "audio/wav");
        if (wavFileUri == null) {
            throw new IOException("Could not write to folder");
        }
        new WaveFileWriter().copyRawDataToWave(tmpFileUri, wavFileUri, sampleRate, 2, this.resolver, this.context);
        return fileName;
    }

    protected static class WaveFileWriter {
        void rawToWave(@NonNull final short[] audioData, final int sampleRate, final int channels, ContentResolver contentResolver, Uri fileUri) throws IOException {
            final byte[] bytes = this.get16BitPcm(audioData);
            this.writeWaveFile(bytes, sampleRate, channels, contentResolver, fileUri);
        }

        private byte[] get16BitPcm(final short[] audioDataAsShort) {
            final byte[] audioDataAsPcm = new byte[2 * audioDataAsShort.length];
            int iter = 0;
            for (final short sample : audioDataAsShort) {
                audioDataAsPcm[iter] = (byte) (sample & 0x00ff);
                iter++;
                audioDataAsPcm[iter] = (byte) ((sample & 0xff00) >>> 8);
                iter++;
            }
            return audioDataAsPcm;
        }

        private void writeWaveFile(final byte[] bytes, final int sampleRate, final int channels, final ContentResolver contentResolver, final Uri fileUri) throws IOException {
            try (final OutputStream output = contentResolver.openOutputStream(fileUri)) {
                this.writeWaveHeader(bytes.length, sampleRate, channels, output);
                assert output != null;
                output.write(bytes);
            }
        }

        private void writeWaveHeader(final int numberOfBytes, final int sampleRate, final int channels, final OutputStream output) throws IOException {
            // WAVE header
            // see https://web.archive.org/web/20100325183246/http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
            this.writeString(output, "RIFF"); // chunk id
            this.writeInt(output, 36 + numberOfBytes); // chunk size
            this.writeString(output, "WAVE"); // format
            this.writeString(output, "fmt "); // subchunk 1 id
            this.writeInt(output, 16); // subchunk 1 size
            this.writeShort(output, (short) 1); // audio format (1 = PCM)
            this.writeShort(output, (short) channels); // number of channels
            this.writeInt(output, sampleRate); // sample rate
            this.writeInt(output, channels * sampleRate * 2); // byte rate
            this.writeShort(output, (short) 2); // block align
            this.writeShort(output, (short) 16); // bits per sample
            this.writeString(output, "data"); // subchunk 2 id
            this.writeInt(output, numberOfBytes); // subchunk 2 size
        }

        private void writeString(OutputStream output, String value) throws IOException {
            for (int i = 0; i < value.length(); i++) {
                output.write(value.charAt(i));
            }
        }

        private void writeInt(OutputStream output, int value) throws IOException {
            output.write(value);
            output.write(value >> 8);
            output.write(value >> 16);
            output.write(value >> 24);
        }

        private void writeShort(OutputStream output, short value) throws IOException {
            output.write(value);
            output.write(value >> 8);
        }

        void rawToHeadlessWave(final short[] audioData, ContentResolver contentResolver, Uri fileUri) throws IOException {
            try (final OutputStream output = contentResolver.openOutputStream(fileUri, "wa")) {
                final byte[] bytes = this.get16BitPcm(audioData);
                assert output != null;
                output.write(bytes);
            }
        }

        void copyRawDataToWave(final Uri tmpFileUri, final Uri finalFileUri, final int sampleRate, final int channels, final ContentResolver contentResolver, final Context context) throws IOException {
            final long fileSize;
            try (final ParcelFileDescriptor fileDescriptor = contentResolver.openFileDescriptor(tmpFileUri, "r")) {
                assert fileDescriptor != null;
                fileSize = fileDescriptor.getStatSize();
            }
            try (final OutputStream output = contentResolver.openOutputStream(finalFileUri)) {
                this.writeWaveHeader((int) fileSize, sampleRate, channels, output);
                assert output != null;
                try (final InputStream inputStream = contentResolver.openInputStream(tmpFileUri)) {
                    int nRead;
                    final byte[] data = new byte[16384];
                    while (true) {
                        assert inputStream != null;
                        if (-1 == (nRead = inputStream.read(data, 0, data.length))) break;
                        output.write(data, 0, nRead);
                    }
                    output.flush();
                }
            }

            try {
                DocumentFile file = DocumentFile.fromTreeUri(context, tmpFileUri);
                Objects.requireNonNull(file).delete();
            } catch (IllegalArgumentException e) {
                contentResolver.delete(tmpFileUri, null, null);
            }
        }
    }
}
