読者です 読者をやめる 読者になる 読者になる

MPFR を使ってみる

C/C++

参考 GNU MPFR 3.1.5

高精度の計算が必要になったので MPFR を使ってみた。 GMP を土台に初等関数を定義してあるので、素で GMP を使うよりは便利かと思う。

2 の平方根を 200桁求めてみる。最後の一桁を丸めるのに MPFR_RNDN (近いほうの値、誤差は 10^(-n) * (1/2) 以下)で統一したが、丸めかたは他にもあり参考サイトを参照。

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include <gmp.h>
#include <mpfr.h>

#define PRECISION_BITS 670 // 201 digits
#define BASE 10

mpfr_t MPFR_TRUE, MPFR_FALSE;

#define make_num(var,str) mpfr_t var; mpfr_init_set_str(var, str, BASE, MPFR_RNDN);

// 動的配列はあったほうが便利
// size の配列を確保して 0 初期化する
mpfr_t *mpfr_array(int32_t size)
{
    mpfr_t *v = malloc(sizeof(mpfr_t)*size);
    for (int32_t i=0; i<size; i++) mpfr_init_set_d(v[i], 0.0, MPFR_RNDN);
    return v;
}

// メモリ開放
void mpfr_free_array(mpfr_t *array, int32_t size)
{
    for (int32_t i=0; i<size; i++) mpfr_clear(array[i]);
    free(array);
}

// 配列プリント
void mpfr_print_array(mpfr_t *v, int32_t size)
{
    printf("[");
    for (int32_t i=0; i<size; i++) mpfr_printf("%.200RNf,", v[i]);
    printf("\b]");
}

// 区間分割 [a,b] を n 等分した区間を返す
mpfr_t *divide_interval(mpfr_t a, mpfr_t b, int32_t n)
{
    mpfr_t *v = mpfr_array(n+1);
    
    make_num(step, "0.0");
    make_num(d, "0.0");

    mpfr_sub(step, b, a, MPFR_RNDN);                 // step = b-a
    mpfr_div_ui(step, step, (uint64_t)n, MPFR_RNDN); // step = (b-a)/n

    mpfr_set(v[0], a, MPFR_RNDN);
    for (int32_t i=1; i<n; i++) {
        mpfr_mul_ui(d, step, (uint64_t)i, MPFR_RNDN); // d = step * i
        mpfr_add(v[i], a, d, MPFR_RNDN);  // v[i] = a + step * i
    }
    mpfr_set(v[n], b, MPFR_RNDN);

    mpfr_clears(step,d,(mpfr_ptr)0);
    
    return v;
}

// f(a), f(b) の符号が相異なるような区間 [a,b] を見つける。
// 返り値は [a,b,flag] 
// flag は見つかったら true, 見つからなければ false 
mpfr_t *find_subinterval(void (*f)(mpfr_t result, mpfr_t arg), mpfr_t *interval, int32_t size)
{
    mpfr_t *v = mpfr_array(3); // v = [a,b,flag]
    mpfr_t prev, tmp;
    mpfr_inits2(PRECISION_BITS, prev,tmp, (mpfr_ptr)0);
    f(tmp, interval[0]);
    mpfr_set(prev, tmp, MPFR_RNDN);
    mpfr_set(v[2], MPFR_FALSE, MPFR_RNDN); // set flag
    for (int32_t i=1; i<size; i++) {
        f(tmp, interval[i]);
        if (mpfr_sgn(prev) != mpfr_sgn(tmp)) {
            mpfr_set(v[0], interval[i-1], MPFR_RNDN);
            mpfr_set(v[1], interval[i], MPFR_RNDN);
            mpfr_set(v[2], MPFR_TRUE, MPFR_RNDN); // set flag
            break;
        } else {
            mpfr_set(prev, tmp, MPFR_RNDN);
        }
    }

    mpfr_clears(prev, tmp, (mpfr_ptr)0);
    
    return v;
}

// f の根を再帰的に求める
// PARTITION は滑らかな関数なら 2 で良いが、振動の激しい関数の場合は 10 くらいを試してみると良い
void find_root(mpfr_t result, void (*f)(mpfr_t result, mpfr_t arg), mpfr_t min, mpfr_t max, mpfr_t epsilon)
{
#define PARTITION 2

    mpfr_t a,b, tmp;
    mpfr_inits2(PRECISION_BITS, a,b,tmp, (mpfr_ptr)0);
    mpfr_add(tmp, min, epsilon, MPFR_RNDN);

    if (mpfr_cmp(max, tmp) < 0) {
        mpfr_add(result, min, max, MPFR_RNDN); // result = max+min
        mpfr_div_ui(result, result, 2, MPFR_RNDN); // result = (max+min)/2
        return;
    }

    mpfr_t *initerval = divide_interval(min,max,PARTITION);
    mpfr_t *sub_interval = find_subinterval(f, initerval, PARTITION+1);

    if (mpfr_cmp(sub_interval[2], MPFR_FALSE) == 0) {
        mpfr_printf("find_root : function has same sign over interval [%.200RNf, %.200RNf].\n", min, max);
        mpfr_free_array(initerval, PARTITION+1);
        mpfr_free_array(sub_interval, 3);
        return;
    }
    
    mpfr_set(a, sub_interval[0], MPFR_RNDN);
    mpfr_set(b, sub_interval[1], MPFR_RNDN);
    mpfr_free_array(initerval, PARTITION+1);
    mpfr_free_array(sub_interval, 3);

    find_root(result, f, a, b, epsilon);

    mpfr_clears(a,b,tmp,(mpfr_ptr)0);
    
#undef PARTITION
}

// f(x) = x^2 - 2
void f(mpfr_t result, mpfr_t x)
{
    mpfr_mul(result, x, x, MPFR_RNDN);         // result = x^2
    mpfr_sub_ui(result, result, 2, MPFR_RNDN); // result = x^2 - 2
}

int main (void)
{
    printf("MPFR VERSION : %s\n", MPFR_VERSION_STRING);

    // これは必須の設定
    mpfr_set_default_prec(PRECISION_BITS);

    mpfr_inits2(PRECISION_BITS, MPFR_TRUE, MPFR_FALSE, (mpfr_ptr)0);
    mpfr_set_d(MPFR_TRUE, 1.0, MPFR_RNDN);
    mpfr_set_d(MPFR_FALSE, -1.0, MPFR_RNDN);
    
    make_num(_1, "1.0");
    make_num(_2, "2.0");
    make_num(result, "0.0");
    mpfr_t epsilon;
    mpfr_init_set_si(epsilon, -200, MPFR_RNDN);
    mpfr_exp10(epsilon, epsilon, MPFR_RNDN); // epsilon = 10^(-200)

    find_root(result, f, _1, _2, epsilon);
    mpfr_printf("sqrt(2) = %.200RNf\n", result);
    
    mpfr_clears(result,epsilon,_1,_2,(mpfr_ptr)0);
    
    return 0;
}

コンパイル

gcc -lgmp -lmpfr hoge.c -o hoge

結果

MPFR VERSION : 3.1.4-p1
sqrt(2) = 1.41421356237309504880168872420969807856967187537694807317667973799073247846210703885038753432764157273501384623091229702492483605585073721264412149709993583141322266592750559275579995050115278206057147

mpfr 組込みの sqrt の結果と一致する。

mpfr_sqrt(result, _2, MPFR_RNDN);
mpfr_printf("sqrt(2) = %.200RNf\n", result);