use bevy::{
    core_pipeline::{
        core_3d::graph::{Core3d, Node3d},
        fullscreen_vertex_shader::fullscreen_shader_vertex_state,
    },
    ecs::query::QueryItem,
    prelude::*,
    render::{
        extract_component::{
            ComponentUniforms, DynamicUniformIndex, ExtractComponent, ExtractComponentPlugin,
            UniformComponentPlugin,
        },
        render_graph::{
            NodeRunError, RenderGraphApp, RenderGraphContext, RenderLabel, ViewNode, ViewNodeRunner,
        },
        render_resource::{
            binding_types::{sampler, texture_2d, uniform_buffer},
            *,
        },
        renderer::{RenderContext, RenderDevice},
        view::ViewTarget,
        RenderApp,
    },
};

use crate::define_asset_collection;

const SHADER_ASSET_PATH: &str = "shaders/gaussian_blur.wgsl";

define_asset_collection!(
    BlurShader,
    !shader : Shader = "shaders/gaussian_blur.wgsl",
);

pub struct BlurPlugin;

impl Plugin for BlurPlugin {
    fn build(&self, app: &mut App) {
        app.add_plugins((
            ExtractComponentPlugin::<BlurSettings>::default(),
            UniformComponentPlugin::<BlurSettings>::default(),
        ));

        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
            return;
        };

        render_app
            .add_render_graph_node::<ViewNodeRunner<BlurNode>>(Core3d, BlurLabel)
            .add_render_graph_edges(
                Core3d,
                (
                    Node3d::Tonemapping,
                    BlurLabel,
                    Node3d::EndMainPassPostProcessing,
                ),
            );
    }

    fn finish(&self, app: &mut App) {
        let Some(render_app) = app.get_sub_app_mut(RenderApp) else {
            return;
        };

        render_app
            // Initialize the pipeline
            .init_resource::<BlurPipeline>();
    }
}

#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
struct BlurLabel;

// The post process node used for the render graph
#[derive(Default)]
struct BlurNode;

// The ViewNode trait is required by the ViewNodeRunner
impl ViewNode for BlurNode {
    type ViewQuery = (
        &'static ViewTarget,
        &'static BlurSettings,
        // As there could be multiple post processing components sent to the GPU (one per camera),
        // we need to get the index of the one that is associated with the current view.
        &'static DynamicUniformIndex<BlurSettings>,
    );

    fn run(
        &self,
        _graph: &mut RenderGraphContext,
        render_context: &mut RenderContext,
        (view_target, _post_process_settings, settings_index): QueryItem<Self::ViewQuery>,
        world: &World,
    ) -> Result<(), NodeRunError> {
        // Get the pipeline resource that contains the global data we need
        // to create the render pipeline
        let post_process_pipeline = world.resource::<BlurPipeline>();

        // The pipeline cache is a cache of all previously created pipelines.
        // It is required to avoid creating a new pipeline each frame,
        // which is expensive due to shader compilation.
        let pipeline_cache = world.resource::<PipelineCache>();

        let pipelines = if view_target.is_hdr() {
            [
                post_process_pipeline.vertical_pipeline_hdr_id,
                post_process_pipeline.horizontal_pipeline_hdr_id,
            ]
        } else {
            [
                post_process_pipeline.vertical_pipeline_id,
                post_process_pipeline.horizontal_pipeline_id,
            ]
        };

        for pipeline_id in pipelines {
            // Get the pipeline from the cache
            let Some(pipeline) = pipeline_cache.get_render_pipeline(pipeline_id) else {
                return Ok(());
            };

            // Get the settings uniform binding
            let settings_uniforms = world.resource::<ComponentUniforms<BlurSettings>>();
            let Some(settings_binding) = settings_uniforms.uniforms().binding() else {
                return Ok(());
            };

            // This will start a new "post process write", obtaining two texture
            // views from the view target - a `source` and a `destination`.
            // `source` is the "current" main texture and you _must_ write into
            // `destination` because calling `post_process_write()` on the
            // [`ViewTarget`] will internally flip the [`ViewTarget`]'s main
            // texture to the `destination` texture. Failing to do so will cause
            // the current main texture information to be lost.
            let post_process = view_target.post_process_write();

            // The bind_group gets created each frame.
            //
            // Normally, you would create a bind_group in the Queue set,
            // but this doesn't work with the post_process_write().
            // The reason it doesn't work is because each post_process_write will alternate the source/destination.
            // The only way to have the correct source/destination for the bind_group
            // is to make sure you get it during the node execution.
            let bind_group = render_context.render_device().create_bind_group(
                "blur_bind_group",
                &post_process_pipeline.layout,
                // It's important for this to match the BindGroupLayout defined in the PostProcessPipeline
                &BindGroupEntries::sequential((
                    // Make sure to use the source view
                    post_process.source,
                    // Use the sampler created for the pipeline
                    &post_process_pipeline.sampler,
                    // Set the settings binding
                    settings_binding.clone(),
                )),
            );

            // Begin the render pass
            let mut render_pass = render_context.begin_tracked_render_pass(RenderPassDescriptor {
                label: Some("blur_pass"),
                color_attachments: &[Some(RenderPassColorAttachment {
                    // We need to specify the post process destination view here
                    // to make sure we write to the appropriate texture.
                    view: post_process.destination,
                    resolve_target: None,
                    ops: Operations::default(),
                })],
                depth_stencil_attachment: None,
                timestamp_writes: None,
                occlusion_query_set: None,
            });

            // This is mostly just wgpu boilerplate for drawing a fullscreen triangle,
            // using the pipeline/bind_group created above
            render_pass.set_render_pipeline(pipeline);
            // By passing in the index of the post process settings on this view, we ensure
            // that in the event that multiple settings were sent to the GPU (as would be the
            // case with multiple cameras), we use the correct one.
            render_pass.set_bind_group(0, &bind_group, &[settings_index.index()]);
            render_pass.draw(0..3, 0..1);
        }

        Ok(())
    }
}

// This contains global data used by the render pipeline. This will be created once on startup.
#[derive(Resource)]
struct BlurPipeline {
    layout: BindGroupLayout,
    sampler: Sampler,
    vertical_pipeline_id: CachedRenderPipelineId,
    horizontal_pipeline_id: CachedRenderPipelineId,
    vertical_pipeline_hdr_id: CachedRenderPipelineId,
    horizontal_pipeline_hdr_id: CachedRenderPipelineId,
}

impl FromWorld for BlurPipeline {
    fn from_world(world: &mut World) -> Self {
        let render_device = world.resource::<RenderDevice>();

        // We need to define the bind group layout used for our pipeline
        let layout = render_device.create_bind_group_layout(
            "blur_group_layout",
            &BindGroupLayoutEntries::sequential(
                // The layout entries will only be visible in the fragment stage
                ShaderStages::FRAGMENT,
                (
                    // The screen texture
                    texture_2d(TextureSampleType::Float { filterable: true }),
                    // The sampler that will be used to sample the screen texture
                    sampler(SamplerBindingType::Filtering),
                    // The settings uniform that will control the effect
                    uniform_buffer::<BlurSettings>(true),
                ),
            ),
        );

        // We can create the sampler here since it won't change at runtime and doesn't depend on the view
        let sampler = render_device.create_sampler(&SamplerDescriptor::default());

        // Get the shader handle
        let shader = world.load_asset(SHADER_ASSET_PATH);

        let mut pipeline: [CachedRenderPipelineId; 4] = [CachedRenderPipelineId::INVALID; 4];

        for (id, name, hdr) in [
            (0, "vertical", false),
            (1, "horizontal", false),
            (2, "vertical", true),
            (3, "horizontal", true),
        ] {
            pipeline[id] = world
                .resource_mut::<PipelineCache>()
                // This will add the pipeline to the cache and queue its creation
                .queue_render_pipeline(RenderPipelineDescriptor {
                    label: Some(format!("blur_{name}_pipeline").into()),
                    layout: vec![layout.clone()],
                    // This will setup a fullscreen triangle for the vertex state
                    vertex: fullscreen_shader_vertex_state(),
                    fragment: Some(FragmentState {
                        shader: shader.clone(),
                        shader_defs: vec![],
                        // Make sure this matches the entry point of your shader.
                        // It can be anything as long as it matches here and in the shader.
                        entry_point: name.into(),
                        targets: vec![Some(ColorTargetState {
                            format: if hdr {
                                ViewTarget::TEXTURE_FORMAT_HDR
                            } else {
                                TextureFormat::bevy_default()
                            },
                            blend: None,
                            write_mask: ColorWrites::ALL,
                        })],
                    }),
                    // All of the following properties are not important for this effect so just use the default values.
                    // This struct doesn't have the Default trait implemented because not all fields can have a default value.
                    primitive: PrimitiveState::default(),
                    depth_stencil: None,
                    multisample: MultisampleState::default(),
                    push_constant_ranges: vec![],
                    zero_initialize_workgroup_memory: false,
                });
        }

        Self {
            layout,
            sampler,
            vertical_pipeline_id: pipeline[0],
            horizontal_pipeline_id: pipeline[1],
            vertical_pipeline_hdr_id: pipeline[2],
            horizontal_pipeline_hdr_id: pipeline[3],
        }
    }
}

#[derive(Component, Default, Clone, Copy, ExtractComponent, ShaderType)]
pub struct BlurSettings {
    circle_of_confusion: f32,
    // WebGL2 structs must be 16 byte aligned.
    #[cfg(target_family = "wasm")]
    _webgl2_padding: Vec3,
}

impl BlurSettings {
    pub fn new(circle_of_confusion: f32) -> Self {
        Self {
            circle_of_confusion,
            #[cfg(target_family = "wasm")]
            _webgl2_padding: Vec3::ZERO,
        }
    }
}
