source: mmcs/armadillo_bits/fn_accu.hpp @ 8daa049

matrices
Last change on this file since 8daa049 was 9dd61b1, checked in by rboet <rboet@…>, 9 years ago

Avance del proyecto 60%

  • Property mode set to 100644
File size: 8.3 KB
Line 
1// Copyright (C) 2008-2015 Conrad Sanderson
2// Copyright (C) 2008-2015 NICTA (www.nicta.com.au)
3// Copyright (C) 2012 Ryan Curtin
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this
7// file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9
10//! \addtogroup fn_accu
11//! @{
12
13
14
15template<typename T1>
16arma_hot
17inline
18typename T1::elem_type
19accu_proxy_linear(const Proxy<T1>& P)
20  {
21  typedef typename T1::elem_type eT;
22 
23  const uword n_elem = P.get_n_elem();
24 
25  #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)
26    {
27    eT val = eT(0);
28   
29    if(P.is_aligned())
30      {
31      typename Proxy<T1>::aligned_ea_type A = P.get_aligned_ea();
32     
33      for(uword i=0; i<n_elem; ++i)  { val += A.at_alt(i); }
34      }
35    else
36      {
37      typename Proxy<T1>::ea_type A = P.get_ea();
38     
39      for(uword i=0; i<n_elem; ++i)  { val += A[i]; }
40      }
41   
42    return val;
43    }
44  #else
45    {
46    eT val1 = eT(0);
47    eT val2 = eT(0);
48   
49    typename Proxy<T1>::ea_type A = P.get_ea();
50   
51    uword i,j;
52    for(i=0, j=1; j < n_elem; i+=2, j+=2)
53      {
54      val1 += A[i];
55      val2 += A[j];
56      }
57   
58    if(i < n_elem)
59      {
60      val1 += A[i];   // equivalent to: val1 += A[n_elem-1];
61      }
62   
63    return (val1 + val2);
64    }
65  #endif
66  }
67
68
69
70template<typename T1>
71arma_hot
72inline
73typename T1::elem_type
74accu_proxy_mat(const Proxy<T1>& P)
75  {
76  const quasi_unwrap<typename Proxy<T1>::stored_type> tmp(P.Q);
77 
78  return arrayops::accumulate(tmp.M.memptr(), tmp.M.n_elem);
79  }
80
81
82
83template<typename T1>
84arma_hot
85inline
86typename T1::elem_type
87accu_proxy_at(const Proxy<T1>& P)
88  {
89  typedef typename T1::elem_type eT;
90 
91  const uword n_rows = P.get_n_rows();
92  const uword n_cols = P.get_n_cols();
93 
94  eT val = eT(0);
95 
96  if(n_rows != 1)
97    {
98    eT val1 = eT(0);
99    eT val2 = eT(0);
100   
101    for(uword col=0; col < n_cols; ++col)
102      {
103      uword i,j;
104      for(i=0, j=1; j < n_rows; i+=2, j+=2)
105        {
106        val1 += P.at(i,col);
107        val2 += P.at(j,col);
108        }
109     
110      if(i < n_rows)
111        {
112        val1 += P.at(i,col);
113        }
114      }
115   
116    val = val1 + val2;
117    }
118  else
119    {
120    for(uword col=0; col < n_cols; ++col)
121      {
122      val += P.at(0,col);
123      }
124    }
125 
126  return val;
127  }
128
129
130
131//! accumulate the elements of a matrix
132template<typename T1>
133arma_hot
134inline
135typename enable_if2< is_arma_type<T1>::value, typename T1::elem_type >::result
136accu(const T1& X)
137  {
138  arma_extra_debug_sigprint();
139 
140  const Proxy<T1> P(X);
141 
142  const bool have_direct_mem = (is_Mat<typename Proxy<T1>::stored_type>::value) || (is_subview_col<typename Proxy<T1>::stored_type>::value);
143 
144  return (Proxy<T1>::prefer_at_accessor) ? accu_proxy_at(P) : (have_direct_mem ? accu_proxy_mat(P) : accu_proxy_linear(P));
145  }
146
147
148
149//! explicit handling of Hamming norm (also known as zero norm)
150template<typename T1>
151inline
152arma_warn_unused
153uword
154accu(const mtOp<uword,T1,op_rel_noteq>& X)
155  {
156  arma_extra_debug_sigprint();
157 
158  typedef typename T1::elem_type eT;
159 
160  const eT val = X.aux;
161 
162  const Proxy<T1> P(X.m);
163 
164  uword n_nonzero = 0;
165 
166  if(Proxy<T1>::prefer_at_accessor == false)
167    {
168    typedef typename Proxy<T1>::ea_type ea_type;
169   
170          ea_type A      = P.get_ea();
171    const uword   n_elem = P.get_n_elem();
172   
173    for(uword i=0; i<n_elem; ++i)
174      {
175      n_nonzero += (A[i] != val) ? uword(1) : uword(0);
176      }
177    }
178  else
179    {
180    const uword P_n_cols = P.get_n_cols();
181    const uword P_n_rows = P.get_n_rows();
182   
183    if(P_n_rows == 1)
184      {
185      for(uword col=0; col < P_n_cols; ++col)
186        {
187        n_nonzero += (P.at(0,col) != val) ? uword(1) : uword(0);
188        }
189      }
190    else
191      {
192      for(uword col=0; col < P_n_cols; ++col)
193      for(uword row=0; row < P_n_rows; ++row)
194        {
195        n_nonzero += (P.at(row,col) != val) ? uword(1) : uword(0);
196        }
197      }
198    }
199 
200  return n_nonzero;
201  }
202
203
204
205template<typename T1>
206inline
207arma_warn_unused
208uword
209accu(const mtOp<uword,T1,op_rel_eq>& X)
210  {
211  arma_extra_debug_sigprint();
212 
213  typedef typename T1::elem_type eT;
214 
215  const eT val = X.aux;
216 
217  const Proxy<T1> P(X.m);
218 
219  uword n_nonzero = 0;
220 
221  if(Proxy<T1>::prefer_at_accessor == false)
222    {
223    typedef typename Proxy<T1>::ea_type ea_type;
224   
225          ea_type A      = P.get_ea();
226    const uword   n_elem = P.get_n_elem();
227   
228    for(uword i=0; i<n_elem; ++i)
229      {
230      n_nonzero += (A[i] == val) ? uword(1) : uword(0);
231      }
232    }
233  else
234    {
235    const uword P_n_cols = P.get_n_cols();
236    const uword P_n_rows = P.get_n_rows();
237   
238    if(P_n_rows == 1)
239      {
240      for(uword col=0; col < P_n_cols; ++col)
241        {
242        n_nonzero += (P.at(0,col) == val) ? uword(1) : uword(0);
243        }
244      }
245    else
246      {
247      for(uword col=0; col < P_n_cols; ++col)
248      for(uword row=0; row < P_n_rows; ++row)
249        {
250        n_nonzero += (P.at(row,col) == val) ? uword(1) : uword(0);
251        }
252      }
253    }
254 
255  return n_nonzero;
256  }
257
258
259
260//! accumulate the elements of a subview (submatrix)
261template<typename eT>
262arma_hot
263arma_pure
264arma_warn_unused
265inline
266eT
267accu(const subview<eT>& X)
268  {
269  arma_extra_debug_sigprint(); 
270 
271  const uword X_n_rows = X.n_rows;
272  const uword X_n_cols = X.n_cols;
273 
274  eT val = eT(0);
275 
276  if(X_n_rows == 1)
277    {
278    typedef subview_row<eT> sv_type;
279   
280    const sv_type& sv = reinterpret_cast<const sv_type&>(X);  // subview_row<eT> is a child class of subview<eT> and has no extra data
281   
282    const Proxy<sv_type> P(sv);
283   
284    val = accu_proxy_linear(P);
285    }
286  else
287  if(X_n_cols == 1)
288    {
289    val = arrayops::accumulate( X.colptr(0), X_n_rows );
290    }
291  else
292    {
293    for(uword col=0; col < X_n_cols; ++col)
294      {
295      val += arrayops::accumulate( X.colptr(col), X_n_rows );
296      }
297    }
298 
299  return val;
300  }
301
302
303
304template<typename eT>
305arma_hot
306arma_pure
307arma_warn_unused
308inline
309eT
310accu(const subview_col<eT>& X)
311  {
312  arma_extra_debug_sigprint(); 
313 
314  return arrayops::accumulate( X.colptr(0), X.n_rows );
315  }
316
317
318
319//! accumulate the elements of a cube
320template<typename T1>
321arma_hot
322arma_warn_unused
323inline
324typename T1::elem_type
325accu(const BaseCube<typename T1::elem_type,T1>& X)
326  {
327  arma_extra_debug_sigprint();
328 
329  typedef typename T1::elem_type          eT;
330  typedef typename ProxyCube<T1>::ea_type ea_type;
331 
332  const ProxyCube<T1> A(X.get_ref());
333 
334  if(is_Cube<typename ProxyCube<T1>::stored_type>::value)
335    {
336    unwrap_cube<typename ProxyCube<T1>::stored_type> tmp(A.Q);
337   
338    return arrayops::accumulate(tmp.M.memptr(), tmp.M.n_elem);
339    }
340 
341 
342  if(ProxyCube<T1>::prefer_at_accessor == false)
343    {
344          ea_type P      = A.get_ea();
345    const uword   n_elem = A.get_n_elem();
346   
347    eT val1 = eT(0);
348    eT val2 = eT(0);
349   
350    uword i,j;
351   
352    for(i=0, j=1; j<n_elem; i+=2, j+=2)
353      {
354      val1 += P[i];
355      val2 += P[j];
356      }
357   
358    if(i < n_elem)
359      {
360      val1 += P[i];
361      }
362   
363    return val1 + val2;
364    }
365  else
366    {
367    const uword n_rows   = A.get_n_rows();
368    const uword n_cols   = A.get_n_cols();
369    const uword n_slices = A.get_n_slices();
370   
371    eT val1 = eT(0);
372    eT val2 = eT(0);
373   
374    for(uword slice=0; slice<n_slices; ++slice)
375    for(uword col=0; col<n_cols; ++col)
376      {
377      uword i,j;
378      for(i=0, j=1; j<n_rows; i+=2, j+=2)
379        {
380        val1 += A.at(i,col,slice);
381        val2 += A.at(j,col,slice);
382        }
383     
384      if(i < n_rows)
385        {
386        val1 += A.at(i,col,slice);
387        }
388      }
389   
390    return val1 + val2;
391    }
392  }
393
394
395
396template<typename T>
397arma_inline
398arma_warn_unused
399const typename arma_scalar_only<T>::result &
400accu(const T& x)
401  {
402  return x;
403  }
404
405
406
407//! accumulate values in a sparse object
408template<typename T1>
409arma_hot
410inline
411arma_warn_unused
412typename enable_if2<is_arma_sparse_type<T1>::value, typename T1::elem_type>::result
413accu(const T1& x)
414  {
415  arma_extra_debug_sigprint();
416 
417  typedef typename T1::elem_type eT;
418 
419  const SpProxy<T1> p(x);
420 
421  if(SpProxy<T1>::must_use_iterator == false)
422    {
423    // direct counting
424    return arrayops::accumulate(p.get_values(), p.get_n_nonzero());
425    }
426  else
427    {
428    typename SpProxy<T1>::const_iterator_type it     = p.begin();
429    typename SpProxy<T1>::const_iterator_type it_end = p.end();
430   
431    eT result = eT(0);
432   
433    while(it != it_end)
434      {
435      result += (*it);
436      ++it;
437      }
438   
439    return result;
440    }
441  }
442
443
444
445//! @}
Note: See TracBrowser for help on using the repository browser.