@file:OptIn(ExperimentalCoroutinesApi::class)

package net.damschen.swatchit.test.ui.viewmodels

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import net.damschen.swatchit.domain.aggregates.swatch.Gauge
import net.damschen.swatchit.domain.aggregates.swatch.GaugeCount
import net.damschen.swatchit.domain.aggregates.swatch.GaugeSize
import net.damschen.swatchit.domain.aggregates.swatch.Measurement
import net.damschen.swatchit.shared.testhelpers.testdata.SwatchTestData
import net.damschen.swatchit.shared.testhelpers.MainDispatcherRule
import net.damschen.swatchit.test.testHelpers.database.FakeRepo
import net.damschen.swatchit.ui.models.GaugeCalculationState
import net.damschen.swatchit.ui.models.GaugeState
import net.damschen.swatchit.ui.models.LoadState
import net.damschen.swatchit.ui.models.MeasurementFormState
import net.damschen.swatchit.ui.models.MeasurementListItem
import net.damschen.swatchit.ui.models.MeasurementsState
import net.damschen.swatchit.ui.models.ValidatedInput
import net.damschen.swatchit.ui.viewmodels.MeasurementViewModel
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test

class MeasurementViewModelTests {
    @get:Rule
    val mainDispatcherRule = MainDispatcherRule()

    private var repo = FakeRepo()
    private lateinit var sut: MeasurementViewModel

    @Before
    fun initSut() {
        repo = FakeRepo()
        sut = MeasurementViewModel(repo.defaultId, repo)
    }

    @Test
    fun onCountChanged_validCount_updatesState() {
        sut.onCountChanged("13")

        assertEquals("13", sut.measurementFormState.value.count.value)
        assertTrue(sut.measurementFormState.value.count is ValidatedInput.Valid)
    }

    @Test
    fun onCountChanged_invalidCount_updatesState() {
        sut.onCountChanged("abc")

        assertEquals("abc", sut.measurementFormState.value.count.value)
        assertFalse(sut.measurementFormState.value.count is ValidatedInput.Valid)
    }

    @Test
    fun onSizeChanged_validSize_updatesState() {
        sut.onSizeChanged("13")

        assertEquals("13", sut.measurementFormState.value.size.value)
        assertTrue(sut.measurementFormState.value.size is ValidatedInput.Valid)
    }

    @Test
    fun onSizeChanged_invalidSize_updatesState() {
        sut.onSizeChanged("-1")

        assertEquals("-1", sut.measurementFormState.value.size.value)
        assertFalse(sut.measurementFormState.value.size is ValidatedInput.Valid)
    }

    @Test
    fun onCountTypeChanged_toRows_updatesState() {
        sut.onTypeChanged(false)

        assertTrue(sut.measurementFormState.value is MeasurementFormState.Rows)
    }

    @Test
    fun onCountTypeChanged_toSTitches_updatesState() {
        sut.onTypeChanged(true)

        assertTrue(sut.measurementFormState.value is MeasurementFormState.Stitches)
    }

    @Test
    fun loadSwatch_existingId_loadsList() = runTest {
        advanceUntilIdle()
        assertTrue(sut.loadState.value is LoadState.Success)
        assertEquals(
            MeasurementsState(
                listOf(
                    MeasurementListItem.Rows(
                        repo.defaultMeasurement.count.value,
                        repo.defaultMeasurement.size.value,
                    )
                )
            ), sut.measurementsState.value
        )
    }

    @Test
    fun loadSwatch_nonExistingId_setsLoadStateToNotFound() = runTest {
        repo.returnNull = true
        sut = MeasurementViewModel(13, repo)
        advanceUntilIdle()

        assertTrue(sut.loadState.value is LoadState.NotFound)
        assertEquals(MeasurementsState(emptyList()), sut.measurementsState.value)
    }

    @Test
    fun loadSwatch_repoReturnsError_setsLoadStateToError() = runTest {
        repo.returnError = true
        sut = MeasurementViewModel(repo.defaultId, repo)
        advanceUntilIdle()

        assertTrue(sut.loadState.value is LoadState.Error)
        assertEquals(MeasurementsState(emptyList()), sut.measurementsState.value)
    }

    @Test
    fun addMeasurement_swatchWithGauge_callsRepoWithUnchangedGauge() = runTest {
        val initialSwatch = SwatchTestData.from(repo.swatchToReturn())!!

        sut.onCountChanged("13")
        sut.onSizeChanged("20")
        sut.onTypeChanged(true)
        sut.addMeasurement()
        advanceUntilIdle()

        val newMeasurements = ArrayList(initialSwatch.measurements)
        newMeasurements.add(
            Measurement.Stitches(
                GaugeCount(13),
                GaugeSize(20.0)
            )
        )
        val expected = initialSwatch.copy(measurements = newMeasurements)

        assertEquals(expected, SwatchTestData.from(repo.updatedSwatch))
    }

    @Test
    fun addMeasurement_repoReturnsSuccess_emitSuccess() = runTest {
        sut.onCountChanged("13")
        sut.onSizeChanged("20")
        sut.onTypeChanged(true)

        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.savedSuccessfully.collect {
                values.add(it)
            }
        }

        sut.addMeasurement()
        advanceUntilIdle()

        assertTrue(values.last())
    }

    @Test
    fun init_uninitializedInput_errorIdsSetInState() = runTest {
        assertNotNull((sut.measurementFormState.value.count as ValidatedInput.Invalid).errorMessageId)
        assertNotNull((sut.measurementFormState.value.size as ValidatedInput.Invalid).errorMessageId)
    }

    @Test
    fun addMeasurement_invalidInput_setsError() = runTest {
        sut.onCountChanged("-13")
        sut.onSizeChanged("20")

        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.savedSuccessfully.collect {
                values.add(it)
            }
        }

        sut.addMeasurement()
        advanceUntilIdle()

        assertFalse(values.last())
    }

    @Test
    fun addMeasurement_repoReturnsSuccess_MeasurementAddedToStateList() = runTest {
        sut.onCountChanged("13")
        sut.onSizeChanged("20")
        sut.onTypeChanged(true)
        sut.addMeasurement()
        advanceUntilIdle()

        assertEquals(
            MeasurementListItem.Stitches(13, 20.0),
            sut.measurementsState.value.items.last()
        )
    }

    @Test
    fun addMeasurement_twoConsecutiveMeasurements_MeasurementIsNotOverwritten() = runTest {
        sut.onCountChanged("13")
        sut.onSizeChanged("20")
        sut.onTypeChanged(true)
        sut.addMeasurement()
        sut.onCountChanged("13")
        sut.onSizeChanged("20")
        sut.onTypeChanged(true)
        sut.addMeasurement()
        advanceUntilIdle()

        val expectedMeasurements = listOf(
            repo.defaultMeasurement,
            Measurement.Stitches(GaugeCount(13), GaugeSize(20.0)),
            Measurement.Stitches(GaugeCount(13), GaugeSize(20.0))
        )

        assertEquals(expectedMeasurements, repo.updatedSwatch!!.measurements)
    }

    @Test
    fun addMeasurement_repoReturnsErrorDuringUpdate_setsErrorState() = runTest {
        sut.onCountChanged("13")
        sut.onSizeChanged("20")
        sut.onTypeChanged(true)
        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.savedSuccessfully.collect {
                values.add(it)
            }
        }

        repo.returnErrorDuringUpdate = true

        sut.addMeasurement()
        advanceUntilIdle()

        assertFalse(values.last())
    }

    @Test
    fun deleteMeasurementAt_existingMeasurement_callsRepo() = runTest {
        val initialSwatch = SwatchTestData.from(repo.swatchToReturn())!!

        sut.deleteMeasurementAt(0)
        advanceUntilIdle()

        val expected = initialSwatch.copy(measurements = initialSwatch.measurements.dropLast(1))

        assertEquals(expected, SwatchTestData.from(repo.updatedSwatch))
    }

    @Test
    fun deleteMeasurementAt_repoReturnsSuccess_emitSuccess() = runTest {
        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.deletedSuccessfully.collect {
                values.add(it)
            }
        }

        sut.deleteMeasurementAt(0)
        advanceUntilIdle()

        assertTrue(values.last())
    }

    @Test
    fun deleteMeasurementAt_repoReturnsSuccess_MeasurementDeletedFromStateList() {
        sut.deleteMeasurementAt(0)

        assertEquals(0, sut.measurementsState.value.items.count())
    }

    @Test
    fun deleteMeasurementAt_repoReturnsErrorDuringUpdate_setsErrorState() = runTest {
        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.deletedSuccessfully.collect {
                values.add(it)
            }
        }

        repo.returnErrorDuringUpdate = true
        sut.deleteMeasurementAt(0)
        advanceUntilIdle()

        assertFalse(values.last())
    }

    @Test
    fun deleteMeasurementAt_nonExistingMeasurement_setsErrorState() = runTest {
        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.deletedSuccessfully.collect {
                values.add(it)
            }
        }

        sut.deleteMeasurementAt(3)
        advanceUntilIdle()

        assertFalse(values.last())
    }

    @Test
    fun deleteMeasurementAt_twoEqualItems_oneMeasurementDeletedFromStateList() = runTest {
        val swatch = repo.swatchToReturn()!!.withMeasurement(repo.defaultMeasurement)
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)
        advanceUntilIdle()

        sut.deleteMeasurementAt(0)
        advanceUntilIdle()

        assertEquals(1, sut.measurementsState.value.items.count())
    }

    @Test
    fun calculateGauge_validMeasurements_updatesGaugeState() = runTest {
        val swatch = repo.swatchToReturn()!!.withNewMeasurements(
            listOf(
                Measurement.Rows(GaugeCount(34), GaugeSize(10.0)),
                Measurement.Stitches(GaugeCount(34), GaugeSize(10.0))
            )
        )
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)
        advanceUntilIdle()

        sut.calculateGauge()
        assertEquals(
            GaugeCalculationState.Valid(
                GaugeState(34, 34, 10.0)
            ), sut.gaugeCalculationState.value
        )
    }

    @Test
    fun calculateGauge_noMeasurements_setsErrorIdInGaugeState() = runTest {
        val swatch = repo.swatchToReturn()!!
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)
        advanceUntilIdle()

        sut.calculateGauge()
        assertNotNull((sut.gaugeCalculationState.value as GaugeCalculationState.Invalid).errorMessageId)
    }

    @Test
    fun saveGaugeCalculationState_gaugeHasBeenCalculated_callsRepo() = runTest {
        val swatch = repo.swatchToReturn()!!.withNewMeasurements(
            listOf(
                Measurement.Rows(GaugeCount(34), GaugeSize(10.0)),
                Measurement.Stitches(GaugeCount(34), GaugeSize(10.0))
            )
        )
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)
        advanceUntilIdle()

        sut.calculateGauge()
        sut.saveGaugeCalculationState()
        advanceUntilIdle()

        assertEquals(
            Gauge(GaugeCount(34), GaugeCount(34), GaugeSize(10.0)),
            repo.updatedSwatch!!.gauge
        )
    }

    @Test
    fun saveGaugeCalculationState_gaugeHasBeenCalculated_emitsSuccess() = runTest {
        val swatch = repo.swatchToReturn()!!.withNewMeasurements(
            listOf(
                Measurement.Rows(GaugeCount(34), GaugeSize(10.0)),
                Measurement.Stitches(GaugeCount(34), GaugeSize(10.0))
            )
        )
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)
        advanceUntilIdle()

        sut.calculateGauge()

        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.gaugeSuccessfullySaved.collect {
                values.add(it)
            }
        }

        sut.saveGaugeCalculationState()
        advanceUntilIdle()

        assertTrue(values.last())
    }


    @Test
    fun saveGaugeCalculationState_gaugeHasNotBeenCalculated_emitsError() = runTest {
        val swatch = repo.swatchToReturn()!!.withNewMeasurements(
            listOf(
                Measurement.Rows(GaugeCount(34), GaugeSize(10.0)),
                Measurement.Stitches(GaugeCount(34), GaugeSize(10.0))
            )
        )
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)

        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.gaugeSuccessfullySaved.collect {
                values.add(it)
            }
        }

        sut.saveGaugeCalculationState()
        advanceUntilIdle()

        assertFalse(values.last())
    }


    @Test
    fun saveGaugeCalculationState_repoReturnsError_emitsError() = runTest {
        val swatch = repo.swatchToReturn()!!.withNewMeasurements(
            listOf(
                Measurement.Rows(GaugeCount(34), GaugeSize(10.0)),
                Measurement.Stitches(GaugeCount(34), GaugeSize(10.0))
            )
        )
        repo.swatchToReturn = { swatch }
        sut = MeasurementViewModel(repo.defaultId, repo)

        sut.calculateGauge()
        repo.returnErrorDuringUpdate = true

        val values = mutableListOf<Boolean>()
        backgroundScope.launch(UnconfinedTestDispatcher(testScheduler)) {
            sut.gaugeSuccessfullySaved.collect {
                values.add(it)
            }
        }

        sut.saveGaugeCalculationState()
        advanceUntilIdle()

        assertFalse(values.last())
    }
}