package net.damschen.swatchit.integrationTest.infrastrcuture.repository

import dagger.hilt.android.testing.HiltAndroidRule
import dagger.hilt.android.testing.HiltAndroidTest
import dagger.hilt.android.testing.HiltTestApplication
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.test.runTest
import net.damschen.swatchit.domain.aggregates.swatch.EpochMillis
import net.damschen.swatchit.domain.aggregates.swatch.Gauge
import net.damschen.swatchit.domain.aggregates.swatch.GaugeCount
import net.damschen.swatchit.domain.aggregates.swatch.GaugeSize
import net.damschen.swatchit.domain.aggregates.swatch.KnittingNeedleSize
import net.damschen.swatchit.domain.aggregates.swatch.Measurement
import net.damschen.swatchit.domain.aggregates.swatch.Notes
import net.damschen.swatchit.domain.aggregates.swatch.Pattern
import net.damschen.swatchit.domain.aggregates.swatch.Photo
import net.damschen.swatchit.domain.aggregates.swatch.Swatch
import net.damschen.swatchit.domain.aggregates.swatch.SwatchId
import net.damschen.swatchit.domain.aggregates.swatch.SwatchName
import net.damschen.swatchit.domain.aggregates.swatch.Yarn
import net.damschen.swatchit.domain.aggregates.swatch.YarnManufacturer
import net.damschen.swatchit.domain.aggregates.swatch.YarnName
import net.damschen.swatchit.domain.repositories.SwatchRepository
import net.damschen.swatchit.domain.resultWrappers.DatabaseResult
import net.damschen.swatchit.infrastructure.database.AppDatabase
import net.damschen.swatchit.infrastructure.database.MeasurementDao
import net.damschen.swatchit.infrastructure.database.MeasurementEntity
import net.damschen.swatchit.infrastructure.database.RoomTransactionProvider
import net.damschen.swatchit.infrastructure.database.SwatchAggregate
import net.damschen.swatchit.infrastructure.database.SwatchDao
import net.damschen.swatchit.infrastructure.database.SwatchEntity
import net.damschen.swatchit.infrastructure.repository.SqlSwatchRepository
import net.damschen.swatchit.shared.testhelpers.FakeUUIDProvider
import net.damschen.swatchit.shared.testhelpers.testdata.SwatchTestData
import org.junit.After
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.annotation.Config
import java.time.LocalDate
import java.time.ZoneId
import javax.inject.Inject
import javax.inject.Named

@RunWith(RobolectricTestRunner::class)
@HiltAndroidTest
@Config(application = HiltTestApplication::class)
class SwatchRepositoryImplTests {
    @Inject
    @Named("test_db")
    lateinit var database: AppDatabase
    private lateinit var swatchDao: SwatchDao
    private lateinit var measurementDao: MeasurementDao
    private lateinit var repository: SwatchRepository

    @get:Rule
    var hiltRule = HiltAndroidRule(this)

    @Before
    fun initDb() {
        hiltRule.inject()
        swatchDao = database.swatchDao()
        measurementDao = database.measurementDao()
        repository =
            SqlSwatchRepository(swatchDao, measurementDao, RoomTransactionProvider(database))

        database.clearAllTables()
    }

    @After
    fun closeDb() {
        database.close()
    }

    @Test
    fun insert_newSwatch_insertsIntoDatabase() = runTest {
        val swatchModel = createSwatch().withMeasurement(defaultMeasurement)

        val result = repository.add(swatchModel)

        val swatchInDb = swatchDao.get(1)

        assertTrue(result is DatabaseResult.Success)
        assertEquals(createSwatchAggregate(), swatchInDb)
    }

    @Test
    fun insert_swatchWithId_returnsError() = runTest {
        val swatchModel = createSwatch(id = 5).withMeasurement(defaultMeasurement)

        val result = repository.add(swatchModel)

        assertTrue(result is DatabaseResult.Error)
    }

    @Test
    fun get_emptyDatabase_returnsEmptyList() = runTest {
        val result = repository.getAll().first()

        assertTrue(result is DatabaseResult.Success)
        val swatches = (result as DatabaseResult.Success).data
        assertEquals(0, swatches.count())
    }

    @Test
    fun get_existingSwatch_returnsList() = runTest {
        val id = swatchDao.insert(createSwatchEntity())
        measurementDao.insert(createMeasurementEntity(id.toInt()))

        val result = repository.getAll().first()

        assertTrue(result is DatabaseResult.Success)
        val swatches = (result as DatabaseResult.Success).data
        assertEquals(1, swatches.count())
        val swatch = swatches.first()
        val expected = createSwatch(1).withMeasurement(defaultMeasurement)
        assertEquals(SwatchTestData.from(expected), SwatchTestData.from(swatch))
    }

    @Test
    fun get_nonExistingSwatch_returnsNull() = runTest {
        val result = repository.get(1)

        assertTrue(result is DatabaseResult.Success)
        val swatch = (result as DatabaseResult.Success).data
        assertNull(swatch)
    }

    @Test
    fun get_existingSwatch_returnsSwatch() = runTest {
        val id = swatchDao.insert(createSwatchEntity())
        measurementDao.insert(createMeasurementEntity(id.toInt()))

        val result = repository.get(1)

        assertTrue(result is DatabaseResult.Success)
        val swatch = (result as DatabaseResult.Success).data
        val expected = createSwatch(1).withMeasurement(defaultMeasurement)
        assertEquals(SwatchTestData.from(expected), SwatchTestData.from(swatch))
    }

    @Test
    fun update_existingSwatch_updatesSwatchTable() = runTest {
        val id = swatchDao.insert(createSwatchEntity())
        val newPattern = Pattern.create("New Pattern")!!
        val swatch = createSwatch(id = id.toInt(), pattern = newPattern)

        val result = repository.update(swatch)

        val swatchAggregate = swatchDao.get(id.toInt())!!

        assertTrue(result is DatabaseResult.Success)
        assertEquals(1, swatchAggregate.swatch.id)
        assertEquals(newPattern.value, swatchAggregate.swatch.pattern)
    }

    @Test
    fun update_addNewMeasurement_updatesMeasurementsTable() = runTest {
        val id = swatchDao.insert(createSwatchEntity()).toInt()
        val swatch = createSwatch(id = id).withMeasurement(defaultMeasurement)

        repository.update(swatch)

        val measurementEntities = measurementDao.getBySwatchId(id)

        assertEquals(listOf(createMeasurementEntity(id)), measurementEntities)
    }

    @Test
    fun update_swatchWithoutId_returnsErrorResult() = runTest {
        val swatch = createSwatch()

        val result = repository.update(swatch)

        assertTrue(result is DatabaseResult.Error)
    }

    @Test
    fun update_nonExistingSwatch_returnsErrorResult() = runTest {
        val swatch = createSwatch(id = 1)

        val result = repository.update(swatch)

        assertTrue(result is DatabaseResult.Error)
    }

    @Test
    fun update_deleteMeasurement_updatesMeasurementsTable() = runTest {
        val id = swatchDao.insert(createSwatchEntity()).toInt()
        measurementDao.insert(createMeasurementEntity(id))
        val result = repository.get(id)
        val swatch = (result as DatabaseResult.Success).data!!.withoutMeasurementAt(0)

        repository.update(swatch)

        val measurementEntities = measurementDao.getBySwatchId(id)
        assertEquals(emptyList<MeasurementEntity>(), measurementEntities)
    }

    @Test
    fun delete_existingSwatch_deletesSwatchFromDatabase() = runTest {
        val id = swatchDao.insert(createSwatchEntity()).toInt()

        repository.delete(SwatchId(id))

        val swatchesInDb = swatchDao.get().first()
        assertEquals(0, swatchesInDb.count())
    }

    @Test
    fun delete_existingSwatch_deletesMeasurementsFromDatabase() = runTest {
        val id = swatchDao.insert(createSwatchEntity()).toInt()
        measurementDao.insert(createMeasurementEntity((id)))

        repository.delete(SwatchId(id))

        val measurementEntities = measurementDao.getBySwatchId(id)
        assertTrue(measurementEntities.isEmpty())
    }
}

private fun createSwatch(id: Int? = null, pattern: Pattern? = defaultPattern): Swatch {
    return Swatch.create(
        needleSize = defaultNeedleSize,
        pattern = pattern,
        yarn = defaultYarn,
        notes = defaultNotes,
        createdAt = defaultCreatedAt,
        id = id?.let { SwatchId(it) },
        name = defaultName
    ).withUpdatedGauge(defaultGauge).withUpdatedPhoto(defaultPhoto)
}

private fun createSwatchAggregate(swatchId: Int = 1): SwatchAggregate = SwatchAggregate(
    createSwatchEntity(), listOf(
        createMeasurementEntity(swatchId)
    )
)

private val defaultName = SwatchName.create("TestName")!!
private val defaultPhoto = Photo(FakeUUIDProvider.defaultUUID)
private val defaultMeasurement =
    Measurement.Stitches(GaugeCount(3), GaugeSize(12.0))
private val defaultNeedleSize = KnittingNeedleSize.SIZE_2_5
private val defaultPattern = Pattern.create("Stockinette")!!
private val defaultGauge =
    Gauge(GaugeCount(20), GaugeCount(30), GaugeSize(10.0))
private val defaultYarn =
    Yarn.create(YarnName.create("Yarn Name"), YarnManufacturer.create("Yarn Manufacturer"))!!
private val defaultNotes = Notes.create("Test notes!")!!
private val defaultCreatedAt = EpochMillis(
    LocalDate.of(2025, 2, 17).atStartOfDay(
        ZoneId.of("UTC")
    ).toInstant().toEpochMilli()
)

private fun createSwatchEntity(): SwatchEntity {
    return SwatchEntity(
        needleSize = net.damschen.swatchit.infrastructure.database.KnittingNeedleSize.SIZE_2_5,
        yarnName = defaultYarn.name!!.value,
        yarnManufacturer = defaultYarn.manufacturer!!.value,
        nrOfStitches = defaultGauge.nrOfStitches.value,
        nrOfRows = defaultGauge.nrOfRows.value,
        gaugeLength = defaultGauge.size.value,
        createdAt = defaultCreatedAt.value,
        name = defaultName.value,
        pattern = defaultPattern.value,
        notes = defaultNotes.value,
        photoUUID = FakeUUIDProvider.defaultUUID
    )
}

private fun createMeasurementEntity(swatchId: Int): MeasurementEntity {
    return MeasurementEntity(
        net.damschen.swatchit.infrastructure.database.MeasurementType.Stitches,
        defaultMeasurement.count.value,
        defaultMeasurement.size.value,
        swatchId
    )
}