37 #if defined(GMM_USES_SUPERLU)
39 #ifndef GMM_SUPERLU_INTERFACE_H
40 #define GMM_SUPERLU_INTERFACE_H
59 #if defined(GMM_NO_SUPERLU_INCLUDE_SUBDIR)
60 #include "slu_Cnames.h"
61 #include "supermatrix.h"
63 #include "slu_scomplex.h"
64 #include "slu_dcomplex.h"
66 #include "superlu/slu_Cnames.h"
67 #include "superlu/supermatrix.h"
68 #include "superlu/slu_util.h"
69 #include "superlu/slu_scomplex.h"
70 #include "superlu/slu_dcomplex.h"
73 #if (SUPERLU_MAJOR_VERSION <= 6)
74 # define singlecomplex complex
87 sgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
88 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
90 dgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
91 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
93 cgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
94 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
96 zgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
97 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
99 sgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
100 char *,
float *,
float *, SuperMatrix *, SuperMatrix *,
101 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
102 float *,
float *,
float *,
float *,
103 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
105 dgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
106 char *,
double *,
double *, SuperMatrix *, SuperMatrix *,
107 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
108 double *,
double *,
double *,
double *,
109 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
111 cgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
112 char *,
float *,
float *, SuperMatrix *, SuperMatrix *,
113 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
114 float *,
float *,
float *,
float *,
115 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
117 zgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
118 char *,
double *,
double *, SuperMatrix *, SuperMatrix *,
119 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
120 double *,
double *,
double *,
double *,
121 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
123 sCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t,
float *,
124 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
126 dCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t,
double *,
127 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
129 cCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t, singlecomplex *,
130 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
132 zCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t, doublecomplex *,
133 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
135 sCreate_Dense_Matrix(SuperMatrix *,
int,
int,
float *,
int,
136 Stype_t, Dtype_t, Mtype_t);
138 dCreate_Dense_Matrix(SuperMatrix *,
int,
int,
double *,
int,
139 Stype_t, Dtype_t, Mtype_t);
141 cCreate_Dense_Matrix(SuperMatrix *,
int,
int, singlecomplex *,
int,
142 Stype_t, Dtype_t, Mtype_t);
144 zCreate_Dense_Matrix(SuperMatrix *,
int,
int, doublecomplex *,
int,
145 Stype_t, Dtype_t, Mtype_t);
153 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
154 int nnz,
float *a,
int *ir,
int *jc) {
155 SuperLU::sCreate_CompCol_Matrix(A, m, n, nnz, a, ir, jc,
156 SuperLU::SLU_NC, SuperLU::SLU_S,
160 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
161 int nnz,
double *a,
int *ir,
int *jc) {
162 SuperLU::dCreate_CompCol_Matrix(A, m, n, nnz, a, ir, jc,
163 SuperLU::SLU_NC, SuperLU::SLU_D,
167 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
168 int nnz, std::complex<float> *a,
170 SuperLU::cCreate_CompCol_Matrix(A, m, n, nnz,
171 (SuperLU::singlecomplex *)(a),
172 ir, jc, SuperLU::SLU_NC, SuperLU::SLU_C,
176 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
177 int nnz, std::complex<double> *a,
179 SuperLU::zCreate_CompCol_Matrix(A, m, n, nnz,
180 (SuperLU::doublecomplex *)(a), ir, jc,
181 SuperLU::SLU_NC, SuperLU::SLU_Z,
187 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
189 SuperLU::sCreate_Dense_Matrix(A, m, n, a, k, SuperLU::SLU_DN,
193 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
195 SuperLU::dCreate_Dense_Matrix(A, m, n, a, k, SuperLU::SLU_DN,
199 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
200 std::complex<float> *a,
int k) {
201 SuperLU::cCreate_Dense_Matrix(A, m, n,
202 (SuperLU::singlecomplex *)(a),
203 k, SuperLU::SLU_DN, SuperLU::SLU_C,
206 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
207 std::complex<double> *a,
int k) {
208 SuperLU::zCreate_Dense_Matrix(A, m, n, (SuperLU::doublecomplex *)(a),
209 k, SuperLU::SLU_DN, SuperLU::SLU_Z,
215 #define DECL_GSSV(FNAME,KEYTYPE) \
216 inline void SuperLU_gssv(SuperLU::superlu_options_t *options, \
217 SuperLU::SuperMatrix *A, int *p, int *q, \
218 SuperLU::SuperMatrix *L, \
219 SuperLU::SuperMatrix *U, \
220 SuperLU::SuperMatrix *B, \
221 SuperLU::SuperLUStat_t *stats, \
222 int *info, KEYTYPE) { \
223 SuperLU::FNAME(options, A, p, q, L, U, B, stats, info); \
226 DECL_GSSV(sgssv,
float)
227 DECL_GSSV(cgssv, std::complex<float>)
228 DECL_GSSV(dgssv,
double)
229 DECL_GSSV(zgssv, std::complex<double>)
233 #define DECL_GSSVX(FNAME,FLOATTYPE,KEYTYPE) \
234 inline float SuperLU_gssvx(SuperLU::superlu_options_t *options, \
235 SuperLU::SuperMatrix *A, \
236 int *perm_c, int *perm_r, int *etree, \
238 FLOATTYPE *R, FLOATTYPE *C, \
239 SuperLU::SuperMatrix *L, \
240 SuperLU::SuperMatrix *U, \
241 void *work, int lwork, \
242 SuperLU::SuperMatrix *B, \
243 SuperLU::SuperMatrix *X, \
244 FLOATTYPE *recip_pivot_growth, \
245 FLOATTYPE *rcond, FLOATTYPE *ferr, \
247 SuperLU::SuperLUStat_t *stats, \
248 int *info, KEYTYPE) { \
249 SuperLU::mem_usage_t mem_usage; \
250 SuperLU::GlobalLU_t Glu; \
251 SuperLU::FNAME(options, A, perm_c, perm_r, etree, equed, R, C, L, \
252 U, work, lwork, B, X, recip_pivot_growth, rcond, \
253 ferr, berr, &Glu, &mem_usage, stats, info); \
254 return mem_usage.for_lu; \
257 DECL_GSSVX(sgssvx,
float,
float)
258 DECL_GSSVX(cgssvx,
float, std::complex<float>)
259 DECL_GSSVX(dgssvx,
double,
double)
260 DECL_GSSVX(zgssvx,
double, std::complex<double>)
266 template <
typename MAT,
typename VECTX,
typename VECTB>
267 int SuperLU_solve(
const MAT &A,
const VECTX &X,
const VECTB &B,
268 double& rcond_,
int permc_spec = 3) {
276 typedef typename linalg_traits<MAT>::value_type T;
277 typedef typename number_traits<T>::magnitude_type R;
279 int m = int(mat_nrows(A)), n = int(mat_ncols(A)), nrhs = 1, info = 0;
281 csc_matrix<T> csc_A(m, n);
283 std::vector<T> rhs(m), sol(m);
286 int nz = int(
nnz(csc_A));
287 if ((2 * nz / n) >= m)
288 GMM_WARNING2(
"CAUTION : it seems that SuperLU has a problem"
289 " for nearly dense sparse matrices");
291 SuperLU::superlu_options_t options;
292 set_default_options(&options);
293 options.ColPerm = SuperLU::NATURAL;
294 options.PrintStat = SuperLU::NO;
295 options.ConditionNumber = SuperLU::YES;
296 switch (permc_spec) {
297 case 1 : options.ColPerm = SuperLU::MMD_ATA;
break;
298 case 2 : options.ColPerm = SuperLU::MMD_AT_PLUS_A;
break;
299 case 3 : options.ColPerm = SuperLU::COLAMD;
break;
301 SuperLU::SuperLUStat_t stat;
304 SuperLU::SuperMatrix SA, SL, SU, SB, SX;
305 Create_CompCol_Matrix(&SA, m, n, nz, (T *)(&csc_A.pr[0]),
306 (
int *)(&csc_A.ir[0]),
307 (
int *)(&csc_A.jc[0]));
308 Create_Dense_Matrix(&SB, m, nrhs, &rhs[0], m);
309 Create_Dense_Matrix(&SX, m, nrhs, &sol[0], m);
310 memset(&SL,0,
sizeof SL);
311 memset(&SU,0,
sizeof SU);
313 std::vector<int> etree(n);
315 std::vector<R> Rscale(m),Cscale(n);
316 std::vector<R> ferr(nrhs), berr(nrhs);
317 R recip_pivot_gross, rcond;
318 std::vector<int> perm_r(m), perm_c(n);
320 SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
336 if (SB.Store) Destroy_SuperMatrix_Store(&SB);
337 if (SX.Store) Destroy_SuperMatrix_Store(&SX);
338 if (SA.Store) Destroy_SuperMatrix_Store(&SA);
339 if (SL.Store) Destroy_SuperNode_Matrix(&SL);
340 if (SU.Store) Destroy_CompCol_Matrix(&SU);
342 GMM_ASSERT1(info != -333333333,
"SuperLU was cancelled.");
344 GMM_ASSERT1(info >= 0,
"SuperLU solve failed: info =" << info);
345 if (info > 0) GMM_WARNING1(
"SuperLU solve failed: info =" << info);
351 class SuperLU_factor {
352 typedef typename number_traits<T>::magnitude_type R;
355 mutable SuperLU::SuperMatrix SA, SL, SB, SU, SX;
356 mutable SuperLU::SuperLUStat_t stat;
357 mutable SuperLU::superlu_options_t options;
359 mutable std::vector<int> etree, perm_r, perm_c;
360 mutable std::vector<R> Rscale, Cscale;
361 mutable std::vector<R> ferr, berr;
362 mutable std::vector<T> rhs;
363 mutable std::vector<T> sol;
364 mutable bool is_init;
368 enum { LU_NOTRANSP, LU_TRANSP, LU_CONJUGATED };
369 void free_supermatrix() {
371 if (SB.Store) Destroy_SuperMatrix_Store(&SB);
372 if (SX.Store) Destroy_SuperMatrix_Store(&SX);
373 if (SA.Store) Destroy_SuperMatrix_Store(&SA);
374 if (SL.Store) Destroy_SuperNode_Matrix(&SL);
375 if (SU.Store) Destroy_CompCol_Matrix(&SU);
378 template <
class MAT>
void build_with(
const MAT &A,
int permc_spec = 3);
379 template <
typename VECTX,
typename VECTB>
383 void solve(
const VECTX &X_,
const VECTB &B,
int transp=LU_NOTRANSP)
const;
384 SuperLU_factor() { is_init =
false; }
385 SuperLU_factor(
const SuperLU_factor& other) {
386 GMM_ASSERT2(!(other.is_init),
387 "copy of initialized SuperLU_factor is forbidden");
390 SuperLU_factor& operator=(
const SuperLU_factor& other) {
391 GMM_ASSERT2(!(other.is_init) && !is_init,
392 "assignment of initialized SuperLU_factor is forbidden");
395 ~SuperLU_factor() { free_supermatrix(); }
396 float memsize() {
return memory_used; }
400 template <
class T>
template <
class MAT>
401 void SuperLU_factor<T>::build_with(
const MAT &A,
int permc_spec) {
410 int n = int(mat_nrows(A)), m = int(mat_ncols(A)), info = 0;
413 rhs.resize(m); sol.resize(m);
415 int nz = int(
nnz(csc_A));
417 set_default_options(&options);
418 options.ColPerm = SuperLU::NATURAL;
419 options.PrintStat = SuperLU::NO;
420 options.ConditionNumber = SuperLU::NO;
421 switch (permc_spec) {
422 case 1 : options.ColPerm = SuperLU::MMD_ATA;
break;
423 case 2 : options.ColPerm = SuperLU::MMD_AT_PLUS_A;
break;
424 case 3 : options.ColPerm = SuperLU::COLAMD;
break;
428 Create_CompCol_Matrix(&SA, m, n, nz, (T *)(&csc_A.pr[0]),
429 (
int *)(&csc_A.ir[0]),
430 (
int *)(&csc_A.jc[0]));
431 Create_Dense_Matrix(&SB, m, 0, &rhs[0], m);
432 Create_Dense_Matrix(&SX, m, 0, &sol[0], m);
433 memset(&SL,0,
sizeof SL);
434 memset(&SU,0,
sizeof SU);
436 Rscale.resize(m); Cscale.resize(n); etree.resize(n);
437 ferr.resize(1); berr.resize(1);
438 R recip_pivot_gross, rcond;
439 perm_r.resize(m); perm_c.resize(n);
440 memory_used = SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
457 Destroy_SuperMatrix_Store(&SB);
458 Destroy_SuperMatrix_Store(&SX);
459 Create_Dense_Matrix(&SB, m, 1, &rhs[0], m);
460 Create_Dense_Matrix(&SX, m, 1, &sol[0], m);
463 GMM_ASSERT1(info != -333333333,
"SuperLU was cancelled.");
464 GMM_ASSERT1(info == 0,
"SuperLU solve failed: info=" << info);
468 template <
class T>
template <
typename VECTX,
typename VECTB>
469 void SuperLU_factor<T>::solve(
const VECTX &X,
const VECTB &B,
472 options.Fact = SuperLU::FACTORED;
473 options.IterRefine = SuperLU::NOREFINE;
475 case LU_NOTRANSP: options.Trans = SuperLU::NOTRANS;
break;
476 case LU_TRANSP: options.Trans = SuperLU::TRANS;
break;
477 case LU_CONJUGATED: options.Trans = SuperLU::CONJ;
break;
478 default: GMM_ASSERT1(
false,
"invalid value for transposition option");
482 R recip_pivot_gross, rcond;
483 SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
499 GMM_ASSERT1(info == 0,
"SuperLU solve failed: info=" << info);
503 template <
typename T,
typename V1,
typename V2>
inline
504 void mult(
const SuperLU_factor<T>& P,
const V1 &v1,
const V2 &v2) {
508 template <
typename T,
typename V1,
typename V2>
inline
509 void transposed_mult(
const SuperLU_factor<T>& P,
const V1 &v1,
const V2 &v2) {
510 P.solve(v2, v1, SuperLU_factor<T>::LU_TRANSP);
size_type nnz(const L &l)
count the number of non-zero entries of a vector or matrix.
void copy(const L1 &l1, L2 &l2)
*/
void clear(L &l)
clear (fill with zeros) a vector or matrix.
void mult(const L1 &l1, const L2 &l2, L3 &l3)
*/
Include the base gmm files.