/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
!  Content:
!      oneMKL CBLAS_DAXPY OpenMP offload Example Program Text with Multi-stack devices
!******************************************************************************/

#include <stdio.h>
#include <omp.h>
#include "mkl.h"
#include "mkl_omp_offload.h"
#include "common.h"

int main() {

    double *x, *y, *y_ref, alpha;
    MKL_INT n, incx, incy, i;

    alpha = 1.0;

    n = 1120;

    incx = 1;
    incy = 1;

    // allocate matrices
    x = (double *)mkl_malloc(n * sizeof(double), 128);
    y = (double *)mkl_malloc(n * sizeof(double), 128);
    y_ref = (double *)mkl_malloc(n * sizeof(double), 128);

    if ((x == NULL) || (y == NULL) || (y_ref == NULL)) {
        printf("Cannot allocate matrices\n");
        return 1;
    }

    // initialize matrices
    init_double_array(n, x, 1);
    init_double_array(n, y, 1);

    for (i = 0; i < n; i++) {
       y_ref[i] = y[i];
    }

    MKL_INT bound_m = (n > 10) ? 10 : n;

    cblas_daxpy(n, alpha, x, incx, y_ref, incy);

    printf("Number of GPU subdevices: %d\n", omp_get_num_devices());

    if (omp_get_num_devices() > 1) {
        printf("Use GPU subdevice 0 for x[0..%d] and y[0..%d]\n", (int) n/2 - 1, (int) n/2 - 1);
#pragma omp target data map(to:x[0:n/2]) map(tofrom:y[0:n/2]) device(0)
        {
        printf("Use GPU subdevice 1 for x[%d..%d] and y[%d..%d]\n", (int) n/2, (int) n - 1, (int) n/2, (int) n - 1);
#pragma omp target data map(to:x[n/2:n-n/2]) map(tofrom:y[n/2:n-n/2]) device(1)
        {
            double *x1 = &x[n/2];
            double *y1 = &y[n/2];
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(0)
#else
#pragma omp target variant dispatch use_device_ptr(x, y) device(0) nowait
#endif
            cblas_daxpy(n/2, alpha, x, incx, y, incy);
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(1)
#else
#pragma omp target variant dispatch use_device_ptr(x1, y1) device(1)
#endif
            cblas_daxpy(n/2, alpha, x1, incx, y1, incy);
#pragma omp taskwait
        }
        }
    } else {
#pragma omp target data map(to:x[0:n]) map(tofrom:y[0:n]) device(0)
        {
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(0)
#else
#pragma omp target variant dispatch use_device_ptr(x, y) device(0)
#endif
            cblas_daxpy(n, alpha, x, incx, y, incy);
#pragma omp taskwait
        }
    }

    double real;
    int err = 0;
    for (i = 0; i < n; i++) {
        real = y[i] - y_ref[i];
        real = (real > 0) ? real : -real;
        if (real > 0.0001) {
#ifdef MKL_ILP64
            printf("y[%lld] != y_ref[%lld], computed value is %lf, reference value is %lf, difference is %lf\n",
                   i, i, y[i], y_ref[i], real);
#else
            printf("y[%d] != y_ref[%d], computed value is %lf, reference value is %lf, difference is %lf\n",
                   i, i, y[i], y_ref[i], real);
#endif
            mkl_free(x);
            mkl_free(y);
            mkl_free(y_ref);
            return 1;
        }
    }

    printf("First elements of the output vector Y:\n");
    printf("Y vector:\n");
    for (i = 0; i < bound_m; i++) {
        printf("%lf\t", y[i]);
    }
    printf("\n");

    printf("Reference vector:\n");
    for (i = 0; i < bound_m; i++) {
        printf("%lf\t", y_ref[i]);
    }
    printf("\n");
    mkl_free(x);
    mkl_free(y);
    mkl_free(y_ref);
    return 0;
}

