#include "device_test.h"

#include <stdio.h>
#include <math.h>

int device_test( test_function test, uint32_t binding, void * task, void * input_a, void * input_b, void * expected, float epsilon ) {
    int ret = test( binding, task, input_a, input_b, expected, epsilon );

    printf( TEST_DONE );
    return ret;
}

int assert_float_near( float expected, float value, float abs_err ) {
    float abs_diff = fabs( expected - value );
    int pass = abs_diff <= abs_err;
    if ( ! pass ) {
        printf( TEST_FAIL " assert_float_near |%e - %e| = %e < %e\n", expected, value, abs_diff, abs_err );
    }

    return pass;
}

int assert_eq( uint32_t expected, uint32_t value ) {
    int pass = expected == value;
    if ( ! pass ) {
        printf( TEST_FAIL " assert_eq %" PRIx32" = %" PRIx32, expected, value );
    }

    return pass;
}

void print_float_data_for_python( const float_word * data, uint32_t len ) {
    printf( "py_float_data=[%e", data[ 0 ].value );
    for ( uint32_t i = 1; i < len; ++i ) {
        printf( ",%e", data[ i ].value );
    }
    printf( "]\n" );
}

void print_hex_data_for_python( const float_word * data, uint32_t len ) {
    printf( "py_hex_data=[0x%"PRIx32, data[ 0 ].word );
    for ( uint32_t i = 1; i < len; ++i ) {
        printf( ",0x%"PRIx32, data[ i ].word );
    }
    printf( "]\n" );
}

void print_results_for_python( uint32_t cycle_count, const float_word * data, uint32_t len ) {
    printf( "py_cycle_count=%"PRIu32"\n", cycle_count );
    print_float_data_for_python( data, len );
    print_hex_data_for_python( data, len );
}