Iterative methods in Rust: Conjugate Gradient

Introduction: the Beauty and the Trap

Iterative methods (IM) power many magical-seeming uses of computers, from deep learning to PageRank. An IM repeatedly applies a simple recipe which improves an approximate solution to a problem. As a block of wood in the right hands gradually transforms into a figurine, the right method will produce a sequence of solutions approaching perfection from an arbitrarily bad initial value. IM are quite unlike most CS algorithms, and take some getting used to, but for many problems they are the only game in town. Unfortunately, implementations of IM are often hard to maintain (read, test, benchmark) because they mix the recipe with other concerns.

We start with a simple concrete example. Implementing an iterative method often starts as a for loop around the recipe (to try it out, press the play button):

#![allow(unused)]
fn main() {
// Problem: minimize the convex parabola f(x) = x^2 + x
// An iterative solution by gradient descent
let mut x = 2.0;
for i in 0..10 {
	// "2.0*x + 1.0" is the derivative of f(x)
	// so moving a little bit in the opposite direction
	x -= 0.2 * (2.0*x + 1.0);
	// should reduce the value of f(x)!
	println!("x_{} = {:.2}; f(x_{}) = {:.4}", i, x, i, x*x + x);
}
// An alternative that does not scale well to harder problems is to 
// use the fact `f` does not decrease/increase at the minimum to 
// find its location directly:
//
// `f` at a minimum zeros its derivative: 2*x + 1 = 0. 
// Rearranging terms gives the solution x = -1/2.
}

... but that loop body already mixes multiple concerns. If it didn't println progress, we wouldn't even see it is working. If it didn't stop after 10 iterations, it would go on forever. But a reader might justly be confused why that limit is there: does the method work only for 10 iterations? and perhaps a different user does not want any intermediate output, must they reimplement the recipe? how should they ensure the reimplementation is also correct? The approach taken in this example does not allow additional functionalities to be composed cleanly, and thus they build up like barnacles on a ship. Is there a way to keep the recipe separate from other concerns?

The answer is yes. This series of posts is about such a way to implement IM in Rust using iterators. This design, realized for Rust in the iterative_methods crate (repo), follows (and expands on) the approach "Iterative Methods Done Right" (IMDR), by Lorenzo Stella, demonstrated with Julia iterables. The main idea is that each IM is an iterator over algorithm states, and each reusable "utility" is an iterator adaptor augmenting or processing them. These components are highly reusable, composable and can be tailored using closures. This is basically an extension to this domain of an approach the Rust standard library already follows. But if it seems a bit dense, that's fine, we'll unpack it over this post and the next, introducing along the way a common iterative method as well as measures of IM quality.

Beyond reuse of methods and utilities, there is another reason to separate them: iterative methods are often highly sensitive beasts. Small changes can cause subtle numerical and/or dynamic effects that are very difficult to trace and resolve. Thus a primary design goal is to minimize the need for modifications to the method implementation, even when attempting to study/debug it.

Our first full iterative method

Our main guest IM for this post and the next will be Conjugate Gradient (CG). What is CG for? here is one example: to simulate physics like spread of heat (or deformation, fluid flows, etc) over a complex shape, we break it up into simpler pieces to apply the Finite Element Method (FEM). In FEM, the relations between temperatures at different elements implied by the heat equation are encoded into a matrix \(A\) and a vector \(b\). To find the vector of temperatures \(x\), it is enough to solve the matrix equation \(Ax = b\). CG is an IM for solving such equations when \(A\) is positive definite (PD)1, which it is for a wide variety of domains.

Why the conjugate gradient method works is beyond the scope of this post, but good sources include the Wikipedia exposition and also these lecture notes.

1

Positive Definite matrices

A positive definite matrix \(A\) only scales \(x\)'s differently in different directions; no rotation or flipping allowed.

Implementation and a general interface

To store the state our method maintains we define a struct ConjugateGradient, for which we implement the StreamingIterator trait. This trait is simple, and requires us to implement two methods:

  • advance applies one iteration of the algorithm, updating state.
  • get returns a borrow of the Item type, generally some part of its state.

The benefit of the StreamingIterator trait over the ubiquitous Iterator is get exposing information by reference; this leaves decisions to copy state up to the implementor.

The signatures for implementing an iterative method in this style are as follows:

#[derive(Clone, Debug)]
pub struct ConjugateGradient {
    // State of the algorithm
}

impl ConjugateGradient {
    
    /// Initialize a conjugate gradient iterative solver to solve linear system `p`.
	pub fn for_problem(p: &LinearSystem) -> ConjugateGradient {
	  // Problem data such as the LinearSystem is often large, and should not be 
	  // duplicated. This implementation uses ndarray's ArcArray's which are cheap
	  // to clone as they share the data they point to.
	}
}

impl StreamingIterator for ConjugateGradient {
    type Item = Self;
    fn advance(&mut self) {
	    // the improvement recipe goes here
    }

    fn get(&self) -> Option<&Self::Item> {
	    // Return self, immutably borrowed. This allows callers read-only access 
		// to method state. The following is a bit simplified:
        Some(self)
    }
}

Note a few design decisions in the above:

  • The problem is a distinct concept from any method for solving it. The same problem representation (here, LinearSystem) often can and should be reused to initialize different methods.
  • The constructor method for_problem is responsible to set up the initial state for the first iteration, and so is part of the method definition.
  • Another constructor responsibility is to perform applicable and cheap checks of the input problem; expensive initialization is a bad fit for an iterative method.
  • Item is set to the whole ConjugateGradient, all algorithm state. We could set the Item type returned by the get method be only a result field, thus hiding implementation details from downstream. Similarly, there is some flexibility in defining the iterable struct: beyond a minimal representation of state required for the next iteration, should we add fields to store intermediate steps of calculations? How about auxiliary information not needed at all in the method itself? Consider the following excerpt from the implementation of advance:
        // while r_k != 0:
        //   alpha_k = ||r_k||^2 / ||p_k||^2_A

        self.alpha_k = self.r_k2 / self.pap_k;
        if (!too_small(self.r_k2)) && (!too_small(self.pap_k)) {
            //   x_{k+1} = x_k + alpha_k*p_k
			...

Where self.alpha_k is only read in the remainder of the recipe. So why not make it a temporary, instead of a fields of ConjugateGradient? This would seem to shrink the struct saving memory and hide an unnecessary detail, generally positive outcomes, right? But soon after implementing this code I found myself wanting to print alpha_k, which is impossible for a local without modifying the advance method! By storing more intermediate state in the iterator state, exposing all of it via get, and inspecting it externally, we avoid modifying the method for our inspection and the dreaded Heisenbugs that could ensue. On top of a solid whitebox implementation, we can always build an interface that abstracts away some aspects.

Running an Iterative Method

How do we call such an implementation? the example below illustrates a common workflow:

    // First we generate a problem, which consists of the pair (A,b).
    let p = make_3x3_pd_system_2();

    // Next convert it into an iterator
    let mut cg_iter = ConjugateGradient::for_problem(&p);

    // and loop over intermediate solutions.
    // Note `next` is provided by the StreamingIterator trait using
    // `advance` then `get`.
    while let Some(result) = cg_iter.next() {
        // We want to find x such that a.dot(x) = b
        // then the difference between the two sides (called the residual),
        // is a good measure of the error in a solution.
        let res = result.a.dot(&result.solution) - &result.b;

        // The (squared) length of the residual is a cost, a number
        // summarizing how bad a solution is. When working on iterative
        // methods, we want to see these numbers decrease quickly.
        let res_squared_length = res.dot(&res);

        // || ... ||_2 is notation for euclidean length of what
        // lies between the vertical lines.
        println!(
            "||Ax - b||_2 = {:.5}, for x = {:.4}, residual = {:.7}",
            res_squared_length.sqrt(),
            result.solution,
            res
        );
        // Stop if residual is small enough
        if res_squared_length < 1e-3 {
            break;
        }
    }

Indeed the output shows nice convergence, with the residual \(Ax - b\) tending quickly to zero:

||Ax - b||_2 = 1.00000, for x = [+0.000, +0.000, +0.000], residual = [+0.000, -1.000, +0.000]
||Ax - b||_2 = 0.94281, for x = [+0.000, +0.667, +0.000], residual = [+0.667, +0.000, +0.667]
||Ax - b||_2 = 0.00000, for x = [-4.000, +6.000, -4.000], residual = [+0.000, +0.000, +0.000]

In terms of the code, notice the algorithm is taken out of the loop! We do not modify it merely to report progress, not even to decide when to stop. But we do change that loop body, which gets a bit messy. Once we start looking for such niceties, soon we'll want to:

  • look at only every Nth iteration,
  • measure the runtime of an iteration (excluding the cost of reporting itself),
  • plot progress over time,
  • save progress in case of a power failure...

for basically every method we work on, and we certainly don't want all of those tangled up in our loops. We will want reusable components, named to convey intention! As mentioned above, the idea of representing processes with streaming iterators applies in a similar way to utilities as well, in a way that is clean and orthogonal. We demonstrate this in the next post.

Looking beyond design for code reuse, IM also put a new twist on benchmarking and testing. How does one time or test code that doesn't really want to stop, and for which solutions only approach correctness? We'll get to those questions as well.


Thanks to Daniel Fox (a collaborator on this project) and Yevgenia Vainsencher for feedback on early versions of this post.