!===============================================================================
! Copyright 2021-2022 Intel Corporation.
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

! Content:
! A simple example of batch single-precision real-to-complex, complex-to-real 
! in-place 2D FFT using Intel(R) oneAPI Math Kernel Library (oneMKL) DFTI
!
!*****************************************************************************

include "mkl_dfti_omp_offload.f90"

program sp_real_2d_batch
  use MKL_DFTI_OMP_OFFLOAD, forget => DFTI_SINGLE, DFTI_SINGLE => DFTI_SINGLE_R
  use omp_lib, ONLY : omp_get_num_devices
  use, intrinsic :: ISO_C_BINDING

  ! Sizes of 2D transform
  integer, parameter :: N1 = 7
  integer, parameter :: N2 = 13
  
  integer, parameter :: halfN1plus1 = N1/2 + 1
  
  ! Number of transforms
  integer, parameter :: M = 5
  
  ! Arbitrary harmonic to test the FFT
  integer, parameter :: H1 = 1
  integer, parameter :: H2 = 2

  ! need single precision
  integer, parameter :: WP = selected_real_kind(6,37)

  ! Execution status
  integer :: status = 0, ignored_status

  real(WP), allocatable :: x (:)
  
  type(DFTI_DESCRIPTOR), POINTER :: hand

  hand => null()

  print *,"Example sp_real_2d_batch"
  print *,"Batch forward and backward single-precision real-to-complex ",     &
    &    " and complex-to-real in-place 2D transform"
  print *,"Configuration parameters:"
  print *,"DFTI_PRECISION      = DFTI_SINGLE"
  print *,"DFTI_FORWARD_DOMAIN = DFTI_REAL"
  print *,"DFTI_DIMENSION      = 2"
  print '(" DFTI_NUMBER_OF_TRANSFORMS = "I0"")', M
  print '(" DFTI_LENGTHS        = /"I0","I0"/" )', N1, N2

  print *,"Create DFTI descriptor"
  status = DftiCreateDescriptor(hand, DFTI_SINGLE, DFTI_REAL, 2, [N1,N2])
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for CCE storage "
  status = DftiSetValue(hand, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for number of batch in-place transforms"
  status = DftiSetValue(hand, DFTI_NUMBER_OF_TRANSFORMS, M)
  if (0 /= status) goto 999
  
  print *,"Set input strides for DFTI descriptor (forward transform) "
  status = DftiSetValue(hand, DFTI_INPUT_STRIDES, [0, 1, 2*halfN1plus1])
  if (0 /= status) goto 999
  
  print *,"Set output strides for DFTI descriptor (forward transform) "
  status = DftiSetValue(hand, DFTI_OUTPUT_STRIDES, [0, 1, halfN1plus1])
  if (0 /= status) goto 999
  
  ! The memory address of the first element must coincide on I/O
  ! --> different DFTI_INPUT_DISTANCE and DFTI_OUTPUT_DISTANCE to be set
  ! (storing a complex number requires to store two real numbers)
  print *,"Set DFTI descriptor for input distance for forward FFT"
  status = DftiSetValue(hand, DFTI_INPUT_DISTANCE, 2*halfN1plus1*N2)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for output distance for forward FFT"
  status = DftiSetValue(hand, DFTI_OUTPUT_DISTANCE, halfN1plus1*N2)
  if (0 /= status) goto 999

  print *,"Commit DFTI descriptor for forward FFT"
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
  !$omp dispatch
#else
  !$omp target variant dispatch
#endif
  status = DftiCommitDescriptor(hand)
#if !defined(ONEMKL_USE_OPENMP_VERSION) || (ONEMKL_USE_OPENMP_VERSION < 202011)
  !$omp end target variant dispatch
#endif
  if (0 /= status) goto 999

  print *,"Allocate array for input/output data"
  allocate ( x(2*halfN1plus1*N2*M), STAT = status)
  if (0 /= status) goto 999

  print *,"Initialize input for forward transform"
  call init_r(x, M, N1, N2, H1, H2)

  print *,"Compute forward transform"
  !$omp target data map(tofrom:x)
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
  !$omp dispatch
#else
  !$omp target variant dispatch use_device_ptr(x)
#endif
  status = DftiComputeForward(hand, x)
#if !defined(ONEMKL_USE_OPENMP_VERSION) || (ONEMKL_USE_OPENMP_VERSION < 202011)
  !$omp end target variant dispatch
#endif
  !$omp end target data
  if (0 /= status) goto 999

  print *,"Verify the result"
  status = verify_c(x, M, N1, N2, H1, H2)
  if (0 /= status) goto 999

  print *,"Initialize input for backward transform"
  call init_c(x, M, N1, N2, H1, H2)
  
  ! reset the strides of input and output data accordingly:
  print *,"Set input strides for DFTI descriptor (backward transform)"
  status = DftiSetValue(hand, DFTI_INPUT_STRIDES, [0, 1, halfN1plus1])
  if (0 /= status) goto 999
  
  print *,"Set output strides for DFTI descriptor (backward transform) "
  status = DftiSetValue(hand, DFTI_OUTPUT_STRIDES, [0, 1, 2*halfN1plus1])
  if (0 /= status) goto 999
  
  ! DFTI_INPUT_DISTANCE and DFTI_OUTPUT_DISTANCE are to be reset for the 
  ! backward transform
  print *,"Set DFTI descriptor for input distance for backward FFT"
  status = DftiSetValue(hand, DFTI_INPUT_DISTANCE, halfN1plus1*N2)
  if (0 /= status) goto 999
  
  print *,"Set DFTI descriptor for output distance for backward FFT"
  status = DftiSetValue(hand, DFTI_OUTPUT_DISTANCE, 2*halfN1plus1*N2)
  if (0 /= status) goto 999
  
  ! we changed the configuration for the DFTI descriptor
  ! --> commit again
  print *,"Commit DFTI descriptor for backward transform "
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
  !$omp dispatch
#else
  !$omp target variant dispatch
#endif
  status = DftiCommitDescriptor(hand)
#if !defined(ONEMKL_USE_OPENMP_VERSION) || (ONEMKL_USE_OPENMP_VERSION < 202011)
  !$omp end target variant dispatch
#endif
  if (0 /= status) goto 999

  print *,"Compute backward transform"
  !$omp target data map(tofrom:x)
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
  !$omp dispatch
#else
  !$omp target variant dispatch use_device_ptr(x)
#endif
  status = DftiComputeBackward(hand, x)
#if !defined(ONEMKL_USE_OPENMP_VERSION) || (ONEMKL_USE_OPENMP_VERSION < 202011)
  !$omp end target variant dispatch
#endif
  !$omp end target data
  if (0 /= status) goto 999

  print *,"Verify the result"
  status = verify_r(x, M, N1, N2, H1, H2)
  if (0 /= status) goto 999

100 continue

  print *,"Release the DFTI descriptor"
  ignored_status = DftiFreeDescriptor(hand)

  if (allocated(x)) then
      print *,"Deallocate input data array"
      deallocate(x)
  endif

  if (status == 0) then
    print *,"TEST PASSED"
    call exit(0)
  else
    print *,"TEST FAILED"
    call exit(1)
  endif

999 print '("  Error, status = ",I0)', status
  goto 100

contains

  ! Compute mod(K*L,M) accurately
  pure real(WP) function moda(k,l,m)
    integer, intent(in) :: k,l,m
    integer*8 :: k8
    k8 = k
    moda = real(mod(k8*l,m),WP)
  end function moda

  ! Initialize array x to produce unit peaks at transform(H1, H2)
  subroutine init_r(x, M, N1, N2, H1, H2)
    integer M, N1, N2, H1, H2
    real(WP) :: x(:)

    integer j, k1, k2, halfN1plus1
    real(WP), parameter :: TWOPI = 6.2831853071795864769_WP
    real(WP) :: TWOPI_phase, factor
    if(mod(2*(N1-H1), N1) == 0 .AND. mod(2*(N2-H2), N2) == 0) then
      factor = 1.0_WP
    else
      factor = 2.0_WP
    endif
    
    halfN1plus1 = N1/2 + 1
    do j = 1, M
      do k2 = 1, N2
        do k1 = 1, N1
          TWOPI_phase = TWOPI*(moda(k1-1,H1,N1) / N1 + moda(k2-1,H2,N2) / N2)
          x((j-1)*2*halfN1plus1*N2 + (k2-1)*2*halfN1plus1 + k1) = &
              factor*cos(TWOPI_phase)/(N1*N2)
        end do
      end do 
    end do
  end subroutine init_r

  ! Verify that y is unit peak at (H1,H2)
  integer function verify_c(y, M, N1, N2, H1, H2)
    integer M, N1, N2, H1, H2
    real(WP) :: y(:)

    integer j, k1, k2, halfN1plus1, idx
    real(WP) err, errthr, maxerr
    complex(WP) :: res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N1*N2,WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    halfN1plus1 = N1/2 + 1
    maxerr = 0.0_WP
    do j = 1, M
      do k2 = 1, N2
        do k1 = 1, halfN1plus1
          if  (   (mod(k1-1-H1, N1)==0 .AND. mod(k2-1-H2, N2)==0) &
            .OR.  (mod(1-k1-H1, N1)==0 .AND. mod(1-k2-H2, N2)==0)) then
            res_exp = 1.0_WP
          else
            res_exp = 0.0_WP
          end if
          idx = (j-1)*2*halfN1plus1*N2 + (k2-1)*2*halfN1plus1 + 2*k1 - 1
          res_got = CMPLX(y(idx), y(idx + 1))
          err = abs(res_got - res_exp)
          maxerr = max(err,maxerr)
          if (.not.(err < errthr)) then
            print '(" Batch #"I0" y("I0","I0"):"$)', j, k1, k2
            print '(" expected ("G14.7","G14.7"),"$)', res_exp
            print '(" got ("G14.7","G14.7"),"$)', res_got
            print '(" err "G10.3)', err
            print *," Verification FAILED"
            verify_c = 100
            return
          end if
        end do
      end do
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify_c = 0
  end function verify_c
  
  ! Initialize complex array y to produce unit peaks at x(H1, H2)
  subroutine init_c(y, M, N1, N2, H1, H2)
    integer M, N1, N2, H1, H2
    real(WP) :: y(:)

    integer j, k1, k2, halfN1plus1, idx
    real(WP), parameter :: TWOPI = 6.2831853071795864769_WP
    real(WP) :: TWOPI_phase
    
    halfN1plus1 = N1/2 + 1
    do j = 1, M
      do k2 = 1, N2
        do k1 = 1, halfN1plus1
          TWOPI_phase = TWOPI*(moda(k1-1, H1, N1)/N1 + moda(k2-1, H2, N2)/N2)
          idx = (j-1)*2*halfN1plus1*N2 + (k2-1)*2*halfN1plus1 + 2*k1 - 1
          y(idx)      =  cos(TWOPI_phase)/(N1*N2)
          y(idx + 1)  = -sin(TWOPI_phase)/(N1*N2)
        end do
      end do
    end do
  end subroutine init_c
  
  ! Verify that x(k1, k2) is unit peak at k1 = H1 and k2 = H2
  integer function verify_r(x, M, N1, N2, H1, H2)
    integer M, N1, N2, H1, H2
    real(WP) :: x(:)

    integer j, k1, k2, halfN1plus1
    real(WP) err, errthr, maxerr
    real(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N1*N2, WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    halfN1plus1 = N1/2 +1 
    maxerr = 0.0_WP
    do j = 1, M
      do k2 = 1, N2
        do k1 = 1, N1
          if (mod(k1-1-H1,N1)==0 .AND. mod(k2-1-H2, N2)==0) then
            res_exp = 1.0_WP
          else
            res_exp = 0.0_WP
          end if
          res_got = x((j-1)*2*halfN1plus1*N2 + (k2-1)*2*halfN1plus1 + k1)
          err = abs(res_got - res_exp)
          maxerr = max(err,maxerr)
          if (.not.(err < errthr)) then
            print '(" Batch #"I0" x("I0","I0"): "$)', j, k1,k2
            print '(" expected "G14.7","$)', res_exp
            print '(" got "G14.7","$)', res_got
            print '(" err "G10.3)', err
            print *," Verification FAILED"
            verify_r = 100
            return
          end if
        end do
      end do
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify_r = 0
  end function verify_r

end program sp_real_2d_batch
