/* $OpenXM: OpenXM/src/hgm/so3/src/so3_nc.c,v 1.9 2019/11/16 10:57:21 takayama Exp $ */ #include #include #include #ifdef USE_GSL_LIB #include #include #else #include "t-gsl_errno.h" #include "t-gsl_odeiv.h" #endif #include "oxprint.h" #ifndef STANDALONE void mh_check_intr(int n); #endif /* gcc evalnc.c -lgsl -lblas -lm */ /* gcc evalnc.c `pkg-config --cflags gsl` `pkg-config --libs gsl` */ int so3_func(); void so3_nc(double a[3],double t0,double y[4]); void so3_evalByS(int deg,double a,double b,double c,double t,double f[4]); int so3_usage(void); #define MDEG 20 #define SO3_QUIET_DEFAULT 0 #define SO3_Deg_DEFAULT 10 #define SO3_Log_DEFAULT 0 int SO3_Quiet = SO3_QUIET_DEFAULT; int SO3_Deg = SO3_Deg_DEFAULT; int SO3_Log = SO3_Log_DEFAULT; #ifdef STANDALONE int main(int argc,char *argv[]) { double a[3]; double y[4]; double t0; int i,j; t0 = 0.0001; /* Small enough number seems to be good (Hueristics) */ j = 0; for (i=1; i MDEG-2) || (SO3_Deg < 0)) { oxprintfe("Error: deg should be less than %d\n",MDEG-2); return(-1); } } else if (j<3) sscanf(argv[i],"%lg",&(a[j++])); } if (j != 3) { so3_usage(); return(-1); } so3_nc(a,t0,y); // oxprintf("%lg, %lg, %lg, %lg\n",y[0],y[1],y[2],y[3]); if (SO3_Log) oxprintf("log of nc : "); oxprintf("%.15e, %.15e, %.15e, %.15e\n",y[0],y[1],y[2],y[3]); } #else void so3_main(double *in1,double *in2,double *in3,double *t0p,int *quiet,int *deg,int *log,double *out) { double a[3]; double y[4]; double t0; int i; SO3_Quiet = SO3_QUIET_DEFAULT; SO3_Deg = SO3_Deg_DEFAULT; SO3_Log = SO3_Log_DEFAULT; if (*quiet) SO3_Quiet = 1; if (*deg) { SO3_Deg = *deg; if (!SO3_Quiet) oxprintfe("deg is set to %d\n",SO3_Deg); } if (*log) SO3_Log=1; t0 = 0.0001; /* Small enough number seems to be good (Hueristics) */ if (*t0p > 0.0) { t0 = *t0p; if (!SO3_Quiet) oxprintfe("t0 is set to %lf\n",t0); } // j = 0; if ((SO3_Deg > MDEG-2) || (SO3_Deg < 0)) { oxprintfe("Error: deg should be less than %d\n",MDEG-2); *out = 0.0; return; } a[0] = *in1; a[1] = *in2; a[2] = *in3; // oxprintfe("DEBUG: %lf,%lf,%lf,%lf\n",t0,a[0],a[1],a[2]); so3_nc(a,t0,y); // oxprintf("%lg, %lg, %lg, %lg\n",y[0],y[1],y[2],y[3]); if ((!SO3_Quiet) && SO3_Log) oxprintf("log of nc : "); if (!SO3_Quiet) oxprintf("%.15e, %.15e, %.15e, %.15e\n",y[0],y[1],y[2],y[3]); for (i=0; i<4; i++) out[i] = y[i]; return; } #endif int so3_usage(void) { oxprintfe("Usage: so3_nc a b c returns nc(a,b,c) and its gradients\n"); oxprintfe(" where nc is the normalization constant\n" ); oxprintfe(" of the Fisher distribution on SO(3) for the diagonal matrix diag(a,b,c).\n"); oxprintfe(" See http://arxiv.org/abs/1110.0721\n"); oxprintfe("Options: --quiet --t0 T0 --deg DEG --log\n"); oxprintfe(" Series is evaluated at T0*(a,b,c) and the value is extended to (a,b,c) by diff. eq. With --log, log(nc(a,b,c)) is returned.\n"); return(0); } /* Evaluate normalization constant */ double SO3_A[3]; double SO3_R; /* rho */ int so3_func(); void so3_nc(double a[3],double t0,double y[4]) { int i; int deg; double r; double y0[4]; double myerr,myerr2; double aa; deg = SO3_Deg; for (i=0; i<3; i++) SO3_A[i]=a[i]; /* When the argument is small, eval it only by series */ aa = 0.0; for (i=0; i<3; i++) aa += a[i]*a[i]; if (aa < 0.01) { so3_evalByS(deg,a[0],a[1],a[2], 1.0, y); return; } SO3_R = 0.0; r = a[0]-a[1]-a[2]; if (r > SO3_R) SO3_R = r; r = -a[0]+a[1]-a[2]; if (r > SO3_R) SO3_R = r; r = -a[0]-a[1]+a[2]; if (r > SO3_R) SO3_R = r; r = a[0]+a[1]+a[2]; if (r > SO3_R) SO3_R = r; if (!SO3_Quiet) oxprintfe("SO3_R=%lg, exp(SO3_R t) is the asymptotics of nc.\n",SO3_R); const gsl_odeiv_step_type *T = gsl_odeiv_step_rkf45; gsl_odeiv_step *s = gsl_odeiv_step_alloc(T, 4); /* rank4 */ /* * Absolute error 1e-6 * Relative error 0.0. */ gsl_odeiv_control *c = gsl_odeiv_control_y_new(1e-6, 0.0); /* rank 4 system */ gsl_odeiv_evolve *e = gsl_odeiv_evolve_alloc(4); gsl_odeiv_system sys = {so3_func, NULL, 4, NULL}; /* t : start, t1 : goal */ double t = t0, t1 = 1.0; double h = 1e-6; if (!SO3_Quiet) { oxprintfe("Set initial values at t0*(a,b,c) by evaluating series and find relevant t0.\n"); oxprintfe("t0=%lf, a=a[0]=%lf, b=a[1]=%lf, c=a[2]=%lf\n",t0,a[0],a[1],a[2]); } do { if (!SO3_Quiet) oxprintfe("t0=%lf\n",t0); so3_evalByS(deg,a[0],a[1],a[2], t0, y0); if (!SO3_Quiet) oxprintfe("[%2d]: %lg,%lg,%lg,%lg\n",deg,y0[0],y0[1],y0[2],y0[3]); so3_evalByS(deg+1,a[0],a[1],a[2], t0, y); if (!SO3_Quiet) oxprintfe("[%2d]: %lg,%lg,%lg,%lg\n",deg+1,y[0],y[1],y[2],y[3]); myerr=0.0; for (i=0; i<4; i++) { myerr2 = (y0[i]-y[i])/y0[i]; if (myerr2 <0) myerr2 = -myerr2; if (myerr2 > myerr) myerr = myerr2; /* sup norm */ } if (myerr < 1e-6) break; /* should take smaller value? */ t0 = t0/2.0; } while (1); t=t0; for (i=0; i<4; i++) y[i]=y[i]*exp(-SO3_R*t0); if (!SO3_Quiet) oxprintfe("[%2d]*exp(-SO3_R*t0): %lg,%lg,%lg,%lg\n",deg+1,y[0],y[1],y[2],y[3]); if (!SO3_Quiet) oxprintfe("Result by HGM (solving ODE) ------ \n"); while (t < t1) { int status = gsl_odeiv_evolve_apply(e, c, s, &sys, &t, t1, &h, y); if (status != GSL_SUCCESS) break; } if (!SO3_Quiet) oxprintfe("t and V : t=%.5e %.5e %.5e, %.5e %.5e\n", t, y[0], y[1],y[2],y[3]); if (!SO3_Log) { for (i=0; i<4; i++) y[i]=y[i]*exp(SO3_R*t); if (!SO3_Quiet) oxprintfe("V*exp(SO3_R*1): t= %.5e %.5e %.5e, %.5e %.5e\n", t, y[0], y[1],y[2],y[3]); if (!SO3_Quiet) oxprintfe("Returned value is V=[so3_nc(a,b,c), c_a, c_b, c_c]\n"); }else{ for (i=0; i<4; i++) y[i]=log(y[i]) + SO3_R*t; if (!SO3_Quiet) oxprintfe("log(V*exp(SO3_R*1)): t= %.5e %.5e %.5e, %.5e %.5e\n", t, y[0], y[1],y[2],y[3]); if (!SO3_Quiet) oxprintfe("Returned value is V=[log(so3_nc(a,b,c)), log(c_a), log(c_b), log(c_c)]\n"); /* test input: ./hgm_so3_nc --t0 0.001 --deg 18 --log 1000 500 50 2019.11.16 when we change the argument number, change also mh-r.c (interface module. */ } gsl_odeiv_evolve_free(e); gsl_odeiv_control_free(c); gsl_odeiv_step_free(s); } /* From FB2/Prog/gls_ode_test2b.c */ /* d/dt y_i = f_i(t, y_1, ..., y_n), f_i=f[i] y[0] = y, y[1]=y'. Ref: Note: See the Corollary 1 of the paper rotation.c on this function func. */ int so3_func(double t, const double y[], double f[], void *params) { extern double SO3_A[3]; #ifndef STANDALONE mh_check_intr(100); #endif f[0] = SO3_A[0]*y[1]+SO3_A[1]*y[2]+SO3_A[2]*y[3] -SO3_R*y[0]; f[1] = SO3_A[0]*y[0]+SO3_A[2]*y[2]+SO3_A[1]*y[3] - (2/t)*y[1] -SO3_R*y[1]; f[2] = SO3_A[1]*y[0]+SO3_A[2]*y[1]+SO3_A[0]*y[3] - (2/t)*y[2] -SO3_R*y[2]; f[3] = SO3_A[2]*y[0]+SO3_A[1]*y[1]+SO3_A[0]*y[2] - (2/t)*y[3] -SO3_R*y[3]; return GSL_SUCCESS; } /* Evaluation of nc by a series */ void so3_evalByS(int deg,double a,double b,double c,double t,double f[4]) { double Tnc[MDEG][MDEG][MDEG]; double ex[MDEG][MDEG][MDEG]; int i,j,k; if (deg >= MDEG) { oxprintfe("Error: degree is too high\n"); } for (i=0; i0) ex[i][j][k] = ex[i-1][j][k]*(a*t); else if (j>0) ex[i][j][k] = ex[i][j-1][k]*(b*t); else if (k>0) ex[i][j][k] = ex[i][j][k-1]*(c*t); f[0] += (ex[i][j][k])*(Tnc[i][j][k]); if (i>0) f[1] += ((double) i)*(ex[i-1][j][k])*(Tnc[i][j][k]); if (j>0) f[2] += ((double) j)*(ex[i][j-1][k])*(Tnc[i][j][k]); if (k>0) f[3] += ((double) k)*(ex[i][j][k-1])*(Tnc[i][j][k]); } } } }