Commit 54f75cbe authored by Kirill Terekhov's avatar Kirill Terekhov

Fixes for hessian calculation

+unit tests
parent fd6a6a92
......@@ -403,15 +403,19 @@ namespace INMOST
assert(HR.isSorted());
output.Resize(HL.Size()+HR.Size()+JL.Size()*JR.Size());
INMOST_DATA_ENUM_TYPE i = 0, j = 0, k1 = 0, l1 = 0, k2 = 0, l2 = 0, q = 0, kk1 = 0, ll2 = 0, r;
while(kk1 < JL.Size() && JL.GetIndex(kk1) < JR.GetIndex(l1) ) kk1++;
k1 = kk1;
while(ll2 < JR.Size() && JL.GetIndex(k2) > JR.GetIndex(ll2) ) ll2++;
l2 = ll2;
entry candidate[4] = {stub_entry,stub_entry,stub_entry,stub_entry};
if( i < HL.Size() )
candidate[0] = make_entry(HL.GetIndex(i),b*HL.GetValue(i));
if( j < HR.Size() )
candidate[1] = make_entry(HR.GetIndex(j),c*HR.GetValue(j));
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),0.5*a*JL.GetValue(k1)*JR.GetValue(l1));
candidate[2] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(JL.GetIndex(k1)==JR.GetIndex(l1)?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[3] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),0.5*a*JL.GetValue(k2)*JR.GetValue(l2));
candidate[3] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(JL.GetIndex(k2)==JR.GetIndex(l2)?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
do
{
//pick smallest
......@@ -454,7 +458,7 @@ namespace INMOST
}
}
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(k1==l1?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
candidate[2] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(JL.GetIndex(k1)==JR.GetIndex(l1)?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
else
candidate[2] = stub_entry;
}
......@@ -470,7 +474,7 @@ namespace INMOST
}
}
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[3] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(k2==l2?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
candidate[3] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(JL.GetIndex(k2)==JR.GetIndex(l2)?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
else
candidate[3] = stub_entry;
}
......@@ -489,13 +493,17 @@ namespace INMOST
assert(H.isSorted());
output.Resize(H.Size()+JL.Size()*JR.Size());
INMOST_DATA_ENUM_TYPE i = 0, k1 = 0, l1 = 0, q = 0, k2 = 0, l2 = 0, r, ll2 = 0, kk1 = 0;
while(kk1 < JL.Size() && JL.GetIndex(kk1) < JR.GetIndex(l1) ) kk1++;
k1 = kk1;
while(ll2 < JR.Size() && JL.GetIndex(k2) > JR.GetIndex(ll2) ) ll2++;
l2 = ll2;
entry candidate[3] = {stub_entry,stub_entry,stub_entry};
if( i < H.Size() )
candidate[0] = make_entry(H.GetIndex(i),b*H.GetValue(i));
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[1] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),0.5*a*JL.GetValue(k1)*JR.GetValue(l1));
candidate[1] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(JL.GetIndex(k1)==JR.GetIndex(l1)?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),0.5*a*JL.GetValue(k2)*JR.GetValue(l2));
candidate[2] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(JL.GetIndex(k2)==JR.GetIndex(l2)?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
do
{
//pick smallest
......@@ -533,7 +541,7 @@ namespace INMOST
}
}
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[1] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(k1==l1?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
candidate[1] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(JL.GetIndex(k1)==JR.GetIndex(l1)?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
else
candidate[1] = stub_entry;
}
......@@ -549,7 +557,7 @@ namespace INMOST
}
}
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(k2==l2?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
candidate[2] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(JL.GetIndex(k2)==JR.GetIndex(l2)?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
else
candidate[2] = stub_entry;
}
......
if(USE_AUTODIFF)
add_subdirectory(autodiff_test000)
add_subdirectory(autodiff_test001)
add_subdirectory(autodiff_test002)
endif(USE_AUTODIFF)
if(USE_SOLVER)
......
......@@ -13,6 +13,7 @@ int main(int argc,char ** argv)
double _dx, _dy, _dxdx, _dydy, _dxdy;
unknown x(_x,0), y(_y,1);
hessian_variable f;
variable f2;
if( test == 0 ) //check derivative and hessian of sin(x*x+y*y)
......@@ -34,6 +35,7 @@ int main(int argc,char ** argv)
_dydy = 2*cos(_x*_x+_y*_y)-4*sin(_x*_x+_y*_y)*_y*_y;
_dxdy = -8*sin(_x*_x+_y*_y)*_x*_y;
f = sin(x*x+y*y);
f2 = sin(x*x + y*y);
}
else if( test == 1 )
{
......@@ -55,6 +57,7 @@ int main(int argc,char ** argv)
_dydy = 2*cos(_x*_x+_y*_y+_x*_y)-sin(_x*_x+_y*_y+_x*_y)*(_x+2*_y)*(_x+2*_y);
_dxdy = 2*cos(_x*_x+_y*_y+_x*_y)-2*sin(_x*_x+_y*_y+_x*_y)*(2*_x+_y)*(_x+2*_y);
f = sin(x*x+y*y+x*y);
f2 = sin(x*x + y*y + x*y);
}
else if( test == 2 )
{
......@@ -76,6 +79,7 @@ int main(int argc,char ** argv)
_dydy = -2*sin(_x*_x+_y*_y+_x*_y)-cos(_x*_x+_y*_y+_x*_y)*(_x+2*_y)*(_x+2*_y);
_dxdy = -2*sin(_x*_x+_y*_y+_x*_y)-2*cos(_x*_x+_y*_y+_x*_y)*(2*_x+_y)*(_x+2*_y);
f = cos(x*x+y*y+x*y);
f2 = cos(x*x + y*y + x*y);
}
else if( test == 3 )
{
......@@ -93,6 +97,7 @@ int main(int argc,char ** argv)
_dxdy = 2*cos(_x);
_dydy = 0;
f = sin(x)*y;
f2 = sin(x)*y;
}
else if( test == 4 )
{
......@@ -110,6 +115,7 @@ int main(int argc,char ** argv)
_dxdy = -2*sin(_x);
_dydy = 0;
f = cos(x)*y;
f2 = cos(x)*y;
}
else if( test == 5 )
{
......@@ -127,6 +133,7 @@ int main(int argc,char ** argv)
_dxdy = -6*_x*_y/(4*pow(_x*_x+_x*_y+_y*_y,1.5));
_dydy = 3*_x*_x/(4*pow(_x*_x+_x*_y+_y*_y,1.5));
f = sqrt(x*x+y*y+x*y);
f2 = sqrt(x*x + y*y + x*y);
}
else if( test == 6 )
{
......@@ -136,6 +143,7 @@ int main(int argc,char ** argv)
_dxdy = 8;
_dydy = 6;
f = 2*x*x+3*y*y+4*x*y;
f2 = 2 * x*x + 3 * y*y + 4 * x*y;
}
else if( test == 7 )
{
......@@ -145,6 +153,7 @@ int main(int argc,char ** argv)
_dxdy = -32*(_x-0.5)*(_y-0.5)*sin(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5)));
_dydy = 4*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) - 16*(_y - 0.5)*(_y - 0.5)*sin(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5)));
f = sin(2*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)));
f2 = sin(2 * ((x - 0.5)*(x - 0.5) + (y - 0.5)*(y - 0.5)));
}
else if( test == 8 )
{
......@@ -154,6 +163,7 @@ int main(int argc,char ** argv)
_dxdy = -(-32*(_x-0.5)*(_y-0.5)*sin(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5))));
_dydy = -(4*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) - 16*(_y - 0.5)*(_y - 0.5)*sin(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))));
f = 1.0-sin(2*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)));
f2 = 1.0 - sin(2 * ((x - 0.5)*(x - 0.5) + (y - 0.5)*(y - 0.5)));
}
else if( test == 9 )
{
......@@ -163,6 +173,7 @@ int main(int argc,char ** argv)
_dxdy = 32*(_x-0.5)*(_y-0.5)*cos(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5)));
_dydy = 4*sin(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) + 16*(_y - 0.5)*(_y - 0.5)*cos(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5)));
f = 1.0-cos(2*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)));
f2 = 1.0 - cos(2 * ((x - 0.5)*(x - 0.5) + (y - 0.5)*(y - 0.5)));
}
else if( test == 10 )
{
......@@ -172,6 +183,7 @@ int main(int argc,char ** argv)
_dxdy = 8*(_y-0.5)*cos(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5)))-4*(_x-0.5)*_x*sin(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5)));
_dydy = 4*_x*cos(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5))) - 16*_x*(_y-0.5)*(_y-0.5)*sin(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5)));
f = sin(2*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)))*x;
f2 = sin(2 * ((x - 0.5)*(x - 0.5) + (y - 0.5)*(y - 0.5)))*x;
}
else if( test == 11 )
{
......@@ -189,6 +201,7 @@ int main(int argc,char ** argv)
_dxdy = 2*cos(_x*_x+_y) - 4*_x*_x*sin(_x*_x+_y);
_dydy = -_x*sin(_x*_x+_y);
f = sin(x*x+y)*x;
f2 = sin(x*x + y)*x;
}
else if( test == 12 )
{
......@@ -206,6 +219,7 @@ int main(int argc,char ** argv)
_dxdy = 4*_y*(cos(_x*_x+_y*_y) - 2*_x*_x*sin(_x*_x+_y*_y));
_dydy = 2*_x*(cos(_x*_x+_y*_y)-2*_y*_y*sin(_x*_x+_y*_y));
f = sin(x*x+y*y)*x;
f2 = sin(x*x + y*y)*x;
}
double dx = f.GetRow()[0];
double dy = f.GetRow()[1];
......@@ -214,6 +228,7 @@ int main(int argc,char ** argv)
double dydy = f.GetHessianRow()[Sparse::HessianRow::make_index(1,1)];
bool error = false;
std::cout << "For GetHessian:" << std::endl;
std::cout << std::setw(10) << "derivative " << std::setw(10) << "original " << std::setw(10) << "computed" << std::endl;
std::cout << std::setw(10) << "dx " << std::setw(10) << _dx << std::setw(10) << dx << std::endl;
std::cout << std::setw(10) << "dy " << std::setw(10) << _dy << std::setw(10) << dy << std::endl;
......@@ -226,7 +241,15 @@ int main(int argc,char ** argv)
if( std::abs(dxdy-_dxdy) > 1.0e-9 ) error = true, std::cout << "Error in dxdy: " << std::abs(dxdy-_dxdy) << " original " << _dxdy << " computed " << dxdy << std::endl;
if( std::abs(dydy-_dydy) > 1.0e-9 ) error = true, std::cout << "Error in dydy: " << std::abs(dydy-_dydy) << " original " << _dydy << " computed " << dydy << std::endl;
if( error ) return -1;
dx = f2.GetRow()[0];
dy = f2.GetRow()[1];
std::cout << "For GetJacobian:" << std::endl;
std::cout << std::setw(10) << "derivative " << std::setw(10) << "original " << std::setw(10) << "computed" << std::endl;
std::cout << std::setw(10) << "dx " << std::setw(10) << _dx << std::setw(10) << dx << std::endl;
std::cout << std::setw(10) << "dy " << std::setw(10) << _dy << std::setw(10) << dy << std::endl;
if( error ) return -1;
return 0;
}
......@@ -14,6 +14,7 @@ int main(int argc,char ** argv)
double dx, dy, dz, dxdx, dydy, dzdz, dxdy, dxdz, dydz;
unknown vx(x,0), vy(y,1), vz(z,2);
hessian_variable f;
variable f2;
if( test == 0 ) //check derivative and hessian of sin(x*x+y*y)
......@@ -33,6 +34,7 @@ int main(int argc,char ** argv)
f = sin(8*((vx-0.5)*(vx-0.5)+(vy-0.5)*(vy-0.5)+(vz-0.5)*(vz-0.5)))*(2*vx-1);
f2 = sin(8 * ((vx - 0.5)*(vx - 0.5) + (vy - 0.5)*(vy - 0.5) + (vz - 0.5)*(vz - 0.5)))*(2 * vx - 1);
}
//mixed derivative computed twice: dxdy and dydx
dxdy *= 2;
......@@ -51,6 +53,7 @@ int main(int argc,char ** argv)
bool error = false;
std::cout << "For GetHessian:" << std::endl;
std::cout << std::setw(10) << "derivative " << std::setw(10) << "original " << std::setw(10) << "computed" << std::endl;
std::cout << std::setw(10) << "dx " << std::setw(10) << dx << std::setw(10) << vdx << std::endl;
std::cout << std::setw(10) << "dy " << std::setw(10) << dy << std::setw(10) << vdy << std::endl;
......@@ -70,7 +73,22 @@ int main(int argc,char ** argv)
if( std::abs(dydy-vdydy) > 1.0e-9 ) error = true, std::cout << "Error in dydy: " << std::abs(dydy-vdydy) << " original " << dydy << " computed " << vdydy << std::endl;
if( std::abs(dydz-vdydz) > 1.0e-9 ) error = true, std::cout << "Error in dydz: " << std::abs(dydz-vdydz) << " original " << dydz << " computed " << vdydz << std::endl;
if( std::abs(dzdz-vdzdz) > 1.0e-9 ) error = true, std::cout << "Error in dzdz: " << std::abs(dzdz-vdzdz) << " original " << dzdz << " computed " << vdzdz << std::endl;
if( error ) return -1;
vdx = f2.GetRow()[0];
vdy = f2.GetRow()[1];
vdz = f2.GetRow()[2];
std::cout << "For GetJacobian:" << std::endl;
std::cout << std::setw(10) << "derivative " << std::setw(10) << "original " << std::setw(10) << "computed" << std::endl;
std::cout << std::setw(10) << "dx " << std::setw(10) << dx << std::setw(10) << vdx << std::endl;
std::cout << std::setw(10) << "dy " << std::setw(10) << dy << std::setw(10) << vdy << std::endl;
std::cout << std::setw(10) << "dz " << std::setw(10) << dz << std::setw(10) << vdz << std::endl;
if (std::abs(dx - vdx) > 1.0e-9) error = true, std::cout << "Error in dx: " << std::abs(dx - vdx) << " original " << dx << " computed " << vdx << std::endl;
if (std::abs(dy - vdy) > 1.0e-9) error = true, std::cout << "Error in dy: " << std::abs(dy - vdy) << " original " << dy << " computed " << vdy << std::endl;
if (std::abs(dz - vdz) > 1.0e-9) error = true, std::cout << "Error in dz: " << std::abs(dz - vdz) << " original " << dz << " computed " << vdz << std::endl;
if (error) return -1;
return 0;
}
project(autodiff_test002)
set(SOURCE main.cpp)
add_executable(autodiff_test002 ${SOURCE})
target_link_libraries(autodiff_test002 inmost)
if(USE_MPI)
message("linking autodiff_test002 with MPI")
target_link_libraries(autodiff_test002 ${MPI_LIBRARIES})
if(MPI_LINK_FLAGS)
set_target_properties(autodiff_test002 PROPERTIES LINK_FLAGS "${MPI_LINK_FLAGS}")
endif()
endif(USE_MPI)
add_test(NAME autodiff_test002_hessian_0 COMMAND $<TARGET_FILE:autodiff_test002> 0)
add_test(NAME autodiff_test002_hessian_1 COMMAND $<TARGET_FILE:autodiff_test002> 1)
add_test(NAME autodiff_test002_hessian_2 COMMAND $<TARGET_FILE:autodiff_test002> 2)
add_test(NAME autodiff_test002_hessian_3 COMMAND $<TARGET_FILE:autodiff_test002> 3)
add_test(NAME autodiff_test002_hessian_4 COMMAND $<TARGET_FILE:autodiff_test002> 4)
add_test(NAME autodiff_test002_hessian_5 COMMAND $<TARGET_FILE:autodiff_test002> 5)
add_test(NAME autodiff_test002_hessian_6 COMMAND $<TARGET_FILE:autodiff_test002> 6)
add_test(NAME autodiff_test002_hessian_7 COMMAND $<TARGET_FILE:autodiff_test002> 7)
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment