/*
 *  Joris, Mar 2006, Jul 2006.
 *
 *  Test different low-level implementations of the 3n+1 iteration.
 *
 *  Compiling:
 *    i386:    gcc-3.4 -O2 -march=athlon -mtune=athlon -o collatz collatz.c
 *    x86_64:  gcc-3.4 -O2 -o collatz collatz.c
 *
 *  Results on an Athlon64 3400+ 2.4 GHz:
 *    arch    code                 parameters   usertime
 *    i386    IMPL_C               0 16          65.6
 *    i386    IMPL_ASM_BASE        0 16          63.4
 *    i386    IMPL_ASM_FAST        0 16          49.9
 *    x86_64  IMPL_C               0 16          43.3
 *    x86_64  IMPL_ASM64_FAST      0 16          32.5
 *    i386    IMPL_ASM_BASE_128    1000000 16    82.8
 *    i386    IMPL_ASM_FAST_128    1000000 16    70.2
 *    x86_64  IMPL_C_128           0 16          92.5
 *    x86_64  IMPL_C_128           1000000 16    93.7
 *    x86_64  IMPL_ASM64_FAST_128  1000000 16    43.3
 */

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <limits.h>

#define RANGE_BITS 28

#if ( defined(IMPL_ASM_BASE) || defined(IMPL_ASM_FAST) || \
      defined(IMPL_ASM_BASE_128)||defined(IMPL_ASM_FAST_128) ) && \
    !defined(__i386)
#error "Selected code only works on i386 architecture"
#endif

#if ( defined(IMPL_C_128) || \
      defined(IMPL_ASM64_FAST) || defined(IMPL_ASM64_FAST_128) ) && \
    !defined(__x86_64)
#error "Selected code only works on X86_64 architecture"
#endif

#if defined(__x86_64)
#define ARCH_NAME "x86_64"
#elif defined(__i386)
#define ARCH_NAME "i386"
#else
#define ARCH_NAME "unknown"
#endif

#ifdef IMPL_C_128
#define EVALBITS 128
/* 128-bit evaluation in C uses 128-bit integers */
typedef unsigned eval_t __attribute__ ((mode (TI)));
static eval_t maxv = 0;
#define maxv_lo ((unsigned long long)maxv)
#define maxv_hi ((unsigned long long)(maxv>>64))
#endif

#if defined(IMPL_ASM_BASE_128) || defined(IMPL_ASM_FAST_128) || \
    defined(IMPL_ASM64_FAST_128)
#define EVALBITS 128
/* 128-bit evaluation in assembler uses a pair of 64-bit integers */
static unsigned long long maxv_lo = 0, maxv_hi = 0;
#endif

#if defined(IMPL_C) || defined(IMPL_ASM_BASE) || defined(IMPL_ASM_FAST) || \
    defined(IMPL_ASM64_FAST)
#define EVALBITS 64
/* 64-bit evaluation uses 64-bit integers */
typedef unsigned long long eval_t;
static eval_t maxv = 0;
#endif

/* largest number of steps until convergence */
static unsigned int maxstep = 0;


#if defined(IMPL_C) || defined(IMPL_C_128)
/* Pure C code for 64-bit or 128-bit */
#define IMPL_NAME "C"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        unsigned int n = 0;
        eval_t v = x;
        while (1) {
                n++;
                if (v & 1) {
                        v = v + (v >> 1) + 1;
                } else {
                        if (v > maxv)
                                maxv = v;
                        v = v >> 1;
                        if (v < x)
                                break;
                }
        }
        *step = n;
}
#endif


#if defined(IMPL_ASM_BASE)
/* Baseline i386 assembler code for 64-bit evaluation */
#define IMPL_NAME "ASM_BASE"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        int dum0, dum1, dum2, dum3;
        *step = 0;
        /*
         * edx:eax  contains working value
         * ecx:ebx  used for temporary values
         *      %0  is the step counter (*step)
         *      %1  points to maxv
         *   %8:%7  is the starting number (x)
         */
        asm("\n"
          " movl %7, %%eax \n"
          " movl %8, %%edx \n"
                /* main iteration loop */
          "1: \n"
          " incl %0 \n"
          " testl $1, %%eax \n"
          " jz 2f \n"
                /* handle odd number; shift right and add */
          " movl %%edx, %%ecx \n"
          " shrl $1, %%edx \n"
          " movl %%eax, %%ebx \n"
          " rcrl $1, %%eax \n"
          " adcl %%ebx, %%eax \n"
          " adcl %%ecx, %%edx \n"
          " jnc 1b \n"
                /* overflow */
          " movl $-1, %0 \n"
          " jmp 5f \n"
          "2: \n"
                /* handle even number; compare to maxv */
          " cmpl 4+%1, %%edx \n"
          " jb 3f \n"
          " ja 4f \n"
          " cmpl %1, %%eax \n"
          " jbe 3f \n"
          "4: \n"
                /* working number is larger than maxv; update maxv */
          " movl %%edx, 4+%1 \n"
          " movl %%eax, %1 \n"
          "3: \n"
                /* shift right */
          " shrl $1, %%edx \n"
          " rcrl $1, %%eax \n"
                /* compare to x; loop while greater or equal */
          " cmpl %8, %%edx \n"
          " ja 1b \n"
          " jb 5f \n"
          " cmpl %7, %%eax \n"
          " jnb 1b \n"
          "5: "
         : "=&r" (*step), "=o" (maxv),
           "=&a" (dum0), "=&b" (dum1), "=&c" (dum2), "=&d" (dum3)
         : "0" (*step), "g" ((unsigned int)x), "g" ((unsigned int)(x >> 32)) );
}
#endif


#if defined(IMPL_ASM_FAST)
/* Optimized i386 assembler code for 64-bit evaluation */
#define IMPL_NAME "ASM_FAST"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        int dum0, dum1, dum2, dum3;
        *step = 0;
        /*
         * edx:eax  contains working value
         * ecx:ebx  used for temporary values
         *      %0  is the step counter (*step)
         *      %1  points to maxv
         *   %8:%7  is the starting number (x)
         *      %9  fixed to zero
         * odd/even branching is avoided through use of CMOV
         * 64-bit compare uses SBB to avoid extra branching 
         */
        asm("\n"
          " movl %7, %%eax \n"
          " movl %8, %%edx \n"
                /* main iteration loop */
          ".p2align 4 \n"
          "1: \n"
          " incl %0 \n"
                /* copy working value; shift right; zero temporary if even */
          " movl %%eax, %%ebx \n"
          " movl %%edx, %%ecx \n"
          " shrl $1, %%edx \n"
          " rcrl $1, %%eax \n"
          " cmovncl %9, %%ebx \n"
          " cmovncl %9, %%ecx \n"
                /* add temporary to working value; check for overflow */
          " adcl %%ebx, %%eax \n"
          " adcl %%ecx, %%edx \n"
          " jc 4f \n"
                /* compare to maxv */
          " cmpl 4+%1, %%edx \n"
          " jb 2f \n"
          " ja 3f \n"
          " cmpl %1, %%eax \n"
          " ja 3f \n"
          "2: \n"
                /* compare to x; loop while greater or equal */
          " movl %%edx, %%ecx \n"
          " cmpl %7, %%eax \n"
          " sbbl %8, %%ecx \n"
          " jnb 1b \n"
          " jmp 5f \n"
          "3: \n"
                /* update maxv */
          " movl %%edx, 4+%1 \n"
          " movl %%eax, %1 \n"
          " jmp 2b \n"
          "4: \n"
                /* overflow */
          " movl $-1, %0 \n"
          "5: "
         : "=&r" (*step), "=o" (maxv),
           "=&a" (dum0), "=&b" (dum1), "=&c" (dum2), "=&d" (dum3)
         : "0" (*step), "g" ((unsigned int)x), "g" ((unsigned int)(x >> 32)), "r" (0) );
}
#endif


#if defined(IMPL_ASM_BASE_128)
/* Baseline i386 assembler code for 128-bit evaluation */
#define IMPL_NAME "ASM_BASE_128"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        int dum0, dum1, dum2, dum3, dum4, dum5, dum6, dum7;
        *step = 0;
        /*
         * edx:ecx:ebx:eax  contains working value
         *    %10:%9:%8:%7  used for temporary values
         *              %0  is the step counter (*step)
         *           %2:%1  point to maxv_hi:maxv_lo
         *         %13:%12  is the starting number (x)
         */
        asm("\n"
          " movl %12, %%eax \n"
          " movl %13, %%ebx \n"
          " xorl %%ecx, %%ecx \n"
          " xorl %%edx, %%edx \n"
                /* main iteration loop */
          "1: \n"
          " incl %0 \n"
          " testl $1, %%eax \n"
          " jz 2f \n"
                /* handle odd number; shift right and add */
          " movl %%edx, %10 \n"
          " movl %%ecx, %9 \n"
          " movl %%ebx, %8 \n"
          " movl %%eax, %7 \n"
          " shrl $1, %%edx \n"
          " rcrl $1, %%ecx \n"
          " rcrl $1, %%ebx \n"
          " rcrl $1, %%eax \n"
          " adcl %7, %%eax \n"
          " adcl %8, %%ebx \n"
          " adcl %9, %%ecx \n"
          " adcl %10, %%edx \n"
          " jnc 1b \n"
                /* overflow */
          " movl $-1, %0 \n"
          " jmp 5f \n"
          "2: \n"
                /* handle even number; compare to maxv */
          " cmpl 4+%2, %%edx \n"
          " jb 4f \n"
          " ja 3f \n"
          " cmpl %2, %%ecx \n"
          " jb 4f \n"
          " ja 3f \n"
          " cmpl 4+%1, %%ebx \n"
          " jb 4f \n"
          " ja 3f \n"
          " cmpl %1, %%eax \n"
          " jbe 4f \n"
          "3: \n"
                /* working number is larger than maxv; update maxv */
          " movl %%edx, 4+%2 \n"
          " movl %%ecx, %2 \n"
          " movl %%ebx, 4+%1 \n"
          " movl %%eax, %1 \n"
          "4: \n"
                /* shift right */
          " shrl $1, %%edx \n"
          " rcrl $1, %%ecx \n"
          " rcrl $1, %%ebx \n"
          " rcrl $1, %%eax \n"
                /* compare to x; loop while greater or equal */
          " or %%edx, %%edx \n"
          " jnz 1b \n"
          " or %%ecx, %%ecx \n"
          " jnz 1b \n"
          " cmpl %13, %%ebx \n"
          " ja 1b \n"
          " jb 5f \n"
          " cmpl %12, %%eax \n"
          " jnb 1b \n"
          "5: "
         : "=&g" (*step), "=o" (maxv_lo), "=o" (maxv_hi),
           "=&a" (dum0), "=&b" (dum1), "=&c" (dum2), "=&d" (dum3),
           "=&g" (dum4), "=&g" (dum5), "=&g" (dum6), "=&g" (dum7)
         : "0" (*step), "g" ((unsigned int)x), "g" ((unsigned int)(x >> 32)) );
}
#endif


#if defined(IMPL_ASM_FAST_128)
/* Optimized i386 assembler code for 128-bit evaluation */
#define IMPL_NAME "ASM_FAST_128"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        int dum0, dum1, dum2, dum3, dum4, dum5, dum6, dum7;
        *step = 0;
        /*
         * edx:ecx:ebx:eax  contains working value
         *    %10:%9:%8:%7  used for temporary values
         *              %0  is the step counter (*step)
         *           %2:%1  point to maxv_hi:maxv_lo
         *         %13:%12  is the starting number (x)
         * basic blocks aligned and reorganized
         * branch elimination seems infeasible due to high register pressure
         */
        asm("\n"
          " movl %12, %%eax \n"
          " movl %13, %%ebx \n"
          " xorl %%ecx, %%ecx \n"
          " xorl %%edx, %%edx \n"
                /* main iteration loop */
          ".p2align 4 \n"
          "1: \n"
          " incl %0 \n"
          " testb $1, %%al \n"
          " jz 2f \n"
                /* handle odd number; shift right and add */
          " movl %%edx, %10 \n"
          " movl %%ecx, %9 \n"
          " movl %%ebx, %8 \n"
          " movl %%eax, %7 \n"
          " shrl $1, %%edx \n"
          " rcrl $1, %%ecx \n"
          " rcrl $1, %%ebx \n"
          " rcrl $1, %%eax \n"
          " adcl %7, %%eax \n"
          " adcl %8, %%ebx \n"
          " adcl %9, %%ecx \n"
          " adcl %10, %%edx \n"
          " jc 4f \n"
                /* compare to maxv */
          " cmpl 4+%2, %%edx \n"
          " jb 1b \n"
          " ja 3f \n"
          " cmpl %2, %%ecx \n"
          " jb 1b \n"
          " ja 3f \n"
          " cmpl 4+%1, %%ebx \n"
          " jb 1b \n"
          " ja 3f \n"
          " cmpl %1, %%eax \n"
          " jnb 1b \n"
          " jmp 3f \n"
          "2: \n"
                /* handle even number; shift right */
          " shrl $1, %%edx \n"
          " rcrl $1, %%ecx \n"
          " rcrl $1, %%ebx \n"
          " rcrl $1, %%eax \n"
                /* compare to x; loop while greater or equal */
          " or %%edx, %%edx \n"
          " jnz 1b \n"
          " or %%ecx, %%ecx \n"
          " jnz 1b \n"
          " cmpl %13, %%ebx \n"
          " ja 1b \n"
          " jb 5f \n"
          " cmpl %12, %%eax \n"
          " jnb 1b \n"
          " jmp 5f \n"
          "3: \n"
                /* working number is larger than maxv; update maxv */
          " movl %%edx, 4+%2 \n"
          " movl %%ecx, %2 \n"
          " movl %%ebx, 4+%1 \n"
          " movl %%eax, %1 \n"
          " jmp 1b \n"
          "4: \n"
                /* overflow */
          " movl $-1, %0 \n"
          "5: "
         : "=&g" (*step), "=o" (maxv_lo), "=o" (maxv_hi),
           "=&a" (dum0), "=&b" (dum1), "=&c" (dum2), "=&d" (dum3),
           "=&r" (dum4), "=&g" (dum5), "=&g" (dum6), "=&g" (dum7)
         : "0" (*step), "g" ((unsigned int)x), "g" ((unsigned int)(x >> 32)) );
}
#endif


#if defined(IMPL_ASM64_FAST)
/* Optimized x86_64 assembler code for 64-bit evaluation */
#define IMPL_NAME "ASM64_FAST"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        long dum0, dum1;
        unsigned int n = 0;
        /*
         * rax  contains working value
         *  %3  used for temporary values
         *  %0  is the step counter (n)
         *  %1  is maxv
         *  %6  is the starting number (x)
         *  %7  fixed to zero
         * most data dependent branching is avoided through use of CMOV
         */
        asm("\n"
          " movq %6, %%rax \n"
                /* main iteration loop */
          ".p2align 4 \n"
          "1: \n"
          " incl %0 \n"
                /* compare to maxv and conditionally update maxv */
          " cmpq %1, %%rax \n"
          " cmovaq %%rax, %1 \n"
                /* copy working value; shift right; zero temporary if even */
          " movq %%rax, %3 \n"
          " shrq $1, %%rax \n"
          " cmovncq %7, %3 \n"
                /* add temporary to working value; check for overflow */
          " adcq %3, %%rax \n"
          " jc 4f \n"
                /* compare to x; loop while greater or equal */
          " cmpq %6, %%rax \n"
          " jnb 1b \n"
          " jmp 5f \n"
          "4: \n"
                /* overflow */
          " movl $-1, %0 \n"
          "5:"
         : "=&r" (n), "=&r" (maxv), "=&a" (dum0), "=&r" (dum1)
         : "0" (n), "1" (maxv), "r" (x), "r" (0ULL) );
        *step = n;
}
#endif


#ifdef IMPL_ASM64_FAST_128
/* Optimized x86_64 assembler code for 128-bit evaluation */
#define IMPL_NAME "ASM64_FAST_128"
static inline void test_number(unsigned long long x, unsigned int *step)
{
        long dum0, dum1, dum2, dum3;
        unsigned int n = 0;
        /*
         * rdx:rax  contains working value
         *   %6:%5  used for temporary values
         *      %0  is the step counter (n)
         *   %2:%1  is maxv
         *     %10  is the starting number (x)
         *     %11  fixed to zero
         * most data dependent branching is avoided through use of CMOV
         * 128-bit compares use SBB or CMOV to avoid extra branching
         */
        asm("\n"
          " movq %10, %%rax \n"
          " xorl %%edx, %%edx \n"
          ".p2align 4 \n"
                /* main iteration loop */
          "1: \n"
          " incl %0 \n"
                /* compare to maxv and conditionally update maxv */
          " movq %%rdx, %6 \n"
          " cmpq %1, %%rax \n"
          " sbbq %2, %6 \n"
          " movq %%rax, %5 \n"
          " movq %%rdx, %6 \n"
          " cmovnbq %%rax, %1 \n"
          " cmovnbq %%rdx, %2 \n"
                /* shift working value right (already copied to temporary) */
          " shrq $1, %%rdx \n"
          " rcrq $1, %%rax \n"
                /* zero temporary if even; add temporary to working value */
          " cmovncq %11, %5 \n"
          " cmovncq %11, %6 \n"
          " adcq %5, %%rax \n"
          " adcq %6, %%rdx \n"
          " jc 4f \n"
                /* compare working value to x; loop while greater or equal */
          " movq %11, %5 \n"
          " cmovzq %10, %5 \n"
          " cmpq %5, %%rax \n"
          " jnb 1b \n"
          " jmp 5f \n"
          "4: \n"
                /* overflow */
          " movl $-1, %0 \n"
          "5:"
         : "=&r" (n), "=&r" (maxv_lo), "=&r" (maxv_hi),
           "=&a" (dum0), "=&d" (dum1), "=&r" (dum2), "=&r" (dum3)
         : "0" (n), "1" (maxv_lo), "2" (maxv_hi), "r" (x), "r" (0ULL) );
        *step = n;
}
#endif


/* Test a range of numbers */
static void test_range(unsigned int i)
{
        assert(RANGE_BITS > 2 && RANGE_BITS < 32);
        unsigned long long x = (((unsigned long long)i) << RANGE_BITS) + 3;
        unsigned int cnt = 1 << (RANGE_BITS - 2);
        for ( ; cnt > 0; cnt--, x += 4) {
                unsigned int step;
                test_number(x, &step);
                if (step > maxstep) {
                        maxstep = step;
                        printf("x=%-16llu maxstep=%-6u\n", x, maxstep);
                }
        }
}


int main(int argc, const char **argv)
{
        unsigned int i = 0;
        unsigned int range_start, range_len;
        printf("impl=%s bits=%d arch=%s\n", IMPL_NAME, EVALBITS, ARCH_NAME);
        assert(argc == 3);
        range_start = atoi(argv[1]);
        range_len = atoi(argv[2]);
        assert(range_start < INT_MAX && range_len > 0 && range_len < INT_MAX);
        printf("test range: %u ... %u\n", range_start, range_start+range_len);
        for (i = range_start; i < range_start + range_len; i++) {
                test_range(i);
                printf("x=%-16llu maxstep=%-6u ",
                  ((unsigned long long)(i+1)) << RANGE_BITS, maxstep);
#if EVALBITS == 128
                printf("maxv=%llu:%llu\n", maxv_hi, maxv_lo);
#elif EVALBITS == 64
                printf("maxv=%llu\n", maxv);
#else
#error "Hmmm?"
#endif
        }
        return 0;
}

/* end */