source: mmcs/armadillo_bits/fn_svds.hpp @ 8daa049

matrices
Last change on this file since 8daa049 was 9dd61b1, checked in by rboet <rboet@…>, 9 years ago

Avance del proyecto 60%

  • Property mode set to 100644
File size: 8.6 KB
Line 
1// Copyright (C) 2015 Conrad Sanderson
2// Copyright (C) 2015 NICTA (www.nicta.com.au)
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8
9//! \addtogroup fn_svds
10//! @{
11
12
13template<typename T1>
14inline
15bool
16svds_helper
17  (
18           Mat<typename T1::elem_type>&    U,
19           Col<typename T1::pod_type >&    S,
20           Mat<typename T1::elem_type>&    V,
21  const SpBase<typename T1::elem_type,T1>& X,
22  const uword                              k,
23  const typename T1::pod_type              tol,
24  const bool                               calc_UV,
25  const typename arma_real_only<typename T1::elem_type>::result* junk = 0
26  )
27  {
28  arma_extra_debug_sigprint();
29  arma_ignore(junk);
30 
31  typedef typename T1::elem_type eT;
32  typedef typename T1::pod_type   T;
33 
34  if(arma_config::arpack == false)
35    {
36    arma_stop("svds(): use of ARPACK needs to be enabled");
37    return false;
38    }
39 
40  arma_debug_check
41    (
42    ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ),
43    "svds(): two or more output objects are the same object"
44    );
45 
46  arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" );
47 
48  const unwrap_spmat<T1> tmp(X.get_ref());
49  const SpMat<eT>& A =   tmp.M;
50 
51  const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k );
52 
53  const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col<eT>(const_cast<eT*>(A.values), A.n_nonzero, false)))) : T(0);
54 
55  if(A_max == T(0))
56    {
57    // TODO: use reset instead ?
58    S.zeros(kk);
59   
60    if(calc_UV)
61      {
62      U.eye(A.n_rows, kk);
63      V.eye(A.n_cols, kk);
64      }
65    }
66  else
67    {
68    SpMat<eT> C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) );
69   
70    SpMat<eT> B  = A / A_max;
71    SpMat<eT> Bt = B.t();
72   
73    C(0, A.n_rows, size(B) ) = B;
74    C(A.n_rows, 0, size(Bt)) = Bt;
75   
76    Bt.reset();
77    B.reset();
78   
79    Col<eT> eigval;
80    Mat<eT> eigvec;
81   
82    const bool status = sp_auxlib::eigs_sym(eigval, eigvec, C, kk, "la", (tol / Datum<T>::sqrt2));
83   
84    if(status == false)
85      {
86      U.reset();
87      S.reset();
88      V.reset();
89     
90      return false;
91      }
92   
93    const T A_norm = max(eigval);
94   
95    const T tol2 = tol / Datum<T>::sqrt2 * A_norm;
96   
97    uvec indices = find(eigval > tol2);
98   
99    if(indices.n_elem > kk)
100      {
101      indices = indices.subvec(0,kk-1);
102      }
103    else
104    if(indices.n_elem < kk)
105      {
106      const uvec indices2 = find(abs(eigval) <= tol2);
107     
108      const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) );
109     
110      if(N_extra > 0)  { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); }
111      }
112   
113    const uvec sorted_indices = sort_index(eigval, "descend");
114   
115    S = eigval.elem(sorted_indices);  S *= A_max;
116   
117    if(calc_UV)
118      {
119      uvec U_row_indices(A.n_rows);  for(uword i=0; i < A.n_rows; ++i)  { U_row_indices[i] = i;            }
120      uvec V_row_indices(A.n_cols);  for(uword i=0; i < A.n_cols; ++i)  { V_row_indices[i] = i + A.n_rows; }
121     
122      U = Datum<T>::sqrt2 * eigvec(U_row_indices, sorted_indices);
123      V = Datum<T>::sqrt2 * eigvec(V_row_indices, sorted_indices);
124      }
125    }
126 
127  arma_debug_warn( (S.n_elem < k), "svds(): warning: found fewer singular values than specified" );
128 
129  return true;
130  }
131
132
133
134template<typename T1>
135inline
136bool
137svds_helper
138  (
139           Mat<typename T1::elem_type>&    U,
140           Col<typename T1::pod_type >&    S,
141           Mat<typename T1::elem_type>&    V,
142  const SpBase<typename T1::elem_type,T1>& X,
143  const uword                              k,
144  const typename T1::pod_type              tol,
145  const bool                               calc_UV,
146  const typename arma_cx_only<typename T1::elem_type>::result* junk = 0
147  )
148  {
149  arma_extra_debug_sigprint();
150  arma_ignore(junk);
151 
152  typedef typename T1::elem_type eT;
153  typedef typename T1::pod_type   T;
154 
155  if(arma_config::arpack == false)
156    {
157    arma_stop("svds(): use of ARPACK needs to be enabled");
158    return false;
159    }
160 
161  arma_debug_check
162    (
163    ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ),
164    "svds(): two or more output objects are the same object"
165    );
166 
167  arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" );
168 
169  const unwrap_spmat<T1> tmp(X.get_ref());
170  const SpMat<eT>& A =   tmp.M;
171 
172  const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k );
173 
174  const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col<eT>(const_cast<eT*>(A.values), A.n_nonzero, false)))) : T(0);
175 
176  if(A_max == T(0))
177    {
178    // TODO: use reset instead ?
179    S.zeros(kk);
180   
181    if(calc_UV)
182      {
183      U.eye(A.n_rows, kk);
184      V.eye(A.n_cols, kk);
185      }
186    }
187  else
188    {
189    SpMat<eT> C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) );
190   
191    SpMat<eT> B  = A / A_max;
192    SpMat<eT> Bt = B.t();
193   
194    C(0, A.n_rows, size(B) ) = B;
195    C(A.n_rows, 0, size(Bt)) = Bt;
196   
197    Bt.reset();
198    B.reset();
199   
200    Col<eT> eigval_tmp;
201    Mat<eT> eigvec;
202   
203    const bool status = sp_auxlib::eigs_gen(eigval_tmp, eigvec, C, kk, "lr", (tol / Datum<T>::sqrt2));
204   
205    if(status == false)
206      {
207      U.reset();
208      S.reset();
209      V.reset();
210      arma_bad("svds(): failed to converge", false);
211     
212      return false;
213      }
214   
215    const Col<T> eigval = real(eigval_tmp);
216   
217    const T A_norm = max(eigval);
218   
219    const T tol2 = tol / Datum<T>::sqrt2 * A_norm;
220   
221    uvec indices = find(eigval > tol2);
222   
223    if(indices.n_elem > kk)
224      {
225      indices = indices.subvec(0,kk-1);
226      }
227    else
228    if(indices.n_elem < kk)
229      {
230      const uvec indices2 = find(abs(eigval) <= tol2);
231     
232      const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) );
233     
234      if(N_extra > 0)  { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); }
235      }
236   
237    const uvec sorted_indices = sort_index(eigval, "descend");
238   
239    S = eigval.elem(sorted_indices);  S *= A_max;
240   
241    if(calc_UV)
242      {
243      uvec U_row_indices(A.n_rows);  for(uword i=0; i < A.n_rows; ++i)  { U_row_indices[i] = i;            }
244      uvec V_row_indices(A.n_cols);  for(uword i=0; i < A.n_cols; ++i)  { V_row_indices[i] = i + A.n_rows; }
245     
246      U = Datum<T>::sqrt2 * eigvec(U_row_indices, sorted_indices);
247      V = Datum<T>::sqrt2 * eigvec(V_row_indices, sorted_indices);
248      }
249    }
250 
251  arma_debug_warn( (S.n_elem < k), "svds(): warning: found fewer singular values than specified" );
252 
253  return true;
254  }
255
256
257
258//! find the k largest singular values and corresponding singular vectors of sparse matrix X
259template<typename T1>
260inline
261bool
262svds
263  (
264           Mat<typename T1::elem_type>&    U,
265           Col<typename T1::pod_type >&    S,
266           Mat<typename T1::elem_type>&    V,
267  const SpBase<typename T1::elem_type,T1>& X,
268  const uword                              k,
269  const typename T1::pod_type              tol  = 0.0,
270  const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
271  )
272  {
273  arma_extra_debug_sigprint();
274  arma_ignore(junk);
275 
276  const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true);
277 
278  if(status == false)
279    {
280    arma_bad("svds(): failed to converge", false);
281    }
282 
283  return status;
284  }
285
286
287
288//! find the k largest singular values of sparse matrix X
289template<typename T1>
290inline
291bool
292svds
293  (
294           Col<typename T1::pod_type >&    S,
295  const SpBase<typename T1::elem_type,T1>& X,
296  const uword                              k,
297  const typename T1::pod_type              tol  = 0.0,
298  const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
299  )
300  {
301  arma_extra_debug_sigprint();
302  arma_ignore(junk);
303 
304  Mat<typename T1::elem_type> U;
305  Mat<typename T1::elem_type> V;
306 
307  const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false);
308
309  if(status == false)
310    {
311    arma_bad("svds(): failed to converge", false);
312    }
313 
314  return status;
315  }
316
317
318
319//! find the k largest singular values of sparse matrix X
320template<typename T1>
321inline
322Col<typename T1::pod_type>
323svds
324  (
325  const SpBase<typename T1::elem_type,T1>& X,
326  const uword                              k,
327  const typename T1::pod_type              tol  = 0.0,
328  const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
329  )
330  {
331  arma_extra_debug_sigprint();
332  arma_ignore(junk);
333 
334  Col<typename T1::pod_type>  S;
335
336  Mat<typename T1::elem_type> U;
337  Mat<typename T1::elem_type> V;
338 
339  const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false);
340 
341  if(status == false)
342    {
343    arma_bad("svds(): failed to converge", true);
344    }
345 
346  return S;
347  }
348
349
350
351//! @}
Note: See TracBrowser for help on using the repository browser.