Rcppとboostで微分方程式を解く

Rcppとboostで微分方程式を解く

 dx/dt=a*x^3+b*x^2+c*x+d

 x(0)=e
をとく。 - cppファイル

#include<Rcpp.h>
#include <iostream>
#include <array>
#include <boost/numeric/odeint.hpp>
using namespace Rcpp;

namespace odeint = boost::numeric::odeint;
using state_type = std::array< double, 1 >;

//[[Rcpp::export]]
NumericVector pow3(NumericVector a0){
        std::vector<double> time_log;
        //std::vector<state_type> statelog;
        std::vector<double> y_log;


         std::vector<double>  a = as< std::vector<double> >(a0);




        auto exponential = [&a](
          const state_type &x,
          state_type &dxdt,
          const double /* t */
        ) {
          dxdt[ 0 ] = a[4]*x[ 0 ]*x[ 0 ]*x[0]+a[3]*x[0]*x[0]+a[2]*x[0]+a[1]; // x`(t) = x(t)

        };

        auto x0 = state_type{ a[0] }; // 初期状態

        auto t0 = 0.0; // 開始パラメータ
        auto t1 = 1.0; // 終了パラメータ
        auto dt = 0.01; // ステップ


        if(1==1){
                auto stepper =
                  odeint::runge_kutta_fehlberg78< state_type >();

                odeint::integrate_const(
                  stepper, exponential, x0, t0, t1, dt,
                  [&](const state_type &x, const double t){
                  // std::cout << t << "\t"
                        //      << x[ 0 ] << std::endl;
                      time_log.push_back(t);
                y_log.push_back(x[0]);
                 } );
        }
     
        return wrap(y_log);
}
  • Rファイル
library(Rcpp)
Sys.setenv("PKG_CXXFLAGS"="-std=c++11 -I/opt/local/include -I/home/user1/usr/local/boost/include -L/opt/local/lib/ -L/home/user1/usr/local/boost/")
sourceCpp('pow3_eq4.cpp')
 a<-c(1,0,0,0,1.1)
 print(a)
 print(pow3(a))
 ans<-pow3(a)
 x0<-1
 a<-1.1
 x<-seq(0,1,by=0.01)
 print("exact")
 print(1.0/(1/x0^2-2*a*x*x0)^0.5)


 a<-c(1,0,0,1.1,0.0)
 print(a)
 print(pow3(a))

 x0<-1
 a<-1.1
 x<-seq(0,1,by=0.01)
 print("exact")
 print(x0/(1-a*x*x0)^1.0)


 a<-c(0.5,0,1.1,0.0,0.0)
 print(a)
 print(pow3(a))

 x0<-0.5
 a<-1.1
 x<-seq(0,1,by=0.01)
 print("exact")
 print(0.5*exp(1.1*x))

 a<-c(0.5,1.1,0.0,0.0,0.0)
 print(a)
 print(pow3(a))

 x0<-0.5
 a<-1.1
 x<-seq(0,1,by=0.01)
 print("exact")
 print((1.1*(x)+0.5))