# -*- shell-script -*-
#
# Copyright (c) 2019-2020 The University of Tennessee and The University
#                         of Tennessee Research Foundation.  All rights
#                         reserved.
# Copyright (c) 2020      Cisco Systems, Inc.  All rights reserved.
# Copyright (c) 2020      Research Organization for Information Science
#                         and Technology (RIST).  All rights reserved.
#
# $COPYRIGHT$
#
# Additional copyrights may follow
#
# $HEADER$
#

# MCA_ompi_op_avx_CONFIG([action-if-can-compile],
#                         [action-if-cant-compile])
# ------------------------------------------------
# We can always build, unless we were explicitly disabled.
AC_DEFUN([MCA_ompi_op_avx_CONFIG],[
    AC_CONFIG_FILES([ompi/mca/op/avx/Makefile])

    MCA_BUILD_OP_AVX_FLAGS=""
    MCA_BUILD_OP_AVX2_FLAGS=""
    MCA_BUILD_OP_AVX512_FLAGS=""
    op_sse3_support=0
    op_sse41_support=0
    op_avx_support=0
    op_avx2_support=0
    op_avx512_support=0

    AS_VAR_PUSHDEF([op_avx_check_sse3], [ompi_cv_op_avx_check_sse3])
    AS_VAR_PUSHDEF([op_avx_check_sse41], [ompi_cv_op_avx_check_sse41])
    AS_VAR_PUSHDEF([op_avx_check_avx], [ompi_cv_op_avx_check_avx])
    AS_VAR_PUSHDEF([op_avx_check_avx2], [ompi_cv_op_avx_check_avx2])
    AS_VAR_PUSHDEF([op_avx_check_avx512], [ompi_cv_op_avx_check_avx512])

    OPAL_VAR_SCOPE_PUSH([op_avx_cflags_save])

    AS_IF([test "$opal_cv_asm_arch" = "X86_64"],
          [AC_LANG_PUSH([C])

           #
           # Check for AVX512 support
           #
           AC_CACHE_CHECK([for AVX512 support], op_avx_check_avx512, AS_VAR_SET(op_avx_check_avx512, yes))
           AS_IF([test "$op_avx_check_avx512" = "yes"],
                 [AC_MSG_CHECKING([for AVX512 support (no additional flags)])
                  AC_LINK_IFELSE(
                      [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                       [[
#if defined(__ICC) && !defined(__AVX512F__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m512 vA, vB;
    _mm512_add_ps(vA, vB)
                                       ]])],
                      [op_avx512_support=1
                       AC_MSG_RESULT([yes])],
                      [AC_MSG_RESULT([no])])

                  AS_IF([test $op_avx512_support -eq 0],
                        [AC_MSG_CHECKING([for AVX512 support (with -mavx512f -mavx512bw -mavx512vl -mavx512dq)])
                         op_avx_cflags_save="$CFLAGS"
                         CFLAGS="-mavx512f -mavx512bw -mavx512vl -mavx512dq $CFLAGS"
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                              [[
#if defined(__ICC) && !defined(__AVX512F__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m512 vA, vB;
    _mm512_add_ps(vA, vB)
                                       ]])],
                             [op_avx512_support=1
                              MCA_BUILD_OP_AVX512_FLAGS="-mavx512f -mavx512bw -mavx512vl -mavx512dq"
                              AC_MSG_RESULT([yes])],
                             [AC_MSG_RESULT([no])])
                         CFLAGS="$op_avx_cflags_save"
                        ])
                  #
                  # Some combination of gcc and older as would not correctly build the code generated by
                  # _mm256_loadu_si256. Screen them out.
                  #
                  AS_IF([test $op_avx512_support -eq 1],
                        [AC_MSG_CHECKING([if _mm512_loadu_si512 generates code that can be compiled])
                         op_avx_cflags_save="$CFLAGS"
                         CFLAGS="$CFLAGS_WITHOUT_OPTFLAGS -O0 $MCA_BUILD_OP_AVX512_FLAGS"
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                      [[
#if defined(__ICC) && !defined(__AVX512F__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    int A[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    __m512i vA = _mm512_loadu_si512((__m512i*)&(A[1]))
                                      ]])],
                             [AC_MSG_RESULT([yes])],
                             [op_avx512_support=0
                              MCA_BUILD_OP_AVX512_FLAGS=""
                              AC_MSG_RESULT([no])])
                         CFLAGS="$op_avx_cflags_save"
                        ])
                  #
                  # Some PGI compilers do not define _mm512_mullo_epi64. Screen them out.
                  #
                  AS_IF([test $op_avx512_support -eq 1],
                        [AC_MSG_CHECKING([if _mm512_mullo_epi64 generates code that can be compiled])
                         op_avx_cflags_save="$CFLAGS"
                         CFLAGS="$CFLAGS_WITHOUT_OPTFLAGS -O0 $MCA_BUILD_OP_AVX512_FLAGS"
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                      [[
#if defined(__ICC) && !defined(__AVX512F__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m512i vA, vB;
    _mm512_mullo_epi64(vA, vB)
                                      ]])],
                             [AC_MSG_RESULT([yes])],
                             [op_avx512_support=0
                              MCA_BUILD_OP_AVX512_FLAGS=""
                              AC_MSG_RESULT([no])])
                         CFLAGS="$op_avx_cflags_save"
                        ])])
           #
           # Check support for AVX2
           #
           AC_CACHE_CHECK([for AVX2 support], op_avx_check_avx2, AS_VAR_SET(op_avx_check_avx2, yes))
           AS_IF([test "$op_avx_check_avx2" = "yes"],
                 [AC_MSG_CHECKING([for AVX2 support (no additional flags)])
                  AC_LINK_IFELSE(
                      [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                              [[
#if defined(__ICC) && !defined(__AVX2__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m256i vA, vB, vC;
    vC = _mm256_and_si256(vA, vB)
                              ]])],
                      [op_avx2_support=1
                       AC_MSG_RESULT([yes])],
                      [AC_MSG_RESULT([no])])
                  AS_IF([test $op_avx2_support -eq 0],
                      [AC_MSG_CHECKING([for AVX2 support (with -mavx2)])
                       op_avx_cflags_save="$CFLAGS"
                       CFLAGS="-mavx2 $CFLAGS"
                       AC_LINK_IFELSE(
                           [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                   [[
#if defined(__ICC) && !defined(__AVX2__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m256i vA, vB, vC;
    vC = _mm256_and_si256(vA, vB)
                                   ]])],
                           [op_avx2_support=1
                            MCA_BUILD_OP_AVX2_FLAGS="-mavx2"
                            AC_MSG_RESULT([yes])],
                           [AC_MSG_RESULT([no])])
                       CFLAGS="$op_avx_cflags_save"
                       ])
                  #
                  # Some combination of gcc and older as would not correctly build the code generated by
                  # _mm256_loadu_si256. Screen them out.
                  #
                  AS_IF([test $op_avx2_support -eq 1],
                        [AC_MSG_CHECKING([if _mm256_loadu_si256 generates code that can be compiled])
                         op_avx_cflags_save="$CFLAGS"
                         CFLAGS="$CFLAGS_WITHOUT_OPTFLAGS -O0 $MCA_BUILD_OP_AVX2_FLAGS"
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                      [[
#if defined(__ICC) && !defined(__AVX2__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    int A[8] = {0, 1, 2, 3, 4, 5, 6, 7};
    __m256i vA = _mm256_loadu_si256((__m256i*)&A)
                                      ]])],
                             [AC_MSG_RESULT([yes])],
                             [op_avx2_support=0
                              MCA_BUILD_OP_AVX2_FLAGS=""
                              AC_MSG_RESULT([no])])
                         CFLAGS="$op_avx_cflags_save"
                        ])])
           #
           # What about early AVX support? The rest of the logic is slightly different as
           # we need to include some of the SSE4.1 and SSE3 instructions. So, we first check
           # if we can compile AVX code without a flag, then we validate that we have support
           # for the SSE4.1 and SSE3 instructions we need. If not, we check for the usage of
           # the AVX flag, and then recheck if we have support for the SSE4.1 and SSE3
           # instructions.
           #
           AC_CACHE_CHECK([for AVX support], op_avx_check_avx, AS_VAR_SET(op_avx_check_avx, yes))
           AS_IF([test "$op_avx_check_avx" = "yes"],
                 [AC_MSG_CHECKING([for AVX support (no additional flags)])
                  AC_LINK_IFELSE(
                      [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                              [[
#if defined(__ICC) && !defined(__AVX__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m256 vA, vB, vC;
    vC = _mm256_add_ps(vA, vB)
                              ]])],
                      [op_avx_support=1
                       AC_MSG_RESULT([yes])],
                      [AC_MSG_RESULT([no])])])
           #
           # Check for SSE4.1 support
           #
           AC_CACHE_CHECK([for SSE4.1 support], op_avx_check_sse41, AS_VAR_SET(op_avx_check_sse41, yes))
           AS_IF([test $op_avx_support -eq 1 && test "$op_avx_check_sse41" = "yes"],
                 [AC_MSG_CHECKING([for SSE4.1 support])
                  AC_LINK_IFELSE(
                      [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                              [[
#if defined(__ICC) && !defined(__SSE4_1__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m128i vA, vB;
    (void)_mm_max_epi8(vA, vB)
                              ]])],
                      [op_sse41_support=1
                       AC_MSG_RESULT([yes])],
                      [AC_MSG_RESULT([no])])
                  ])
           #
           # Check for SSE3 support
           #
           AC_CACHE_CHECK([for SSE3 support], op_avx_check_sse3, AS_VAR_SET(op_avx_check_sse3, yes))
           AS_IF([test $op_avx_support -eq 1 && test "$op_avx_check_sse3" = "yes"],
                 [AC_MSG_CHECKING([for SSE3 support])
                  AC_LINK_IFELSE(
                      [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                              [[
#if defined(__ICC) && !defined(__SSE3__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    int A[4] = {0, 1, 2, 3};
    __m128i vA = _mm_lddqu_si128((__m128i*)&A)
                              ]])],
                      [op_sse3_support=1
                       AC_MSG_RESULT([yes])],
                      [AC_MSG_RESULT([no])])
                  ])
           # Second pass, do we need to add the AVX flag ?
           AS_IF([test $op_avx_support -eq 0 || test $op_sse41_support -eq 0 || test $op_sse3_support -eq 0],
                 [AS_IF([test "$op_avx_check_avx" = "yes"],
                        [AC_MSG_CHECKING([for AVX support (with -mavx)])
                         op_avx_cflags_save="$CFLAGS"
                         CFLAGS="-mavx $CFLAGS"
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                   [[
    __m256 vA, vB, vC;
#if defined(__ICC) && !defined(__AVX__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    vC = _mm256_add_ps(vA, vB)
                            ]])],
                             [op_avx_support=1
                              MCA_BUILD_OP_AVX_FLAGS="-mavx"
                              op_sse41_support=0
                              op_sse3_support=0
                              AC_MSG_RESULT([yes])],
                             [AC_MSG_RESULT([no])])])

                  AS_IF([test "$op_avx_check_sse41" = "yes" && test $op_sse41_support -eq 0],
                        [AC_MSG_CHECKING([for SSE4.1 support])
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                     [[
#if defined(__ICC) && !defined(__SSE4_1__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    __m128i vA, vB;
    (void)_mm_max_epi8(vA, vB)
                                     ]])],
                             [op_sse41_support=1
                              AC_MSG_RESULT([yes])],
                             [AC_MSG_RESULT([no])])])
                  AS_IF([test "$op_avx_check_sse3" = "yes" && test $op_sse3_support -eq 0],
                        [AC_MSG_CHECKING([for SSE3 support])
                         AC_LINK_IFELSE(
                             [AC_LANG_PROGRAM([[#include <immintrin.h>]],
                                 [[
#if defined(__ICC) && !defined(__SSE3__)
#error "icc needs the -m flags to provide the AVX* detection macros
#endif
    int A[4] = {0, 1, 2, 3};
    __m128i vA = _mm_lddqu_si128((__m128i*)&A)
                                 ]])],
                             [op_sse3_support=1
                              AC_MSG_RESULT([yes])],
                             [AC_MSG_RESULT([no])])])
                  CFLAGS="$op_avx_cflags_save"])

           AC_LANG_POP([C])
          ])
    AC_DEFINE_UNQUOTED([OMPI_MCA_OP_HAVE_AVX512],
                       [$op_avx512_support],
                       [AVX512 supported in the current build])
    AC_DEFINE_UNQUOTED([OMPI_MCA_OP_HAVE_AVX2],
                       [$op_avx2_support],
                       [AVX2 supported in the current build])
    AC_DEFINE_UNQUOTED([OMPI_MCA_OP_HAVE_AVX],
                       [$op_avx_support],
                       [AVX supported in the current build])
    AC_DEFINE_UNQUOTED([OMPI_MCA_OP_HAVE_SSE41],
                       [$op_sse41_support],
                       [SSE4.1 supported in the current build])
    AC_DEFINE_UNQUOTED([OMPI_MCA_OP_HAVE_SSE3],
                       [$op_sse3_support],
                       [SSE3 supported in the current build])
    AM_CONDITIONAL([MCA_BUILD_ompi_op_has_avx512_support],
                   [test "$op_avx512_support" == "1"])
    AM_CONDITIONAL([MCA_BUILD_ompi_op_has_avx2_support],
                   [test "$op_avx2_support" == "1"])
    AM_CONDITIONAL([MCA_BUILD_ompi_op_has_avx_support],
                   [test "$op_avx_support" == "1"])
    AM_CONDITIONAL([MCA_BUILD_ompi_op_has_sse41_support],
                   [test "$op_sse41_support" == "1"])
    AM_CONDITIONAL([MCA_BUILD_ompi_op_has_sse3_support],
                   [test "$op_sse3_support" == "1"])
    AC_SUBST(MCA_BUILD_OP_AVX512_FLAGS)
    AC_SUBST(MCA_BUILD_OP_AVX2_FLAGS)
    AC_SUBST(MCA_BUILD_OP_AVX_FLAGS)

    AS_VAR_POPDEF([op_avx_check_avx512])
    AS_VAR_POPDEF([op_avx_check_avx2])
    AS_VAR_POPDEF([op_avx_check_avx])
    AS_VAR_POPDEF([op_avx_check_sse41])
    AS_VAR_POPDEF([op_avx_check_sse3])

    OPAL_VAR_SCOPE_POP
    # Enable this component iff we have at least the most basic form of support
    # for vectorial ISA
    AS_IF([test $op_avx_support -eq 1 || test $op_avx2_support -eq 1 || test $op_avx512_support -eq 1],
          [$1],
          [$2])

])dnl
