Jekyll2019-06-17T17:50:30-04:00https://www.amensch.fr/Arthur MenschMachine learning / Optimization / Neuro-imagingamenschMassive Matrix Factorization for fMRI and Recommender Systems2016-06-10T04:19:34-04:002016-06-10T04:19:34-04:00https://www.amensch.fr/2016/06/10/massive-matrix-factorization-for-fMRI-and-recommender-systems<p>Before presenting it at ICML New York, I will give a quick overview of our work on efficient
matrix factorization for very large datasets. Our focus was to scale <em>matrix
factorization</em> techniques for functional MRI, a domain where data to
decompose is now at terabyte scale. Along the way, we also designed a encouraging proof-of-concept
experiment for collaborative filtering.</p>
<p>We’ll start by reviewing matrix factorization techniques for interpretable data
decomposition.</p>
<h2 id="understanding-data-with-matrix-factorization">Understanding data with matrix factorization</h2>
<p>Unsupervised learning aim at finding patterns in a sequence of n samples
<code class="MathJax_Preview">(x_t)t</code><script type="math/tex">(x_t)t</script>, living in a <code class="MathJax_Preview">p</code><script type="math/tex">p</script> dimensional space. Typically, this involve finding a few statistics that describe data in a <em>compressed</em> manner. Our dataset can be seen as a large matrix <code class="MathJax_Preview">X \in R^{n \times p}</code><script type="math/tex">X \in R^{n \times p}</script>. Factorizing such matrix has proven a very flexible manner to extract interesting pattern. Namely, we want to find two <em>small</em> matrices <code class="MathJax_Preview">D</code><script type="math/tex">D</script> (the <em>dictionary</em>) and <code class="MathJax_Preview">A</code><script type="math/tex">A</script> (the <em>code</em>) with <code class="MathJax_Preview">k</code><script type="math/tex">k</script> columns/rows whose product approximates <code class="MathJax_Preview">X</code><script type="math/tex">X</script></p>
<p><img src="/assets/img/16-mmf/drawings/poster_model_sparse.png" width="80%" style="display: block; margin: 0 auto;" title="Model" /></p>
<p>Small can mean several things here : we may impose <code class="MathJax_Preview">k</code><script type="math/tex">k</script> to be small, which amounts to search for a low rank representation of the matrix <code class="MathJax_Preview">X</code><script type="math/tex">X</script>, and thus a subspace of <code class="MathJax_Preview">RR^p</code><script type="math/tex">RR^p</script> that approximately include all samples. For interpretability, it can be useful, as in the drawing above, to impose sparsity on <code class="MathJax_Preview">D</code><script type="math/tex">D</script> – this is what we’ll do in fMRI.</p>
<p>In other settings, we may have <code class="MathJax_Preview">k</code><script type="math/tex">k</script> large but impose <code class="MathJax_Preview">A</code><script type="math/tex">A</script> <em>sparse</em>, leading to an overcomplete dictionary learning setting, that generalize the k-means algorithm. This setting won’t interest us today, although we use its terminology.</p>
<h3 id="fmri-example">fMRI example</h3>
<p>We can already instantiate matrix factorization for fMRI as this will make things clearer. We study resting-state functional imaging : 500 subjects go four times in a scanner, to get their brain activity recorded during 15 minutes while at rest – roughly, a 3D image of their brain activity is acquired every second. This yields 2 millions 3D images of brain activity, each of them with 200 000 <em>voxels</em> – <strong>2TB</strong> of dense data. We want to extract spatial activity maps that constitute a good basis for these images:</p>
<p><img src="/assets/img/16-mmf/drawings/poster_fmri_dl_flat.png" width="80%" style="display: block; margin: 0 auto;" title="Model" /></p>
<p>What we are most interested in is the dictionary <code class="MathJax_Preview">D</code><script type="math/tex">D</script>, that holds, say, 70 sparse spatial maps. We expect those to capture functional networks, segmenting the auditory, visual, motor cortex, etc. Sparsity and low-rank are key for pattern discovery: we want to find few maps, with few activated regions.</p>
<h2 id="matrix-factorization-for-large-datasets">Matrix factorization for large datasets</h2>
<p>A little math should be introduced to better grasp our problem. Decomposing <code class="MathJax_Preview">X</code><script type="math/tex">X</script> into the product <code class="MathJax_Preview">D A</code><script type="math/tex">D A</script> can be done by solving an optimization problem (see <strong>[Olshausen ‘97]</strong> for the initial problem setting):</p>
<pre class="MathJax_Preview"><code>\min_{D \in \mathcal{C}, A \in R^{k\times p}} \Vert X - D A \Vert_2^2 + \lambda \Omega(\alpha)</code></pre>
<script type="math/tex; mode=display">\min_{D \in \mathcal{C}, A \in R^{k\times p}} \Vert X - D A \Vert_2^2 + \lambda \Omega(\alpha)</script>
<p>where structure and sparsity can be imposed via constraints (convex set <code class="MathJax_Preview">\mathcal{C}</code><script type="math/tex">\mathcal{C}</script>)
and penalties. For example, we may impose dictionary columns to live in <code class="MathJax_Preview">\ell_1</code><script type="math/tex">\ell_1</script> balls, to get a sparse dictionary.</p>
<p>Solving this minimization problem is where all the honey is : let’s see what methods can be used when <code class="MathJax_Preview">X</code><script type="math/tex">X</script> grows large.</p>
<p>A naive solver alternatively minimize the loss function over <code class="MathJax_Preview">A</code><script type="math/tex">A</script> and <code class="MathJax_Preview">D</code><script type="math/tex">D</script>. Meaning, given <code class="MathJax_Preview">X</code><script type="math/tex">X</script> and <code class="MathJax_Preview">A</code><script type="math/tex">A</script>, find the best <code class="MathJax_Preview">D</code><script type="math/tex">D</script>, given <code class="MathJax_Preview">X</code><script type="math/tex">X</script> and <code class="MathJax_Preview">D</code><script type="math/tex">D</script>, find the best <code class="MathJax_Preview">A</code><script type="math/tex">A</script>, and repeat. If we look at it from a dictionary oriented point of view, we define
<code class="MathJax_Preview">A(D) = \text{arg}\,\min_{A \in R^{k \times n}} \Vert X - D A \Vert_F^2 + \lambda \Omega(A)</code><script type="math/tex">A(D) = \text{arg}\,\min_{A \in R^{k \times n}} \Vert X - D A \Vert_F^2 + \lambda \Omega(A)</script></p>
<pre class="MathJax_Preview"><code>\alpha_i(D) = \text{arg}\,\min_{A \in R^{k \times n}} \Vert x_i - D \alpha_i \Vert_F^2 + \lambda \Omega(\alpha_i)</code></pre>
<script type="math/tex; mode=display">\alpha_i(D) = \text{arg}\,\min_{A \in R^{k \times n}} \Vert x_i - D \alpha_i \Vert_F^2 + \lambda \Omega(\alpha_i)</script>
<p>where the second equality has used the colummns <code class="MathJax_Preview">(\alpha_i)</code><script type="math/tex">(\alpha_i)</script> of <code class="MathJax_Preview">A</code><script type="math/tex">A</script> – we’ll see why in a minute. The naive algorithm simply consist in doing</p>
<pre class="MathJax_Preview"><code>\begin{aligned}
D_t &= \text{arg}\,\min_{D \in \mathcal{C}} \Vert X - D A(D_{t-1}) \Vert_F^2 \\
&= \min_{D} \sum_{i=1}^n \Vert x_i - D \alpha_i(D_{n-1})) \Vert_F^2
\end{aligned}</code></pre>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
D_t &= \text{arg}\,\min_{D \in \mathcal{C}} \Vert X - D A(D_{t-1}) \Vert_F^2 \\
&= \min_{D} \sum_{i=1}^n \Vert x_i - D \alpha_i(D_{n-1})) \Vert_F^2
\end{aligned} %]]></script>
<p>This takes time, as the whole data <code class="MathJax_Preview">X</code><script type="math/tex">X</script> is loaded at each iteration. In fact, it quickly becomes intractable: beyond 1 million entry in <code class="MathJax_Preview">X</code><script type="math/tex">X</script>, it already takes hours.</p>
<h3 id="going-online">Going online</h3>
<p>A very efficient way to get past this intractability was introduced by <strong>[Mairal ‘10]</strong>. Computing <code class="MathJax_Preview">A</code><script type="math/tex">A</script> for the whole dataset is costly, and overkill for a single step of improving the dictionary: we can maintain an approximation of this code by streaming the data and optimizing the dictionary along the stream.</p>
<p><img src="/assets/img/16-mmf/drawings/poster_model_sparse_online.png" width="80%" style="display: block; margin: 0 auto;" title="Model" /></p>
<p>As the drawing above indicates, we look at data sample <code class="MathJax_Preview">x_t</code><script type="math/tex">x_t</script> after
sample. At iteration <code class="MathJax_Preview">t</code><script type="math/tex">t</script>t, we use the current dictionary to compute the associated loadings
<code class="MathJax_Preview">\alpha_t</code><script type="math/tex">\alpha_t</script>:</p>
<pre class="MathJax_Preview"><code>\alpha_t(D) = \text{arg}\,\min_{A \in R^{k \times n}} \Vert x_t - D_{t-1} \alpha_t \Vert_F^2 + \lambda \Omega(\alpha_t)</code></pre>
<script type="math/tex; mode=display">\alpha_t(D) = \text{arg}\,\min_{A \in R^{k \times n}} \Vert x_t - D_{t-1} \alpha_t \Vert_F^2 + \lambda \Omega(\alpha_t)</script>
<p>We then solve, at each iteration</p>
<pre class="MathJax_Preview"><code>D_t = \text{arg}\,\min_{D \in \mathcal{C}} \sum_{i=1}^t \Vert x_i - D \alpha_i \Vert_F^2</code></pre>
<script type="math/tex; mode=display">D_t = \text{arg}\,\min_{D \in \mathcal{C}} \sum_{i=1}^t \Vert x_i - D \alpha_i \Vert_F^2</script>
<p>This look very much like the original update, except we use outdated
<code class="MathJax_Preview">\alpha_t</code><script type="math/tex">\alpha_t</script> to approximate our objective function. The essential idea here is
start solving the problem with a very inaccurate approximation of it, and
improve it by looking at more data.</p>
<p>A single iteration of the algorithm depend on <code class="MathJax_Preview">p</code><script type="math/tex">p</script> but no longer on <code class="MathJax_Preview">n</code><script type="math/tex">n</script>, and the
algorithm empirically converges in a few epochs on data. This is very efficient
when data dimension <code class="MathJax_Preview">p</code><script type="math/tex">p</script> is reasonable – as a matter of fact the online algorithm
was initially designed to handle large sequences of 16x16 image patches – <strong>a
very low p compared to fMRI setting</strong>.</p>
<h2 id="handling-large-sample-dimension">Handling large sample dimension</h2>
<p>This is where our contribution begins. We want to provide an algorithm that
scales not only in the number of samples but also in the sample dimension. To
scale in the number of samples, we went from using <code class="MathJax_Preview">X</code><script type="math/tex">X</script> to using <code class="MathJax_Preview">x_t</code><script type="math/tex">x_t</script> at
each iteration, allowing around n time faster iterations. Here, <code class="MathJax_Preview">x_t</code><script type="math/tex">x_t</script> is
still too large, and <strong>we want to acquire information even faster</strong>.</p>
<p>This is where we introduce <em>random subsampling</em>: can we improve the dictionary
with only a <em>fraction</em> of a sample at each iteration. The answer is yes, as we’ll
now show. The algorithm we propose loads a fraction of a sample <code class="MathJax_Preview">x_t</code><script type="math/tex">x_t</script> at each
iteration and use it to update the approximation of the optimization problem.
The fraction is different at each iteration: this way, we are able to obtain
information about the whole feature space, in a stochastic manner. We go a step
beyond in randomness:</p>
<p><img src="/assets/img/16-mmf/drawings/poster_next_level.png" alt="Random subsampling" /></p>
<p><code class="MathJax_Preview">M_t x_t</code><script type="math/tex">M_t x_t</script> corresponds to a subsampling of <code class="MathJax_Preview">x_t</code><script type="math/tex">x_t</script>, choosing <code class="MathJax_Preview">M_t</code><script type="math/tex">M_t</script> to be a <code class="MathJax_Preview">[0, 1]</code><script type="math/tex">[0, 1]</script> diagonal matrix with, say, 90% zeros.</p>
<p>The whole difficulty lies in constructing the right approximations so that the
problem we solve at each iteration looks more and more like the original
optimization problem – just like the online algorithm does.</p>
<p>The online algorithm relies on a few low dimensional statistics that
sufficiently describe the approximate problem. These are updated in a
<code class="MathJax_Preview">\mathcal{O}(p)</code><script type="math/tex">\mathcal{O}(p)</script> cost – ensuring scalable single iteration, and hence the online magic.</p>
<p>Our objective here is to speed up iteration of a constant factor, that
corresponds to the factor of dimension reduction. We must therefore ensure that
everything we do at iteration t scales in <code class="MathJax_Preview">\mathcal{O}(s)</code><script type="math/tex">\mathcal{O}(s)</script>, where <code class="MathJax_Preview">s</code><script type="math/tex">s</script> is the <em>reduced</em> dimension.
That way, we gain a constant factor (from 2 to 12 on large datasets, as we’ll
see) on single iteration complexity (<em>computational speed-up</em>), and we expect not to
loose it because of the approximation we introduce (<em>approximation errance</em>).</p>
<p>This is because <strong>very large datasets have often many redundancies</strong>, accessing
a stochastic part of sample does not reduce much the information acquired at
each iteration. As we’ll see, on large datasets, the balance is therefore very
much on the side of single iteration computational speed-up.</p>
<p>The constraint we introduce on iteration complexity restrains much what we are able to do. To sum up, we have to adapt the three steps of the online algorithm</p>
<ul>
<li>Computing the code from past iterate : we rely on a <em>sketched</em> version of code computation, where we only look at <code class="MathJax_Preview">M_t</code><script type="math/tex">M_t</script> features of <code class="MathJax_Preview">x_t</code><script type="math/tex">x_t</script> and <code class="MathJax_Preview">D_{t-1}</code><script type="math/tex">D_{t-1}</script></li>
</ul>
<pre class="MathJax_Preview"><code>\begin{aligned}\alpha_t(D) &= \text{arg}\,\min_{A \in R^{k \times n}} \Vert M_t(x_t - D_{t-1} \alpha_t) \Vert_F^2 + \lambda \frac{s}{p} \Omega(\alpha)
\end{aligned}</code></pre>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}\alpha_t(D) &= \text{arg}\,\min_{A \in R^{k \times n}} \Vert M_t(x_t - D_{t-1} \alpha_t) \Vert_F^2 + \lambda \frac{s}{p} \Omega(\alpha)
\end{aligned} %]]></script>
<ul>
<li>
<p>Aggregating this partial sample and code in an approximative objective, as we do by summing <code class="MathJax_Preview">t</code><script type="math/tex">t</script> factors in the online algorithm. We have to do this in a clever manner so that we only update statistics of size in s and not in p. This includes keeping tracks of the number of time we saw a feature in the past.</p>
</li>
<li>
<p>Updating the dictionary: we can’t update the full <code class="MathJax_Preview">D</code><script type="math/tex">D</script> at each iteration as this is <code class="MathJax_Preview">\mathcal{O}(p)</code><script type="math/tex">\mathcal{O}(p)</script> costly. It makes sense to update the features of the dictionary atoms that were seen in <code class="MathJax_Preview">M_t</code><script type="math/tex">M_t</script>, ensuring that <code class="MathJax_Preview">D</code><script type="math/tex">D</script> remains in <code class="MathJax_Preview">\mathcal{C}</code><script type="math/tex">\mathcal{C}</script> by projection.</p>
</li>
</ul>
<p>I skipped the math in the two last parts, but you can access it in more detail <a href="docs/presentations/icml_presentation.pdf">on these slides</a>. You will also find a detailed comparison between our algorithm and the original online algorithm.</p>
<h2 id="results">Results</h2>
<p>Let’s get to the most important part: do we get desired speed-up, is the dictionary
we compute as good as those we would obtain with previous algorithms ?</p>
<p><strong>On fMRI, we can push the reduction up to x12 and obtain x10 speed-up compared to the online algorithm</strong>. Remember that a single pass on the data would take 235h using the online algorithm. We’ll use the obtained maps as a baseline. Maps are blobish, with noiseless contours.</p>
<p>In no more than 10h, our algorithm, using a 12-fold reduction, is able to recover maps that are almost as epxloitable as the baseline one. In comparison, the original algorithm stopped after 10h yields very poor results: noisy maps with many blobs.</p>
<p>Displaying the contour of these maps makes it clearly appear:</p>
<p><img src="/assets/img/16-mmf/figures/brains.png" alt="Brains" /></p>
<p>We can quantify the speed-up we obtain by looking at convergence curve, that decribe how good the dictionary perform as a basis on a test set, against time spent in computation.</p>
<p><img src="/assets/img/16-mmf/figures/bench.png" width="70%" style="display: block; margin: 0 auto;" title="Bench" /></p>
<p><strong>Convergence is obtained x10 more quickly</strong> with a 12 times reduction.
This is very valuable for practioners ! Information is indeed acqired faster,
as the speed-up we obtained is close to the reduction we imposed.</p>
<h2 id="collaborative-filtering">Collaborative filtering</h2>
<p>Our setting imposes masks on data to speed up learning. Quite interestingly,
collaborative filtering brings us a setting where we can only acces <em>masked</em>
data, that corresponds to, for example, the few movies that a user has rated.
Matrix factorization is there used to reconstructe the incomplete matrix <code class="MathJax_Preview">X</code><script type="math/tex">X</script> (see, for instance <strong>[Szabo ‘11’]</strong>).
To evaluate the performance of our model, we look at rating prediction on a
test set. We compare our algorithm with a fast coordiate descent solver from
<a href="http://github.com/mathieublondel/spira">spira</a>, that does not involve setting any
hyperparmeter – our algorithm is, unlike SGD, not very dependant on
hyperparameters. We get good results on large datasets (Netflix,
Movielens 10M), as these benches show. On <strong>Netflix</strong>, our algorithm is <strong>7x faster</strong> than the coordinate descent solver, which was the fastest well-packaged collaborative filtering algorithm we could find.</p>
<p><img src="/assets/img/16-mmf/figures/rec_bench.png" width="100%" style="display: block; margin: 0 auto;" title="Collaborative filtering benches" /></p>
<p>Our model is very simple (minimization of an <code class="MathJax_Preview">\ell_2</code><script type="math/tex">\ell_2</script> loss), and we do not get
state of the art prediction on Netflix. However, this experiment shows that our
algorithm is able to learn a decomposition even with non random masks, and
demonstrate the efficiency of imposing the complexity constraints explained
above.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Leveraging random subsampling with online learning is thus a very efficient manner to perform matrix factorization on datasets large in both direction. Our algorithm had no convergence guarantee at the time of contribution (February), but we now have a slightly adapted algorithm that converges with the same guarantee as in the original online matrix factorization paper – we rely on the stochastic majorization minimization framework <strong>[Mairal ‘13]</strong>.</p>
<p><a href="http://github.com/arthurmensch/modl">A Python package</a> is available for reproducibility. We hope to integrate this algorithm in more well-known library in the long term.</p>
<p>I hope that this post was readable enough and has interested you. You’ll find
more details in our <a href="https://hal.archives-ouvertes.fr/hal-01308934">paper</a>,
<a href="/docs/posters/icml_poster.pdf">poster</a> and <a href="/docs/presentations/icml_presentation.pdf">slides</a>. I
will present this work in ICML New York Monday June 20th at 11h45. Discussions are
welcome !</p>
<h2 id="references">References</h2>
<ul>
<li>
<p><strong>[Mairal ‘10]</strong> Mairal, Julien, Francis Bach, Jean Ponce, and Guillermo Sapiro. “Online Learning for Matrix Factorization and Sparse Coding.” The Journal of Machine Learning Research, 2010.</p>
</li>
<li>
<p><strong>[Mairal ‘13]</strong> Mairal, Julien. “Stochastic Majorization-\minimization Algorithms for Large-Scale Optimization.” In Advances in Neural Information Processing Systems, 2013.</p>
</li>
<li>
<p><strong>[Olshausen ‘97]</strong> Olshausen, Bruno A., and David J. Field. “Sparse Coding with an Overcomplete Basis Set: A Strategy Employed by V1?” Vision Research, 1997.</p>
</li>
<li>
<p><strong>[Szabo ‘11]</strong> Szabó, Zoltán, Barnabás Póczos, and András Lorincz. “Online Group-Structured Dictionary Learning.” In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2011.</p>
</li>
<li>
<p>See also <a href="/docs/presentations/icml_presentation.pdf">these slides</a></p>
</li>
</ul>Arthur Menscharthur.mensch@m4x.orgBefore presenting it at ICML New York, I will give a quick overview of our work on efficient matrix factorization for very large datasets. Our focus was to scale matrix factorization techniques for functional MRI, a domain where data to decompose is now at terabyte scale. Along the way, we also designed a encouraging proof-of-concept experiment for collaborative filtering.