Commit a388b869 authored by Kirill Terekhov's avatar Kirill Terekhov

Fix algorithm for hessian calculation

Fix for procedure merging hessians and multiplication of jacobins

Few more unit tests for hessian
parent 8f5f181f
......@@ -460,20 +460,24 @@ namespace INMOST
__INLINE hessian_multivar_expression & operator +=(basic_expression const & expr)
{
value += expr.GetValue();
Sparse::Row tmp(entries);
Sparse::HessianRow htmp(hessian_entries);
expr.GetHessian(1.0,tmp,1.0,htmp);
Sparse::Row tmpr, tmp;
Sparse::HessianRow htmpr, htmp;
expr.GetHessian(1.0,tmpr,1.0,htmpr);
Sparse::Row::MergeSortedRows(1.0,entries,1.0,tmpr,tmp);
entries.Swap(tmp);
Sparse::HessianRow::MergeSortedRows(1.0,hessian_entries,1.0,htmpr,htmp);
hessian_entries.Swap(htmp);
return *this;
}
__INLINE hessian_multivar_expression & operator -=(basic_expression const & expr)
{
value -= expr.GetValue();
Sparse::Row tmp(entries);
Sparse::HessianRow htmp(hessian_entries);
expr.GetHessian(-1.0,tmp,-1.0,htmp);
Sparse::Row tmpr, tmp;
Sparse::HessianRow htmpr, htmp;
expr.GetHessian(1.0,tmpr,1.0,htmpr);
Sparse::Row::MergeSortedRows(1.0,entries,-1.0,tmpr,tmp);
entries.Swap(tmp);
Sparse::HessianRow::MergeSortedRows(1.0,hessian_entries,-1.0,htmpr,htmp);
hessian_entries.Swap(htmp);
return *this;
}
......@@ -1046,9 +1050,7 @@ namespace INMOST
__INLINE void GetHessian(INMOST_DATA_REAL_TYPE multJ, Sparse::Row & J, INMOST_DATA_REAL_TYPE multH, Sparse::HessianRow & H) const
{
Sparse::HessianRow htmp;
arg.GetHessian(multJ,J,multH,htmp);
Sparse::HessianRow::MergeJacobianHessian(-value,J,J,dmult,htmp,H);
for(Sparse::Row::iterator it = J.Begin(); it != J.End(); ++it) it->second*=dmult;
arg.GetHessian(1,J,1,htmp);
}
};
......@@ -1078,9 +1080,6 @@ namespace INMOST
{
//arg.GetHessian(multJ*dmult,J,-multH*value,H);
Sparse::HessianRow htmp;
arg.GetHessian(multJ,J,multH,htmp);
Sparse::HessianRow::MergeJacobianHessian(-value,J,J,dmult,htmp,H);
for(Sparse::Row::iterator it = J.Begin(); it != J.End(); ++it) it->second*=dmult;
}
};
......@@ -1109,8 +1108,6 @@ namespace INMOST
//general formula:
// (F(G))'' = F'(G) G'' + F''(G) G'.G'
Sparse::HessianRow htmp;
arg.GetHessian(multJ,J,multH,htmp);
Sparse::HessianRow::MergeJacobianHessian(-0.25/::pow(value,3.0),J,J,0.5/value,htmp,H);
for(Sparse::Row::iterator it = J.Begin(); it != J.End(); ++it) it->second *= 0.5/value;
//arg.GetHessian(0.5*multJ/value,J,-0.25*multH/::pow(value,3),H);
}
......@@ -1272,16 +1269,12 @@ namespace INMOST
{
Sparse::Row JL, JR; //temporary jacobian rows from left and right expressions
Sparse::HessianRow HL, HR; //temporary hessian rows form left and right expressions
left.GetHessian(multJ,JL,multH,HL); //retrive jacobian row and hessian matrix of the left expression
right.GetHessian(multJ,JR,multH,HR); //retrive jacobian row and hessian matrix of the right expression
//assume rows are sorted (this is to be ensured by corresponding GetHessian functions)
//preallocate J to JL.Size+JR.Size
//perform merging of two sorted arrays
//resize to correct size
Sparse::Row::MergeSortedRows(right.GetValue(),JL,left.GetValue(),JR,J);
//preallocate H to HL.Size+HR.Size+JL.Size*JR.Size
//merge sorted
Sparse::HessianRow::MergeJacobianHessian(2.0,JL,JR,right.GetValue(),HL,left.GetValue(),HR,H);
}
};
......@@ -1357,10 +1350,6 @@ namespace INMOST
{
Sparse::Row JL, JR; //temporary jacobian rows from left and right expressions
Sparse::HessianRow HL, HR; //temporary hessian rows form left and right expressions
left.GetHessian(multJ,JL,multH,HL); //retrive jacobian row and hessian matrix of the left expression
right.GetHessian(multJ,JR,multH,HR); //retrive jacobian row and hessian matrix of the right expression
Sparse::Row::MergeSortedRows(1.0,JL,1.0,JR,J);
Sparse::HessianRow::MergeSortedRows(1.0,HL,1.0,HR,H);
}
};
......@@ -1392,10 +1381,6 @@ namespace INMOST
{
Sparse::Row JL, JR; //temporary jacobian rows from left and right expressions
Sparse::HessianRow HL, HR; //temporary hessian rows form left and right expressions
left.GetHessian(multJ,JL,multH,HL); //retrive jacobian row and hessian matrix of the left expression
right.GetHessian(multJ,JR,multH,HR); //retrive jacobian row and hessian matrix of the right expression
Sparse::Row::MergeSortedRows(1.0,JL,-1.0,JR,J);
Sparse::HessianRow::MergeSortedRows(1.0,HL,-1.0,HR,H);
}
};
......
......@@ -319,9 +319,9 @@ namespace INMOST
/// that allow for the modification of individual entries.
/// @param size New size of the row.
void Resize(INMOST_DATA_ENUM_TYPE size) {data.resize(size);}
void Print()
void Print() const
{
for(iterator it = Begin(); it != End(); ++it) std::cout << "(" << it->first.first << "," << it->first.second << "," << it->second << ") ";
for(const_iterator it = Begin(); it != End(); ++it) std::cout << "(" << it->first.first << "," << it->first.second << "," << it->second << ") ";
std::cout << std::endl;
}
bool isSorted() const;
......
......@@ -401,20 +401,23 @@ namespace INMOST
assert(HL.isSorted());
assert(HR.isSorted());
output.Resize(HL.Size()+HR.Size()+JL.Size()*JR.Size());
INMOST_DATA_ENUM_TYPE i = 0, j = 0, k = 0, l = 0, q = 0, kk = 0, ll = 0, r;
entry candidate[3] = {stub_entry,stub_entry,stub_entry};
INMOST_DATA_ENUM_TYPE i = 0, j = 0, k1 = 0, l1 = 0, k2 = 0, l2 = 0, q = 0, kk1 = 0, ll2 = 0, r;
entry candidate[4] = {stub_entry,stub_entry,stub_entry,stub_entry};
if( i < HL.Size() )
candidate[0] = make_entry(HL.GetIndex(i),b*HL.GetValue(i));
if( j < HR.Size() )
candidate[1] = make_entry(HR.GetIndex(j),c*HR.GetValue(j));
if( k < JL.Size() && l < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(kk),JR.GetIndex(ll)),a*JL.GetValue(kk)*JR.GetValue(ll));
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),0.5*a*JL.GetValue(k1)*JR.GetValue(l1));
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[3] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),0.5*a*JL.GetValue(k2)*JR.GetValue(l2));
do
{
//pick smallest
r = 0;
if( candidate[1].first < candidate[r].first ) r = 1;
if( candidate[2].first < candidate[r].first ) r = 2;
if( candidate[3].first < candidate[r].first ) r = 3;
//all candidates are stub - exit
if( candidate[r].first == stub_entry.first ) break;
//record selected entry
......@@ -438,39 +441,37 @@ namespace INMOST
candidate[1] = make_entry(HR.GetIndex(j),c*HR.GetValue(j));
else candidate[1] = stub_entry;
}
else //update jacobians indexes
else if( r == 2 ) //update jacobians indexes
{
if( JR.GetIndex(l) < JL.GetIndex(k) )
if( ++k1 == JL.Size() )
{
if( ++kk == JL.Size() )
++l1;
if( l1 < JR.Size() )
{
++l;
kk = k;
ll = l;
while(kk1 < JL.Size() && JL.GetIndex(kk1) < JR.GetIndex(l1) ) kk1++;
k1 = kk1;
}
}
else if( JL.GetIndex(k) < JR.GetIndex(l) )
{
if( ++ll == JR.Size() )
{
++k;
kk = k;
ll = l;
}
}
else //values are equal
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(k1==l1?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
else
candidate[2] = stub_entry;
}
else //update jacobians indexes
{
if( ++l2 == JR.Size() )
{
if( ++ll == JR.Size() )
++k2;
if( k2 < JL.Size() )
{
++k;
kk = k;
ll = l;
while(ll2 < JR.Size() && JL.GetIndex(k2) > JR.GetIndex(ll2) ) ll2++;
l2 = ll2;
}
}
if( kk < JL.Size() && ll < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(kk),JR.GetIndex(ll)),a*JL.GetValue(kk)*JR.GetValue(ll));
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[3] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(k2==l2?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
else
candidate[2] = stub_entry;
candidate[3] = stub_entry;
}
}
while(true);
......@@ -486,22 +487,27 @@ namespace INMOST
assert(JR.isSorted());
assert(H.isSorted());
output.Resize(H.Size()+JL.Size()*JR.Size());
INMOST_DATA_ENUM_TYPE i = 0, k = 0, l = 0, q = 0, kk = 0, ll = 0, r;
entry candidate[2] = {stub_entry,stub_entry};
INMOST_DATA_ENUM_TYPE i = 0, k1 = 0, l1 = 0, q = 0, k2 = 0, l2 = 0, r, ll2 = 0, kk1 = 0;
entry candidate[3] = {stub_entry,stub_entry,stub_entry};
if( i < H.Size() )
candidate[0] = make_entry(H.GetIndex(i),b*H.GetValue(i));
if( k < JL.Size() && l < JR.Size() )
candidate[1] = make_entry(make_index(JL.GetIndex(kk),JR.GetIndex(ll)),a*JL.GetValue(kk)*JR.GetValue(ll));
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[1] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),0.5*a*JL.GetValue(k1)*JR.GetValue(l1));
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),0.5*a*JL.GetValue(k2)*JR.GetValue(l2));
do
{
//pick smallest
r = 0;
if( candidate[1].first < candidate[r].first ) r = 1;
if( candidate[2].first < candidate[r].first ) r = 2;
//all candidates are stub - exit
if( candidate[r].first == stub_entry.first ) break;
//record selected entry
if( q > 0 && (output.GetIndex(q-1) == candidate[r].first) )
{
output.GetValue(q-1) += candidate[r].second;
}
else
{
output.GetIndex(q) = candidate[r].first;
......@@ -514,39 +520,37 @@ namespace INMOST
candidate[0] = make_entry(H.GetIndex(i),b*H.GetValue(i));
else candidate[0] = stub_entry;
}
else //update jacobians indexes
else if( r == 1 )
{
if( JR.GetIndex(l) < JL.GetIndex(k) )
{
if( ++kk == JL.Size() )
{
++l;
kk = k;
ll = l;
}
}
else if( JL.GetIndex(k) < JR.GetIndex(l) )
if( ++k1 == JL.Size() )
{
if( ++ll == JR.Size() )
++l1;
if( l1 < JR.Size() )
{
++k;
kk = k;
ll = l;
while(kk1 < JL.Size() && JL.GetIndex(kk1) < JR.GetIndex(l1) ) kk1++;
k1 = kk1;
}
}
else //values are equal
if( k1 < JL.Size() && l1 < JR.Size() )
candidate[1] = make_entry(make_index(JL.GetIndex(k1),JR.GetIndex(l1)),(k1==l1?0.5:1)*a*JL.GetValue(k1)*JR.GetValue(l1));
else
candidate[1] = stub_entry;
}
else //update jacobians indexes
{
if( ++l2 == JR.Size() )
{
if( ++ll == JR.Size() )
++k2;
if( k2 < JL.Size() )
{
++k;
kk = k;
ll = l;
while(ll2 < JR.Size() && JL.GetIndex(k2) > JR.GetIndex(ll2) ) ll2++;
l2 = ll2;
}
}
if( kk < JL.Size() && ll < JR.Size() )
candidate[1] = make_entry(make_index(JL.GetIndex(kk),JR.GetIndex(ll)),a*JL.GetValue(kk)*JR.GetValue(ll));
if( k2 < JL.Size() && l2 < JR.Size() )
candidate[2] = make_entry(make_index(JL.GetIndex(k2),JR.GetIndex(l2)),(k2==l2?0.5:1)*a*JL.GetValue(k2)*JR.GetValue(l2));
else
candidate[1] = stub_entry;
candidate[2] = stub_entry;
}
}
while(true);
......
......@@ -20,3 +20,6 @@ add_test(NAME autodiff_test000_hessian_cos_mixed COMMAND $<TARGET_FILE
add_test(NAME autodiff_test000_hessian_sin_mult COMMAND $<TARGET_FILE:autodiff_test000> 3)
add_test(NAME autodiff_test000_hessian_cos_mult COMMAND $<TARGET_FILE:autodiff_test000> 4)
add_test(NAME autodiff_test000_hessian_sqrt_mixed COMMAND $<TARGET_FILE:autodiff_test000> 5)
add_test(NAME autodiff_test000_hessian_poly COMMAND $<TARGET_FILE:autodiff_test000> 6)
add_test(NAME autodiff_test000_hessian_sin_poly COMMAND $<TARGET_FILE:autodiff_test000> 7)
add_test(NAME autodiff_test000_hessian_minus_sin_poly COMMAND $<TARGET_FILE:autodiff_test000> 8)
......@@ -128,7 +128,34 @@ int main(int argc,char ** argv)
_dydy = 3*_x*_x/(4*pow(_x*_x+_x*_y+_y*_y,1.5));
f = sqrt(x*x+y*y+x*y);
}
else if( test == 6 )
{
_dx = 4*_x + 4*_y;
_dy = 6*_y + 4*_x;
_dxdx = 4;
_dxdy = 8;
_dydy = 6;
f = 2*x*x+3*y*y+4*x*y;
}
else if( test == 7 )
{
_dx = 4 *(_x - 0.5)*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5)));
_dy = 4 *(_y - 0.5)*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5)));
_dxdx = 4*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) - 16*(_x - 0.5)*(_x - 0.5)*sin(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5)));
_dxdy = -32*(_x-0.5)*(_y-0.5)*sin(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5)));
_dydy = 4*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) - 16*(_y - 0.5)*(_y - 0.5)*sin(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5)));
f = sin(2*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)));
}
else if( test == 8 )
{
_dx = -(4 *(_x - 0.5)*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))));
_dy = -(4 *(_y - 0.5)*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))));
_dxdx = -(4*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) - 16*(_x - 0.5)*(_x - 0.5)*sin(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))));
_dxdy = -(-32*(_x-0.5)*(_y-0.5)*sin(2*((_x-0.5)*(_x-0.5)+(_y-0.5)*(_y-0.5))));
_dydy = -(4*cos(2*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))) - 16*(_y - 0.5)*(_y - 0.5)*sin(2.0*((_x - 0.5)*(_x - 0.5) + (_y - 0.5)*(_y - 0.5))));
f = 1.0-sin(2*((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)));
}
double dx = f.GetRow()[0];
double dy = f.GetRow()[1];
double dxdx = f.GetHessianRow()[Sparse::HessianRow::make_index(0,0)];
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment