source: mmcs/armadillo_bits/spglue_times_meat.hpp @ 8daa049

matrices
Last change on this file since 8daa049 was 9dd61b1, checked in by rboet <rboet@…>, 9 years ago

Avance del proyecto 60%

  • Property mode set to 100644
File size: 7.8 KB
Line 
1// Copyright (C) 2012 Ryan Curtin
2// Copyright (C) 2012 Conrad Sanderson
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8
9//! \addtogroup spglue_times
10//! @{
11
12
13
14template<typename T1, typename T2>
15inline
16void
17spglue_times::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_times>& X)
18  {
19  arma_extra_debug_sigprint();
20 
21  typedef typename T1::elem_type eT;
22 
23  const SpProxy<T1> pa(X.A);
24  const SpProxy<T2> pb(X.B);
25 
26  const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
27 
28  if(is_alias == false)
29    {
30    spglue_times::apply_noalias(out, pa, pb);
31    }
32  else
33    {
34    SpMat<eT> tmp;
35    spglue_times::apply_noalias(tmp, pa, pb);
36   
37    out.steal_mem(tmp);
38    }
39  }
40
41
42
43template<typename eT, typename T1, typename T2>
44arma_hot
45inline
46void
47spglue_times::apply_noalias(SpMat<eT>& c, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
48  {
49  arma_extra_debug_sigprint();
50 
51  const uword x_n_rows = pa.get_n_rows();
52  const uword x_n_cols = pa.get_n_cols();
53  const uword y_n_rows = pb.get_n_rows();
54  const uword y_n_cols = pb.get_n_cols();
55
56  arma_debug_assert_mul_size(x_n_rows, x_n_cols, y_n_rows, y_n_cols, "matrix multiplication");
57
58  // First we must determine the structure of the new matrix (column pointers).
59  // This follows the algorithm described in 'Sparse Matrix Multiplication
60  // Package (SMMP)' (R.E. Bank and C.C. Douglas, 2001).  Their description of
61  // "SYMBMM" does not include anything about memory allocation.  In addition it
62  // does not consider that there may be elements which space may be allocated
63  // for but which evaluate to zero anyway.  So we have to modify the algorithm
64  // to work that way.  For the "SYMBMM" implementation we will not determine
65  // the row indices but instead just the column pointers.
66 
67  //SpMat<typename T1::elem_type> c(x_n_rows, y_n_cols); // Initializes col_ptrs to 0.
68  c.zeros(x_n_rows, y_n_cols);
69 
70  //if( (pa.get_n_elem() == 0) || (pb.get_n_elem() == 0) )
71  if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) )
72    {
73    return;
74    }
75 
76  // Auxiliary storage which denotes when items have been found.
77  podarray<uword> index(x_n_rows);
78  index.fill(x_n_rows); // Fill with invalid links.
79 
80  typename SpProxy<T2>::const_iterator_type y_it  = pb.begin();
81  typename SpProxy<T2>::const_iterator_type y_end = pb.end();
82
83  // SYMBMM: calculate column pointers for resultant matrix to obtain a good
84  // upper bound on the number of nonzero elements.
85  uword cur_col_length = 0;
86  uword last_ind = x_n_rows + 1;
87  do
88    {
89    const uword y_it_row = y_it.row();
90   
91    // Look through the column that this point (*y_it) could affect.
92    typename SpProxy<T1>::const_iterator_type x_it = pa.begin_col(y_it_row);
93   
94    while(x_it.col() == y_it_row)
95      {
96      // A point at x(i, j) and y(j, k) implies a point at c(i, k).
97      if(index[x_it.row()] == x_n_rows)
98        {
99        index[x_it.row()] = last_ind;
100        last_ind = x_it.row();
101        ++cur_col_length;
102        }
103
104      ++x_it;
105      }
106
107    const uword old_col = y_it.col();
108    ++y_it;
109
110    // See if column incremented.
111    if(old_col != y_it.col())
112      {
113      // Set column pointer (this is not a cumulative count; that is done later).
114      access::rw(c.col_ptrs[old_col + 1]) = cur_col_length;
115      cur_col_length = 0;
116
117      // Return index markers to zero.  Use last_ind for traversal.
118      while(last_ind != x_n_rows + 1)
119        {
120        const uword tmp = index[last_ind];
121        index[last_ind] = x_n_rows;
122        last_ind = tmp;
123        }
124      }
125    }
126  while(y_it != y_end);
127
128  // Accumulate column pointers.
129  for(uword i = 0; i < c.n_cols; ++i)
130    {
131    access::rw(c.col_ptrs[i + 1]) += c.col_ptrs[i];
132    }
133
134  // Now that we know a decent bound on the number of nonzero elements, allocate
135  // the memory and fill it.
136  c.mem_resize(c.col_ptrs[c.n_cols]);
137
138  // Now the implementation of the NUMBMM algorithm.
139  uword cur_pos = 0; // Current position in c matrix.
140  podarray<eT> sums(x_n_rows); // Partial sums.
141  sums.zeros();
142 
143  // setting the size of 'sorted_indices' to x_n_rows is a better-than-nothing guess;
144  // the correct minimum size is determined later
145  podarray<uword> sorted_indices(x_n_rows);
146 
147  // last_ind is already set to x_n_rows, and cur_col_length is already set to 0.
148  // We will loop through all columns as necessary.
149  uword cur_col = 0;
150  while(cur_col < c.n_cols)
151    {
152    // Skip to next column with elements in it.
153    while((cur_col < c.n_cols) && (c.col_ptrs[cur_col] == c.col_ptrs[cur_col + 1]))
154      {
155      // Update current column pointer to actual number of nonzero elements up
156      // to this point.
157      access::rw(c.col_ptrs[cur_col]) = cur_pos;
158      ++cur_col;
159      }
160
161    if(cur_col == c.n_cols)
162      {
163      break;
164      }
165
166    // Update current column pointer.
167    access::rw(c.col_ptrs[cur_col]) = cur_pos;
168
169    // Check all elements in this column.
170    typename SpProxy<T2>::const_iterator_type y_col_it = pb.begin_col(cur_col);
171   
172    while(y_col_it.col() == cur_col)
173      {
174      // Check all elements in the column of the other matrix corresponding to
175      // the row of this column.
176      typename SpProxy<T1>::const_iterator_type x_col_it = pa.begin_col(y_col_it.row());
177
178      const eT y_value = (*y_col_it);
179
180      while(x_col_it.col() == y_col_it.row())
181        {
182        // A point at x(i, j) and y(j, k) implies a point at c(i, k).
183        // Add to partial sum.
184        const eT x_value = (*x_col_it);
185        sums[x_col_it.row()] += (x_value * y_value);
186
187        // Add point if it hasn't already been marked.
188        if(index[x_col_it.row()] == x_n_rows)
189          {
190          index[x_col_it.row()] = last_ind;
191          last_ind = x_col_it.row();
192          }
193
194        ++x_col_it;
195        }
196
197      ++y_col_it;
198      }
199
200    // Now sort the indices that were used in this column.
201    //podarray<uword> sorted_indices(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]);
202    sorted_indices.set_min_size(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]);
203   
204    // .set_min_size() can only enlarge the array to the specified size,
205    // hence if we request a smaller size than already allocated,
206    // no new memory allocation is done
207   
208   
209    uword cur_index = 0;
210    while(last_ind != x_n_rows + 1)
211      {
212      const uword tmp = last_ind;
213
214      // Check that it wasn't a "fake" nonzero element.
215      if(sums[tmp] != eT(0))
216        {
217        // Assign to next open position.
218        sorted_indices[cur_index] = tmp;
219        ++cur_index;
220        }
221
222      last_ind = index[tmp];
223      index[tmp] = x_n_rows;
224      }
225
226    // Now sort the indices.
227    if (cur_index != 0)
228      {
229      op_sort::direct_sort_ascending(sorted_indices.memptr(), cur_index);
230
231      for(uword k = 0; k < cur_index; ++k)
232        {
233        const uword row = sorted_indices[k];
234        access::rw(c.row_indices[cur_pos]) = row;
235        access::rw(c.values[cur_pos]) = sums[row];
236        sums[row] = eT(0);
237        ++cur_pos;
238        }
239      }
240
241    // Move to next column.
242    ++cur_col;
243    }
244
245  // Update last column pointer and resize to actual memory size.
246  access::rw(c.col_ptrs[c.n_cols]) = cur_pos;
247  c.mem_resize(cur_pos);
248  }
249
250
251
252//
253//
254// spglue_times2: scalar*(A * B)
255
256
257
258template<typename T1, typename T2>
259inline
260void
261spglue_times2::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_times2>& X)
262  {
263  arma_extra_debug_sigprint();
264 
265  typedef typename T1::elem_type eT;
266 
267  const SpProxy<T1> pa(X.A);
268  const SpProxy<T2> pb(X.B);
269 
270  const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
271 
272  if(is_alias == false)
273    {
274    spglue_times::apply_noalias(out, pa, pb);
275    }
276  else
277    {
278    SpMat<eT> tmp;
279    spglue_times::apply_noalias(tmp, pa, pb);
280   
281    out.steal_mem(tmp);
282    }
283 
284  out *= X.aux;
285  }
286
287
288
289//! @}
Note: See TracBrowser for help on using the repository browser.