Functional Programming in Java


This is the blog associated with my book "Functional Programming in Java", published by Manning (https://manning.com/books/functional-programming-in-java). Functional Programming is not tied to any "functional" language. It is a way of thinking and a paradigm for writing programs. Some languages are more "functional friendly", and if you intend to write program using the functional paradigm, you should probably pick one of them... if you can. This blog (and the associated book) is dedicated to those who can't.


Stack safe recursion in Java

In this article, excerpted from my book Functional Programming in Java, I explain how to use recursion while avoiding the risk of a StackOverflowException.

Corecursion is composing computing steps by using the output of one step as the input of the next one starting with the first step. Recursion is the same operation, but starting with the last step. In this case, we have to delay evaluation until we encounter a base condition (corresponding to the first step of corecursion).

Let’s say we have only two instructions in our programming language: incrementation (adding 1 to a value) and decrementation (subtracting 1 from a value). Let’s implement addition by composing these instructions.

Corecursive and recursive addition examples

In order to add two positive numbers, x and y, we can do the following:

  • if y == 0, return x

  • otherwise, increment x, decrement y and start again.

This may be written in Java as:

static int add(int x, int y) {
  while(y > 0) {
    ++x;
    --y;
  }
  return x;
}

or simpler:

static int add(int x, int y) {
  while(y-- > 0) {
    ++x;
  }
  return x;
}

Note that there is no problem with using the parameters x and y directly, since in Java, all parameters are passed by value. Also note that we have used post decrementation to simplify coding. We could however have used pre decrementation by slightly changing the condition, thus switching form iterating from y to 1 to iterating from y 1 to 0:

static int add(int x, int y) {
  while(--y >= 0) {
    ++x;
  }
  return x;
}

The recursive version is trickier, but still very simple:

static int addRec(int x, int y) {
  return y == 0
      ? x
      : addRec(++x, --y);
}

Both approaches seem to work but if we try the recursive version with big numbers, we may get a surprise. Although

addRec(10000, 3);

produces the expected result 10003, switching the parameters, like this:

addRec(3, 10000);

produces a StackOverflowException.

How is recursion implemented in Java?

To understand what’s happening, we must look at how Java handles method calls. When a method is called, Java suspends what it is currently doing and pushes the environment on the stack to make place for the called method execution. When this method returns, Java pops the stack to restore the environment and resume program execution. If we call one method after the other, the stack will always hold at least one of these methods call environment.

But methods are not only composed by calling them one after the other. Methods call methods. If method1 calls method2 as part of its implementation, Java again suspends method1 execution, pushes the current environment on the stack, and starts executing method2. When method2 returns, Java pops the last pushed environment from the stack and resume execution (of method1 in our case). When method1 completes, Java pops again from the stack and resume what it was doing before calling this method.

Of course, method calls may be deeply nested. Is there a limit to method nesting depth? Yes. The limit is the size of the stack. In current situations, the limit is somewhere around a few thousands of levels, although it is possible to increase this limit by configuring the stack size. However, the same stack size is used for all threads, so increasig the stack size for a single computation will generally waste space. Default stack size varies between 320k and 1024k depending on the version of Java and the system used. For a 64 bits Java 8 program with minimal stack usage, the maximum number of nested method calls is about 7,000. Generally, we don’t need more, excepted in very specific cases. One such case is recursive method calls.

Tail Call Elimination (TCE)

Pushing the environment on the stack seems necessary in order to allow resuming computation after the called method returns. But not always. When the call to a method is the last thing to do in the calling method, there is nothing to resume when it returns, so it should be OK to resume directly with the caller of the current method instead of the current method itself. A method call occurring in last position, meaning as the last thing to do before returning, is called a tail call. Avoiding pushing the environment onto the stack to resume method processing after a tail call is an optimization technique known as Tail Call Elimination (TCE). Unfortunately, Java does not implement TCE.

Tail Call Elimination is sometimes called Tail Call Optimization (TCO). TCE is generally an optimization, and we may live without it. However, when it comes to recursive function calls, TCE is no longer an optimization. It is a mandatory feature. That’s why TCE is a better term than TCO when it comes to handling recursion.

Tail recursive methods and functions

Most functional languages implement TCE. However, TCE is not enough to make every recursive call possible. To be candidate for TCE, the recursive call must be the last thing the method will have to do. Think of the following method computing the sum of the elements of a list:

static Integer sum(List<Integer> list) {
  return list.isEmpty()
      ? 0
      : head(list) + sum(tail(list));
}

This method uses the head() and tail() methods. Note that the recursive call to the sum method is not the last thing the method has to do. The four last things the method does are:

  • calling the head method

  • calling the tail method

  • calling the sum method

  • adding the result of head and the result of sum

Even if we had TCE, we would not be able to use this method with lists of 10,000 elements, because the recursive call is not in tail position. However, this method may be rewritten in order to put the call to sum in tail position:

static Integer sum_(List<Integer> list) {
  return sumTail(list, 0);
}

static Integer sumTail(List<Integer> list, int acc) {
  return list.isEmpty()
      ? acc
      : sumTail(tail(list), acc + head(list));
}

Now, the sumTail method is tail recursive and can be optimized through TCE.

Abstracting recursion

So far, so good, but why bother with all this since Java does not implement TCE? Well, Java does not implement it, but we can do without it. All we need to do is:

  • Representing unevaluated method calls

  • Storing them in a stack-like structure until we encounter a terminal condition

  • Evaluating the calls in LIFO order

Most examples of recursive methods use the Factorial function as an example. The others use the Fibonacci series example. To start our study, we will use the much simpler recursive addition method.

Recursive and corecursive functions are both functions where f(n) is a composition of f(n 1), f(n 2), f(n 3) and so on until a terminal condition is encountered (generally f(0) or f(1)). Remember that in traditional programming, composing generally means composing results of evaluations. This means that composing function f(a) and g(a) consists in evaluating g(a) then using the result as the input to f. It does not have to be done that way. You could develop a compose method to compose functions, and a higherCompose function to do the same thing. Neither of this methods or this function would evaluate the composed functions. They would only produce another function that could be applied later.

Recursion and corecursion are similar, although with a difference. We create a list of function calls instead of a list of functions. With corecursion, each step is terminal, so it may be evaluated in order to get the result and use it as input for the next step. With recursion, we start from the other end. So we have to put non evaluated calls in the list until we find a terminal condition, from which we can process the list in reverse order. In other words, we stack the steps (without evaluating them) until the last one is found, and then we process the stack in reverse order (last in first out), evaluating each step and using the result as the input for the next (in fact the previous) one.

The problem we have is that Java uses the thread stack for this, and its capacity is very limited. Typically, the stack will overflow after between 6,000 and 7,000 steps.

What we have to do is to create a function or a method returning a non evaluated step. To represent a step in the calculation, we will use an abstract class called TailCall (since we want to represent a call to a method that appears in tail position).

This TailCall abstract class will have two sub classes: one will represent an intermediate call, when the processing of one step is suspended to call the new method for evaluating the next step. This will be represented by a class named Suspend. It will be instantiated with a Supplier<TailCall>>, which will represent the next recursive call. This way, instead of putting all TailCalls instances in a list, we will construct a linked list by linking each `TailCall with the next. The benefit of this approach is that such a linked list is in fact a stack, offering constant time insertion as well as constant time access to the last inserted element, which is optimal for a LIFO structure.

The second implementation will represent the last call, which is suppose to return the result. So we will call it Return. It will not hold a link to the next TailCall, since there is nothing next, but it will hold the result. Here is what we get:

import java.util.function.Supplier;

public abstract class TailCall<T> {

  public static class Return<T> extends TailCall<T> {

    private final T t;

    public Return(T t) {
      this.t = t;
    }
  }

  public static class Suspend<T> extends TailCall<T> {

    private final Supplier<TailCall<T>> resume;

    private Suspend(Supplier<TailCall<T>> resume) {
      this.resume = resume;
    }
  }
}

To handle these classes, we will need some methods: one to return the result, one to return the next call, and one helper method to determine if a TailCall is a Suspend or a Return. We could avoid this last method, but we would have to use instanceof to do the job, which is ugly. The three methods will be:

public abstract TailCall<T> resume();

public abstract T eval();

public abstract boolean isSuspend();

The resume method will have no implementation in Return and will just throw a RuntimeException. The user of our API should not be in a situation to call this method, so if it is eventually called, it will be a bug and we will stop the application. In the Suspend class, it will return the next TailCall.

The eval method will return the result stored in the Return class. In our first version, it will throw a RuntimeException if called on the Suspend class.

The isSuspend method will simply return true in Suspend and false in Return. Listing 1 shows this first version.

Listing 1 The TailCall abstract class and its two sub classes
import java.util.function.Supplier;

public abstract class TailCall<T> {

  public abstract TailCall<T> resume();

  public abstract T eval();

  public abstract boolean isSuspend();

  public static class Return<T> extends TailCall<T> {

    private final T t;

    public Return(T t) {
      this.t = t;
    }

    @Override
    public T eval() {
      return t;
    }

    @Override
    public boolean isSuspend() {
      return false;
    }

    @Override
    public TailCall<T> resume() {
      throw new IllegalStateException("Return has no resume");
    }
  }

  public static class Suspend<T> extends TailCall<T> {

    private final Supplier<TailCall<T>> resume;

    public Suspend(Supplier<TailCall<T>> resume) {
      this.resume = resume;
    }

    @Override
    public T eval() {
      throw new IllegalStateException("Suspend has no value");
    }

    @Override
    public boolean isSuspend() {
      return true;
    }

    @Override
    public TailCall<T> resume() {
      return resume.get();
    }
  }
}

Now, to make our recursive method add to work with any number of steps (within the limits of available memory size!), we have little changes to make. Starting with our original method:

static int add(int x, int y) {
  return y == 0
      ? x
      : add(++x, --y)  ;
}

We just have to make the modifications shown in listing 2.

Listing 2 The modified recursive method
static TailCall<Integer> add(int x, int y) {  (1)
  return y == 0
      ? new TailCall.Return<>(x)   (2)
      : new TailCall.Suspend<>(() -> add(x + 1, y – 1));  (3)
}
1 Method now returns a TailCall
2 In terminal condition, a Return is returned
3 In non terminal condition, a Suspend is returned

Our method now returns a TailCall<Integer> instead of an int (<1>). This return value may be a Return<Integer> (<2>) if we have reached terminal condition, or a Suspend<Integer> (<3>) if we have not yet. The Return is instantiated with the result of the computation (which is x, since y is 0) and the Suspend is instantiated with a Supplier<TailCall<Integer>>, which is the next step of the computation in terms of execution sequence, or the previous in terms of calling sequence. It is very important to understand that the Return corresponds to the last step in terms of method call, but to the first step in terms of evaluation. Also note that we have slightly changed the evaluation, replacing ++x and --y by x + 1 and y – 1. This was necessary because we are using a closure, which works only if closed over variables are effectively final. This is cheating, but not so much. We could have created and called two methods dec and inc using the original operators.

What this method returns is a chain of TailCall instances, all being Suspend instances except the last one, which is a Return.

So far so good, but obviously, this method is not a drop in replacement for the original one. Not a big deal! The original method was used as:

System.out.println(add(x, y))

We can use our new method like this:

TailCall<Integer> tailCall = add(3, 100000000);
while(tailCall .isSuspend()) {
  tailCall = tailCall.resume();
}
System.out.println(tailCall.eval());

Doesn’t it look nice? Well if you feel somewhat frustrated, I can understand. You thought we would just use a new method in place of the old one, in a transparent manner. We seem to be far from this. But we can make things much better with little efforts.

Drop in replacement for stack base recursive methods

In the beginning of the previous section, we said that the user of our recursive API would have no opportunity to mess with the TailCall instances by calling resume on a Return or eval on a Suspend. This is easy to achieve by putting the evaluation code in the eval method of the Suspend class:

public static class Suspend<T> extends TailCall<T> {

  ...

  @Override
  public T eval() {
    TailCall<T> tailRec = this;
    while(tailRec.isSuspend()) {
      tailRec = tailRec.resume();
    }
    return tailRec.eval();
  }

Now we may get the result of the recursive call in a much simpler and safer way:

add(3, 100000000).eval()

But this is not yet what we want. We want to get rid of this call to the eval method. This can be done with a helper method:

import static com.fpinjava.excerpt.TailCall.ret;
import static com.fpinjava.excerpt.TailCall.sus;

. . .

public static int add(int x, int y) {
  return addRec(x, y).eval();
}

private static TailCall<Integer> addRec(int x, int y) {
  return y == 0
      ? ret(x)
      : sus(() -> addRec(x + 1, y - 1));
}

Now we can call the add method exactly as the original one. Note that we have made our recursive API easier to use by providing static factory methods to instantiate Return and Suspend:

public static <T> Return<T> ret(T t) {
  return new Return<>(t);
}

public static <T> Suspend<T> sus(Supplier<TailCall<T>> s) {
  return new Suspend<>(s);
}

Listing 3 shows the complete TailCall class. We have added a private no arg constructor in order to prevent extension by other classes.

Listing 3 The complete TailCall class
package com.fpinjava.excerpt;

import java.util.function.Supplier;

public abstract class TailCall<T> {

  public abstract TailCall<T> resume();

  public abstract T eval();

  public abstract boolean isSuspend();

  private TailCall() {}

  public static class Return<T> extends TailCall<T> {

    private final T t;

    private Return(T t) {
      this.t = t;
    }

    @Override
    public T eval() {
      return t;
    }

    @Override
    public boolean isSuspend() {
      return false;
    }

    @Override
    public TailCall<T> resume() {
      throw new IllegalStateException("Return has no resume");
    }
  }

public static class Suspend<T> extends TailCall<T> {

  private final Supplier<TailCall<T>> resume;

  private Suspend(Supplier<TailCall<T>> resume) {
    this.resume = resume;
  }

  @Override
  public T eval() {
    TailCall<T> tailRec = this;
    while(tailRec.isSuspend()) {
      tailRec = tailRec.resume();
    }
    return tailRec.eval();
  }

    @Override
    public boolean isSuspend() {
      return true;
    }

    @Override
    public TailCall<T> resume() {
      return resume.get();
    }
  }

  public static <T> Return<T> ret(T t) {
    return new Return<>(t);
  }

  public static <T> Suspend<T> sus(Supplier<TailCall<T>> s) {
    return new Suspend<>(s);
  }
}

Now that you have a stack-safe tail recursive method, can you do the same thing with a function? In my book, Functional Programming in Java, I talk about how to do that.

comments powered by Disqus