IT++ Logo Newcom Logo

ls_solve.cpp

Go to the documentation of this file.
00001 
00033 #ifndef _MSC_VER
00034 #  include <itpp/config.h>
00035 #else
00036 #  include <itpp/config_msvc.h>
00037 #endif
00038 
00039 #if defined(HAVE_LAPACK)
00040 #  include <itpp/base/lapack.h>
00041 #endif
00042 
00043 #include <itpp/base/ls_solve.h>
00044 
00045 
00046 namespace itpp { 
00047 
00048   // ----------- ls_solve_chol -----------------------------------------------------------
00049 
00050 #if defined(HAVE_LAPACK)
00051 
00052   bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00053   {
00054     int n, lda, ldb, nrhs, info;
00055     n = lda = ldb = A.rows();
00056     nrhs = 1;
00057     char uplo='U';
00058 
00059     it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00060     it_assert1(n == b.size(), "The number of rows in A must equal the length of b!");
00061 
00062     ivec ipiv(n);
00063     x = b;
00064     mat Chol = A;
00065 
00066     dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00067 
00068     return (info==0);
00069   }
00070 
00071 
00072   bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00073   {
00074     int n, lda, ldb, nrhs, info;
00075     n = lda = ldb = A.rows();
00076     nrhs = B.cols();
00077     char uplo='U';
00078 
00079     it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00080     it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!");
00081 
00082     ivec ipiv(n);
00083     X = B;
00084     mat Chol = A;
00085 
00086     dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00087 
00088     return (info==0);
00089   }
00090 
00091   bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00092   {
00093     int n, lda, ldb, nrhs, info;
00094     n = lda = ldb = A.rows();
00095     nrhs = 1;
00096     char uplo='U';
00097 
00098     it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00099     it_assert1(n == b.size(), "The number of rows in A must equal the length of b!");
00100 
00101     ivec ipiv(n);
00102     x = b;
00103     cmat Chol = A;
00104 
00105     zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00106 
00107     return (info==0);
00108   }
00109 
00110   bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00111   {
00112     int n, lda, ldb, nrhs, info;
00113     n = lda = ldb = A.rows();
00114     nrhs = B.cols();
00115     char uplo='U';
00116 
00117     it_assert1(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00118     it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!");
00119 
00120     ivec ipiv(n);
00121     X = B;
00122     cmat Chol = A;
00123 
00124     zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00125 
00126     return (info==0);
00127   }
00128 
00129 #else
00130 
00131   bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00132   {
00133     it_error("LAPACK library is needed to use ls_solve_chol() function");
00134     return false;
00135   }
00136 
00137   bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00138   {
00139     it_error("LAPACK library is needed to use ls_solve_chol() function");
00140     return false;
00141   }
00142 
00143   bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00144   {
00145     it_error("LAPACK library is needed to use ls_solve_chol() function");
00146     return false;
00147   }
00148 
00149   bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00150   {
00151     it_error("LAPACK library is needed to use ls_solve_chol() function");
00152     return false;
00153   }
00154 
00155 #endif // HAVE_LAPACK
00156 
00157   vec ls_solve_chol(const mat &A, const vec &b)
00158   {
00159     vec x;
00160     bool info;
00161     info = ls_solve_chol(A, b, x);
00162     it_assert1(info, "ls_solve_chol: Failed solving the system");
00163     return x;
00164   }
00165 
00166   mat ls_solve_chol(const mat &A, const mat &B)
00167   {
00168     mat X;
00169     bool info;
00170     info = ls_solve_chol(A, B, X);
00171     it_assert1(info, "ls_solve_chol: Failed solving the system");
00172     return X;
00173   }
00174 
00175   cvec ls_solve_chol(const cmat &A, const cvec &b)
00176   {
00177     cvec x;
00178     bool info;
00179     info = ls_solve_chol(A, b, x);
00180     it_assert1(info, "ls_solve_chol: Failed solving the system");
00181     return x;
00182   }
00183 
00184   cmat ls_solve_chol(const cmat &A, const cmat &B)
00185   {
00186     cmat X;
00187     bool info;
00188     info = ls_solve_chol(A, B, X);
00189     it_assert1(info, "ls_solve_chol: Failed solving the system");
00190     return X;
00191   }
00192 
00193 
00194   // --------- ls_solve ---------------------------------------------------------------
00195 #if defined(HAVE_LAPACK)
00196 
00197   bool ls_solve(const mat &A, const vec &b, vec &x)
00198   {
00199     int n, lda, ldb, nrhs, info;
00200     n = lda = ldb = A.rows();
00201     nrhs = 1;
00202 
00203     it_assert1(A.cols() == n, "ls_solve: System-matrix is not square");
00204     it_assert1(n == b.size(), "The number of rows in A must equal the length of b!");
00205 
00206     ivec ipiv(n);
00207     x = b;
00208     mat LU = A;
00209 
00210     dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00211 
00212     return (info==0);
00213   }
00214 
00215   bool ls_solve(const mat &A, const mat &B, mat &X)
00216   {
00217     int n, lda, ldb, nrhs, info;
00218     n = lda = ldb = A.rows();
00219     nrhs = B.cols();
00220 
00221     it_assert1(A.cols() == n, "ls_solve: System-matrix is not square");
00222     it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!");
00223 
00224     ivec ipiv(n);
00225     X = B;
00226     mat LU = A;
00227 
00228     dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00229 
00230     return (info==0);
00231   }
00232 
00233   bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00234   {
00235     int n, lda, ldb, nrhs, info;
00236     n = lda = ldb = A.rows();
00237     nrhs = 1;
00238 
00239     it_assert1(A.cols() == n, "ls_solve: System-matrix is not square");
00240     it_assert1(n == b.size(), "The number of rows in A must equal the length of b!");
00241 
00242     ivec ipiv(n);
00243     x = b;
00244     cmat LU = A;
00245 
00246     zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00247 
00248     return (info==0);
00249   }
00250 
00251   bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00252   {
00253     int n, lda, ldb, nrhs, info;
00254     n = lda = ldb = A.rows();
00255     nrhs = B.cols();
00256 
00257     it_assert1(A.cols() == n, "ls_solve: System-matrix is not square");
00258     it_assert1(n == B.rows(), "The number of rows in A must equal the length of B!");
00259 
00260     ivec ipiv(n);
00261     X = B;
00262     cmat LU = A;
00263 
00264     zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00265 
00266     return (info==0);
00267   }
00268 
00269 #else
00270 
00271   bool ls_solve(const mat &A, const vec &b, vec &x)
00272   {
00273     it_error("LAPACK library is needed to use ls_solve() function");
00274     return false;   
00275   }
00276 
00277   bool ls_solve(const mat &A, const mat &B, mat &X)
00278   {
00279     it_error("LAPACK library is needed to use ls_solve() function");
00280     return false;   
00281   }
00282 
00283   bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00284   {
00285     it_error("LAPACK library is needed to use ls_solve() function");
00286     return false;   
00287   }
00288 
00289   bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00290   {
00291     it_error("LAPACK library is needed to use ls_solve() function");
00292     return false;   
00293   }
00294 
00295 #endif // HAVE_LAPACK
00296 
00297   vec ls_solve(const mat &A, const vec &b)
00298   {
00299     vec x;
00300     bool info;
00301     info = ls_solve(A, b, x);
00302     it_assert1(info, "ls_solve: Failed solving the system");
00303     return x;
00304   }
00305 
00306   mat ls_solve(const mat &A, const mat &B)
00307   {
00308     mat X;
00309     bool info;
00310     info = ls_solve(A, B, X);
00311     it_assert1(info, "ls_solve: Failed solving the system");
00312     return X;
00313   }
00314 
00315   cvec ls_solve(const cmat &A, const cvec &b)
00316   {
00317     cvec x;
00318     bool info;
00319     info = ls_solve(A, b, x);
00320     it_assert1(info, "ls_solve: Failed solving the system");
00321     return x;
00322   }
00323 
00324   cmat ls_solve(const cmat &A, const cmat &B)
00325   {
00326     cmat X;
00327     bool info;
00328     info = ls_solve(A, B, X);
00329     it_assert1(info, "ls_solve: Failed solving the system");
00330     return X;
00331   }
00332 
00333 
00334   // ----------------- ls_solve_od ------------------------------------------------------------------
00335 #if defined(HAVE_LAPACK)
00336 
00337   bool ls_solve_od(const mat &A, const vec &b, vec &x)
00338   {
00339     int m, n, lda, ldb, nrhs, lwork, info;
00340     char trans='N';
00341     m = lda = ldb = A.rows();
00342     n = A.cols();
00343     nrhs = 1;
00344     lwork = n + std::max(m,nrhs);
00345 
00346     it_assert1(m >= n, "The system is under-determined!");
00347     it_assert1(m == b.size(), "The number of rows in A must equal the length of b!");
00348 
00349     vec work(lwork);
00350     x = b;
00351     mat QR = A;
00352 
00353     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00354     x.set_size(n, true);
00355 
00356     return (info==0);
00357   }
00358 
00359   bool ls_solve_od(const mat &A, const mat &B, mat &X)
00360   {
00361     int m, n, lda, ldb, nrhs, lwork, info;
00362     char trans='N';
00363     m = lda = ldb = A.rows();
00364     n = A.cols();
00365     nrhs = B.cols();
00366     lwork = n + std::max(m,nrhs);
00367 
00368     it_assert1(m >= n, "The system is under-determined!");
00369     it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!");
00370 
00371     vec work(lwork);
00372     X = B;
00373     mat QR = A;
00374 
00375     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00376     X.set_size(n, nrhs, true);
00377 
00378     return (info==0);    
00379   }
00380 
00381   bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00382   {
00383     int m, n, lda, ldb, nrhs, lwork, info;
00384     char trans='N';
00385     m = lda = ldb = A.rows();
00386     n = A.cols();
00387     nrhs = 1;
00388     lwork = n + std::max(m,nrhs);
00389 
00390     it_assert1(m >= n, "The system is under-determined!");
00391     it_assert1(m == b.size(), "The number of rows in A must equal the length of b!");
00392 
00393     cvec work(lwork);
00394     x = b;
00395     cmat QR = A;
00396 
00397     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00398     x.set_size(n, true);
00399 
00400     return (info==0);
00401   }
00402 
00403   bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00404   {
00405     int m, n, lda, ldb, nrhs, lwork, info;
00406     char trans='N';
00407     m = lda = ldb = A.rows();
00408     n = A.cols();
00409     nrhs = B.cols();
00410     lwork = n + std::max(m,nrhs);
00411 
00412     it_assert1(m >= n, "The system is under-determined!");
00413     it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!");
00414 
00415     cvec work(lwork);
00416     X = B;
00417     cmat QR = A;
00418 
00419     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00420     X.set_size(n, nrhs, true);
00421 
00422     return (info==0);    
00423   }
00424 
00425 #else
00426 
00427   bool ls_solve_od(const mat &A, const vec &b, vec &x)
00428   {
00429     it_error("LAPACK library is needed to use ls_solve_od() function");
00430     return false;   
00431   }
00432 
00433   bool ls_solve_od(const mat &A, const mat &B, mat &X)
00434   { 
00435     it_error("LAPACK library is needed to use ls_solve_od() function");
00436     return false;   
00437   }
00438 
00439   bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00440   {
00441     it_error("LAPACK library is needed to use ls_solve_od() function");
00442     return false;   
00443   }
00444 
00445   bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00446   {
00447     it_error("LAPACK library is needed to use ls_solve_od() function");
00448     return false;   
00449   }
00450 
00451 #endif // HAVE_LAPACK
00452 
00453   vec ls_solve_od(const mat &A, const vec &b)
00454   {
00455     vec x;
00456     bool info;
00457     info = ls_solve_od(A, b, x);
00458     it_assert1(info, "ls_solve_od: Failed solving the system");
00459     return x;
00460   }
00461 
00462   mat ls_solve_od(const mat &A, const mat &B)
00463   {
00464     mat X;
00465     bool info;
00466     info = ls_solve_od(A, B, X);
00467     it_assert1(info, "ls_solve_od: Failed solving the system");
00468     return X;
00469   }
00470 
00471   cvec ls_solve_od(const cmat &A, const cvec &b)
00472   {
00473     cvec x;
00474     bool info;
00475     info = ls_solve_od(A, b, x);
00476     it_assert1(info, "ls_solve_od: Failed solving the system");
00477     return x;
00478   }
00479 
00480   cmat ls_solve_od(const cmat &A, const cmat &B)
00481   {
00482     cmat X;
00483     bool info;
00484     info = ls_solve_od(A, B, X);
00485     it_assert1(info, "ls_solve_od: Failed solving the system");
00486     return X;
00487   }
00488 
00489   // ------------------- ls_solve_ud -----------------------------------------------------------
00490 #if defined(HAVE_LAPACK)
00491 
00492   bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00493   {
00494     int m, n, lda, ldb, nrhs, lwork, info;
00495     char trans='N';
00496     m = lda = A.rows();
00497     n = A.cols();
00498     ldb = n;
00499     nrhs = 1;
00500     lwork = m + std::max(n,nrhs);
00501 
00502     it_assert1(m < n, "The system is over-determined!");
00503     it_assert1(m == b.size(), "The number of rows in A must equal the length of b!");
00504 
00505     vec work(lwork);
00506     x = b;
00507     x.set_size(n, true);
00508     mat QR = A;
00509 
00510     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00511 
00512     return (info==0);
00513   }
00514 
00515   bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00516   {
00517     int m, n, lda, ldb, nrhs, lwork, info;
00518     char trans='N';
00519     m = lda = A.rows();
00520     n = A.cols();
00521     ldb = n;
00522     nrhs = B.cols();
00523     lwork = m + std::max(n,nrhs);
00524 
00525     it_assert1(m < n, "The system is over-determined!");
00526     it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!");
00527 
00528     vec work(lwork);
00529     X = B;
00530     X.set_size(n, std::max(m, nrhs), true);
00531     mat QR = A;
00532 
00533     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00534     X.set_size(n, nrhs, true);
00535 
00536     return (info==0);    
00537   }
00538 
00539   bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00540   {
00541     int m, n, lda, ldb, nrhs, lwork, info;
00542     char trans='N';
00543     m = lda = A.rows();
00544     n = A.cols();
00545     ldb = n;
00546     nrhs = 1;
00547     lwork = m + std::max(n,nrhs);
00548 
00549     it_assert1(m < n, "The system is over-determined!");
00550     it_assert1(m == b.size(), "The number of rows in A must equal the length of b!");
00551 
00552     cvec work(lwork);
00553     x = b;
00554     x.set_size(n, true);
00555     cmat QR = A;
00556 
00557     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00558 
00559     return (info==0);
00560   }
00561 
00562   bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00563   {
00564     int m, n, lda, ldb, nrhs, lwork, info;
00565     char trans='N';
00566     m = lda = A.rows();
00567     n = A.cols();
00568     ldb = n;
00569     nrhs = B.cols();
00570     lwork = m + std::max(n,nrhs);
00571 
00572     it_assert1(m < n, "The system is over-determined!");
00573     it_assert1(m == B.rows(), "The number of rows in A must equal the length of b!");
00574 
00575     cvec work(lwork);
00576     X = B;
00577     X.set_size(n, std::max(m, nrhs), true);
00578     cmat QR = A;
00579 
00580     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00581     X.set_size(n, nrhs, true);
00582 
00583     return (info==0);    
00584   }
00585 
00586 #else
00587 
00588   bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00589   {
00590     it_error("LAPACK library is needed to use ls_solve_ud() function");
00591     return false;   
00592   }
00593 
00594   bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00595   {
00596     it_error("LAPACK library is needed to use ls_solve_ud() function");
00597     return false;   
00598   }
00599 
00600   bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00601   {
00602     it_error("LAPACK library is needed to use ls_solve_ud() function");
00603     return false;   
00604   }
00605 
00606   bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00607   {
00608     it_error("LAPACK library is needed to use ls_solve_ud() function");
00609     return false;   
00610   }
00611 
00612 #endif // HAVE_LAPACK
00613 
00614 
00615   vec ls_solve_ud(const mat &A, const vec &b)
00616   {
00617     vec x;
00618     bool info;
00619     info = ls_solve_ud(A, b, x);
00620     it_assert1(info, "ls_solve_ud: Failed solving the system");
00621     return x;
00622   }
00623 
00624   mat ls_solve_ud(const mat &A, const mat &B)
00625   {
00626     mat X;
00627     bool info;
00628     info = ls_solve_ud(A, B, X);
00629     it_assert1(info, "ls_solve_ud: Failed solving the system");
00630     return X;
00631   }
00632 
00633   cvec ls_solve_ud(const cmat &A, const cvec &b)
00634   {
00635     cvec x;
00636     bool info;
00637     info = ls_solve_ud(A, b, x);
00638     it_assert1(info, "ls_solve_ud: Failed solving the system");
00639     return x;
00640   }
00641 
00642   cmat ls_solve_ud(const cmat &A, const cmat &B)
00643   {
00644     cmat X;
00645     bool info;
00646     info = ls_solve_ud(A, B, X);
00647     it_assert1(info, "ls_solve_ud: Failed solving the system");
00648     return X;
00649   }
00650 
00651 
00652   // ---------------------- backslash -----------------------------------------
00653 
00654   bool backslash(const mat &A, const vec &b, vec &x)
00655   {
00656     int m=A.rows(), n=A.cols();
00657     bool info;
00658 
00659     if (m == n)
00660       info = ls_solve(A,b,x);
00661     else if (m > n)
00662       info = ls_solve_od(A,b,x);
00663     else
00664       info = ls_solve_ud(A,b,x);
00665     
00666     return info;
00667   }
00668 
00669 
00670   vec backslash(const mat &A, const vec &b)
00671   {
00672     vec x;
00673     bool info;
00674     info = backslash(A, b, x);
00675     it_assert1(info, "backslash(): solution was not found");
00676     return x;
00677   }
00678 
00679 
00680   bool backslash(const mat &A, const mat &B, mat &X)
00681   {
00682     int m=A.rows(), n=A.cols();
00683     bool info;
00684 
00685     if (m == n)
00686       info = ls_solve(A, B, X);
00687     else if (m > n)
00688       info = ls_solve_od(A, B, X);
00689     else
00690       info = ls_solve_ud(A, B, X);
00691     
00692     return info;
00693   }
00694 
00695 
00696   mat backslash(const mat &A, const mat &B)
00697   {
00698     mat X;
00699     bool info;
00700     info = backslash(A, B, X);
00701     it_assert1(info, "backslash(): solution was not found");
00702     return X;
00703   }
00704 
00705 
00706   bool backslash(const cmat &A, const cvec &b, cvec &x)
00707   {
00708     int m=A.rows(), n=A.cols();
00709     bool info;
00710 
00711     if (m == n)
00712       info = ls_solve(A,b,x);
00713     else if (m > n)
00714       info = ls_solve_od(A,b,x);
00715     else
00716       info = ls_solve_ud(A,b,x);
00717     
00718     return info;
00719   }
00720 
00721 
00722   cvec backslash(const cmat &A, const cvec &b)
00723   {
00724     cvec x;
00725     bool info;
00726     info = backslash(A, b, x);
00727     it_assert1(info, "backslash(): solution was not found");
00728     return x;
00729   }
00730 
00731 
00732   bool backslash(const cmat &A, const cmat &B, cmat &X)
00733   {
00734     int m=A.rows(), n=A.cols();
00735     bool info;
00736 
00737     if (m == n)
00738       info = ls_solve(A, B, X);
00739     else if (m > n)
00740       info = ls_solve_od(A, B, X);
00741     else
00742       info = ls_solve_ud(A, B, X);
00743     
00744     return info;
00745   }
00746 
00747   cmat backslash(const cmat &A, const cmat &B)
00748   {
00749     cmat X;
00750     bool info;
00751     info = backslash(A, B, X);
00752     it_assert1(info, "backslash(): solution was not found");
00753     return X;
00754   }
00755 
00756 
00757   // --------------------------------------------------------------------------
00758 
00759   vec forward_substitution(const mat &L, const vec &b)
00760   {
00761     int n = L.rows();
00762     vec x(n);
00763   
00764     forward_substitution(L, b, x);
00765 
00766     return x;
00767   }
00768 
00769   void forward_substitution(const mat &L, const vec &b, vec &x)
00770   {
00771     it_assert( L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(), 
00772                "forward_substitution: dimension mismatch" );
00773     int n = L.rows(), i, j;
00774     double temp;
00775 
00776     x(0)=b(0)/L(0,0);
00777     for (i=1;i<n;i++) {
00778       // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow.
00779       //i_pos=i*L._row_offset();
00780       temp=0;
00781       for (j=0; j<i; j++) {
00782         temp += L._elem(i,j) * x(j);
00783         //temp+=L._data()[i_pos+j]*x(j);
00784       }
00785       x(i) = (b(i)-temp)/L._elem(i,i);
00786       //x(i)=(b(i)-temp)/L._data()[i_pos+i];
00787     }
00788   }
00789 
00790   vec forward_substitution(const mat &L, int p, const vec &b)
00791   {
00792     int n = L.rows();
00793     vec x(n);
00794   
00795     forward_substitution(L, p, b, x);
00796 
00797     return x;
00798   }
00799 
00800   void forward_substitution(const mat &L, int p, const vec &b, vec &x)
00801   {
00802     it_assert( L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows()/2,
00803                "forward_substitution: dimension mismatch");
00804     int n = L.rows(), i, j;
00805 
00806     x=b;
00807   
00808     for (j=0;j<n;j++) {
00809       x(j)/=L(j,j);
00810       for (i=j+1;i<std::min(j+p+1,n);i++) {
00811         x(i)-=L(i,j)*x(j);
00812       }
00813     }
00814   }
00815 
00816   vec backward_substitution(const mat &U, const vec &b)
00817   {
00818     vec x(U.rows());
00819     backward_substitution(U, b, x);
00820 
00821     return x;
00822   }
00823 
00824   void backward_substitution(const mat &U, const vec &b, vec &x)
00825   {
00826     it_assert( U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(),
00827                "backward_substitution: dimension mismatch" );
00828     int n = U.rows(), i, j;
00829     double temp;
00830 
00831     x(n-1)=b(n-1)/U(n-1,n-1);
00832     for (i=n-2; i>=0; i--) {
00833       // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow.
00834       temp=0;
00835       //i_pos=i*U._row_offset();
00836       for (j=i+1; j<n; j++) {
00837         temp += U._elem(i,j) * x(j);
00838         //temp+=U._data()[i_pos+j]*x(j);
00839       }
00840       x(i) = (b(i)-temp)/U._elem(i,i);
00841       //x(i)=(b(i)-temp)/U._data()[i_pos+i];
00842     }
00843   }
00844 
00845   vec backward_substitution(const mat &U, int q, const vec &b)
00846   {
00847     vec x(U.rows());
00848     backward_substitution(U, q, b, x);
00849 
00850     return x;
00851   }
00852 
00853   void backward_substitution(const mat &U, int q, const vec &b, vec &x)
00854   {
00855     it_assert( U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows()/2,
00856                "backward_substitution: dimension mismatch" );
00857     int n = U.rows(), i, j;
00858 
00859     x=b;
00860   
00861     for (j=n-1; j>=0; j--) {
00862       x(j) /= U(j,j);
00863       for (i=std::max(0,j-q); i<j; i++) {
00864         x(i)-=U(i,j)*x(j);
00865       }
00866     }
00867   }
00868 
00869 } // namespace itpp
SourceForge Logo

Generated on Fri Jun 8 01:07:08 2007 for IT++ by Doxygen 1.5.2