/*
 *     Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * NVIDIA CORPORATION and its licensors retain all intellectual property
 * and proprietary rights in and to this software, related documentation
 * and any modifications thereto.  Any use, reproduction, disclosure or
 * distribution of this software and related documentation without an express
 * license agreement from NVIDIA CORPORATION is strictly prohibited.
 *
 *         THIS CODE AND INFORMATION ARE PROVIDED "AS IS" WITHOUT
 *  WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT
 *  NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR
 *  FITNESS FOR A PARTICULAR PURPOSE.
 */

#ifndef NVHPC_CURAND_RUNTIME_H
#define NVHPC_CURAND_RUNTIME_H

#include "curand_kernel.h"

#ifndef __NVHPC_CURAND_DEVICE
#define __NVHPC_CURAND_DEVICE __device__ static __inline__
#endif

/* XORWOW */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitXORWOW(long long seed, long long seq, long long offset,
                              signed char *h)
{
    curand_init((unsigned long long)seed, (unsigned long long)seq,
                (unsigned long long)offset, (curandStateXORWOW_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE int
__pgicudalib_curandGetXORWOW(signed char *h)
{
    return curand((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformXORWOW(signed char *h)
{
    return curand_uniform((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoubleXORWOW(signed char *h)
{
    return curand_uniform_double((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalXORWOW(signed char *h)
{
    return curand_normal((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal2xorwow_(signed char *dst, signed char *h)
{
    float2 fxy = curand_normal2((curandStateXORWOW_t *)h);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE float2
__pgicudalib_curandNormal2XORWOW(signed char *h)
{
    return curand_normal2((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoubleXORWOW(signed char *h)
{
    return curand_normal_double((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal2doublexorwow_(signed char *dst, signed char *h)
{
    double2 fxy = curand_normal2_double((curandStateXORWOW_t *)h);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE double2
__pgicudalib_curandNormal2DoubleXORWOW(signed char *h)
{
    return curand_normal2_double((curandStateXORWOW_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalXORWOW(signed char *h, float mean, float stddev)
{
    return curand_log_normal((curandStateXORWOW_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal2xorwow_(signed char *dst, signed char *h, float mean, float stddev)
{
    float2 fxy = curand_log_normal2((curandStateXORWOW_t *)h, mean, stddev);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE float2
__pgicudalib_curandLogNormal2XORWOW(signed char *h, float mean, float stddev)
{
    return curand_log_normal2((curandStateXORWOW_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoubleXORWOW(signed char *h, double mean,
                                         double stddev)
{
    return curand_log_normal_double((curandStateXORWOW_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal2doublexorwow_(signed char *dst, signed char *h, double mean, double stddev)
{
    double2 fxy = curand_log_normal2_double((curandStateXORWOW_t *)h, mean, stddev);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE double2
__pgicudalib_curandLogNormal2DoubleXORWOW(signed char *h, double mean, double stddev)
{
    return curand_log_normal2_double((curandStateXORWOW_t *)h, mean, stddev);
}

/* MRG32k3a */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitMRG32k3a(long long seed, long long seq, long long offset,
                                signed char *h)
{
    curand_init((unsigned long long)seed, (unsigned long long)seq,
                (unsigned long long)offset, (curandStateMRG32k3a_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE int
__pgicudalib_curandGetMRG32k3a(signed char *h)
{
    return curand((curandStateMRG32k3a_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformMRG32k3a(signed char *h)
{
    return curand_uniform((curandStateMRG32k3a_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoubleMRG32k3a(signed char *h)
{
    return curand_uniform_double((curandStateMRG32k3a_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalMRG32k3a(signed char *h)
{
    return curand_normal((curandStateMRG32k3a_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal2mrg32k3a_(signed char *dst, signed char *h)
{
    float2 fxy = curand_normal2((curandStateMRG32k3a_t *)h);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoubleMRG32k3a(signed char *h)
{
    return curand_normal_double((curandStateMRG32k3a_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal2doublemrg32k3a_(signed char *dst, signed char *h)
{
    double2 fxy = curand_normal2_double((curandStateMRG32k3a_t *)h);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalMRG32k3a(signed char *h, float mean, float stddev)
{
    return curand_log_normal((curandStateMRG32k3a_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal2mrg32k3a_(signed char *dst, signed char *h, float mean, float stddev)
{
    float2 fxy = curand_log_normal2((curandStateMRG32k3a_t *)h, mean, stddev);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoubleMRG32k3a(signed char *h, double mean,
                                           double stddev)
{
    return curand_log_normal_double((curandStateMRG32k3a_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal2doublemrg32k3a_(signed char *dst, signed char *h, double mean, double stddev)
{
    double2 fxy = curand_log_normal2_double((curandStateMRG32k3a_t *)h, mean, stddev);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

/* Philox4_32_10 */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitPhilox4_32_10(long long seed, long long seq,
                                     long long offset, signed char *h)
{
    curand_init((unsigned long long)seed, (unsigned long long)seq,
                (unsigned long long)offset, (curandStatePhilox4_32_10_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE int
__pgicudalib_curandGetPhilox4_32_10(signed char *h)
{
    return curand((curandStatePhilox4_32_10_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformPhilox4_32_10(signed char *h)
{
    return curand_uniform((curandStatePhilox4_32_10_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoublePhilox4_32_10(signed char *h)
{
    return curand_uniform_double((curandStatePhilox4_32_10_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curanduniform2doublephilox4_32_10_(signed char *dst, signed char *h)
{
    double2 fxy = curand_uniform2_double((curandStatePhilox4_32_10_t *)h);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curanduniform4philox4_32_10_(signed char *dst, signed char *h)
{
    float4 fxy = curand_uniform4((curandStatePhilox4_32_10_t *)h);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    ((float *)dst)[2] = fxy.z;
    ((float *)dst)[3] = fxy.w;
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curanduniform4doublephilox4_32_10_(signed char *dst, signed char *h)
{
    double4 fxy = curand_uniform4_double((curandStatePhilox4_32_10_t *)h);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    ((double *)dst)[2] = fxy.z;
    ((double *)dst)[3] = fxy.w;
    return;
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalPhilox4_32_10(signed char *h)
{
    return curand_normal((curandStatePhilox4_32_10_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal2philox4_32_10_(signed char *dst, signed char *h)
{
    float2 fxy = curand_normal2((curandStatePhilox4_32_10_t *)h);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal4philox4_32_10_(signed char *dst, signed char *h)
{
    float4 fxy = curand_normal4((curandStatePhilox4_32_10_t *)h);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    ((float *)dst)[2] = fxy.z;
    ((float *)dst)[3] = fxy.w;
    return;
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoublePhilox4_32_10(signed char *h)
{
    return curand_normal_double((curandStatePhilox4_32_10_t *)h);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal2doublephilox4_32_10_(signed char *dst, signed char *h)
{
    double2 fxy = curand_normal2_double((curandStatePhilox4_32_10_t *)h);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandnormal4doublephilox4_32_10_(signed char *dst, signed char *h)
{
    double4 fxy = curand_normal4_double((curandStatePhilox4_32_10_t *)h);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    ((double *)dst)[2] = fxy.z;
    ((double *)dst)[3] = fxy.w;
    return;
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalPhilox4_32_10(signed char *h, float mean,
                                          float stddev)
{
    return curand_log_normal((curandStatePhilox4_32_10_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal2philox4_32_10_(signed char *dst, signed char *h, float mean, float stddev)
{
    float2 fxy = curand_log_normal2((curandStatePhilox4_32_10_t *)h, mean, stddev);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal4philox4_32_10_(signed char *dst, signed char *h, float mean, float stddev)
{
    float4 fxy = curand_log_normal4((curandStatePhilox4_32_10_t *)h, mean, stddev);
    ((float *)dst)[0] = fxy.x;
    ((float *)dst)[1] = fxy.y;
    ((float *)dst)[2] = fxy.z;
    ((float *)dst)[3] = fxy.w;
    return;
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoublePhilox4_32_10(signed char *h, double mean,
                                                double stddev)
{
    return curand_log_normal_double((curandStatePhilox4_32_10_t *)h, mean,
                                    stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal2doublephilox4_32_10_(signed char *dst, signed char *h, double mean, double stddev)
{
    double2 fxy = curand_log_normal2_double((curandStatePhilox4_32_10_t *)h, mean, stddev);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandlognormal4doublephilox4_32_10_(signed char *dst, signed char *h, double mean, double stddev)
{
    double4 fxy = curand_log_normal4_double((curandStatePhilox4_32_10_t *)h, mean, stddev);
    ((double *)dst)[0] = fxy.x;
    ((double *)dst)[1] = fxy.y;
    ((double *)dst)[2] = fxy.z;
    ((double *)dst)[3] = fxy.w;
    return;
}

/* Sobol32 */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitSobol32(signed char *direction_vectors, int offset,
                               signed char *h)
{
    curand_init((unsigned int *)direction_vectors, (unsigned int)offset,
                (curandStateSobol32_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE int
__pgicudalib_curandGetSobol32(signed char *h)
{
    return curand((curandStateSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformSobol32(signed char *h)
{
    return curand_uniform((curandStateSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoubleSobol32(signed char *h)
{
    return curand_uniform_double((curandStateSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalSobol32(signed char *h)
{
    return curand_normal((curandStateSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoubleSobol32(signed char *h)
{
    return curand_normal_double((curandStateSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalSobol32(signed char *h, float mean, float stddev)
{
    return curand_log_normal((curandStateSobol32_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoubleSobol32(signed char *h, double mean,
                                          double stddev)
{
    return curand_log_normal_double((curandStateSobol32_t *)h, mean, stddev);
}

/* ScrambledSobol32 */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitScrambledSobol32(signed char *direction_vectors,
                                        int scramble_c, int offset,
                                        signed char *h)
{
    curand_init((unsigned int *)direction_vectors, (unsigned int)scramble_c,
                (unsigned int)offset, (curandStateScrambledSobol32_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE int
__pgicudalib_curandGetScrambledSobol32(signed char *h)
{
    return curand((curandStateScrambledSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformScrambledSobol32(signed char *h)
{
    return curand_uniform((curandStateScrambledSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoubleScrambledSobol32(signed char *h)
{
    return curand_uniform_double((curandStateScrambledSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalScrambledSobol32(signed char *h)
{
    return curand_normal((curandStateScrambledSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoubleScrambledSobol32(signed char *h)
{
    return curand_normal_double((curandStateScrambledSobol32_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalScrambledSobol32(signed char *h, float mean,
                                             float stddev)
{
    return curand_log_normal((curandStateScrambledSobol32_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoubleScrambledSobol32(signed char *h, double mean,
                                                   double stddev)
{
    return curand_log_normal_double((curandStateScrambledSobol32_t *)h, mean,
                                    stddev);
}

/* Sobol64 */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitSobol64(signed char *direction_vectors,
                               long long offset, signed char *h)
{
    curand_init((unsigned long long *)direction_vectors,
                (unsigned long long)offset, (curandStateSobol64_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE unsigned long long
__pgicudalib_curandGetSobol64(signed char *h)
{
    return curand((curandStateSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformSobol64(signed char *h)
{
    return curand_uniform((curandStateSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoubleSobol64(signed char *h)
{
    return curand_uniform_double((curandStateSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalSobol64(signed char *h)
{
    return curand_normal((curandStateSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoubleSobol64(signed char *h)
{
    return curand_normal_double((curandStateSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalSobol64(signed char *h, float mean, float stddev)
{
    return curand_log_normal((curandStateSobol64_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoubleSobol64(signed char *h, double mean,
                                          double stddev)
{
    return curand_log_normal_double((curandStateSobol64_t *)h, mean, stddev);
}

/* ScrambledSobol64 */

__NVHPC_CURAND_DEVICE void
__pgicudalib_curandInitScrambledSobol64(signed char *direction_vectors,
                                        long long scramble_c,
                                        long long offset,
                                        signed char *h)
{
    curand_init((unsigned long long *)direction_vectors,
                (unsigned long long)scramble_c, (unsigned long long)offset,
                (curandStateScrambledSobol64_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE unsigned long long
__pgicudalib_curandGetScrambledSobol64(signed char *h)
{
    return curand((curandStateScrambledSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandUniformScrambledSobol64(signed char *h)
{
    return curand_uniform((curandStateScrambledSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandUniformDoubleScrambledSobol64(signed char *h)
{
    return curand_uniform_double((curandStateScrambledSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandNormalScrambledSobol64(signed char *h)
{
    return curand_normal((curandStateScrambledSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandNormalDoubleScrambledSobol64(signed char *h)
{
    return curand_normal_double((curandStateScrambledSobol64_t *)h);
}

__NVHPC_CURAND_DEVICE float
__pgicudalib_curandLogNormalScrambledSobol64(signed char *h, float mean,
                                             float stddev)
{
    return curand_log_normal((curandStateScrambledSobol64_t *)h, mean, stddev);
}

__NVHPC_CURAND_DEVICE double
__pgicudalib_curandLogNormalDoubleScrambledSobol64(signed char *h, double mean,
                                                   double stddev)
{
    return curand_log_normal_double((curandStateScrambledSobol64_t *)h, mean,
                                    stddev);
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curand_skipahead_xorwow(long long n, signed char *h)
{
    skipahead((unsigned long long) n, (curandStateXORWOW_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curand_skipaheadseq_xorwow(long long n, signed char *h)
{
    skipahead_sequence((unsigned long long) n, (curandStateXORWOW_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curand_skipahead_mrg32k3a(long long n, signed char *h)
{
    skipahead((unsigned long long) n, (curandStateMRG32k3a_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curand_skipaheadseq_mrg32k3a(long long n, signed char *h)
{
    skipahead_sequence((unsigned long long) n, (curandStateMRG32k3a_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curand_skipahead_philox4_32_10(long long n, signed char *h)
{
    skipahead((unsigned long long) n, (curandStatePhilox4_32_10_t *)h);
    return;
}

__NVHPC_CURAND_DEVICE void
__pgicudalib_curand_skipaheadseq_philox4_32_10(long long n, signed char *h)
{
    skipahead_sequence((unsigned long long) n, (curandStatePhilox4_32_10_t *)h);
    return;
}

#endif /* NVHPC_CURAND_RUNTIME_H */
