Preventing Recursion in Ruby

Sometimes you need to prevent code from executing recursively. I had to do this recently in order to prevent an infinite recursion in ActiveRecord callbacks. A before-save hook needed to save a parent model attribute, which because of an :autosave option would then try to save the child model, which would then try to save the parent… and so on until the stack overflowed.

Unless the method is directly calling itself, the only way to short-circuit a method when it recurses is to have it set some global flag which it can then check in recursive calls. If the flag is already set, the recursively called method bails out early. Once the original invocation is finished it unsets the flag.

It’s not safe to use a true global variable for this flag, since this would not be threadsafe. The safe way to do it (albeit not a terribly pretty one) is to use thread-local variables.

I decided to go ahead and write a macro which can turn any method into a non-recursive one. Here’s the code:

module RecursionHelpers
  def prevent_recursion(method_name)
    flag_name = "in:#{name}##{method_name}"
    original = instance_method(method_name)

    define_method(method_name) do |*args|
      return if Thread.current[flag_name]
      begin
        Thread.current[flag_name] = true
        original.bind(self).call(*args)
      ensure
        Thread.current[flag_name] = nil
      end
    end
  end
end

Let’s take a closer look at that.

The method takes an argument, method_name, which is the name of the method to make non-recursive.

def self.prevent_recursion(method_name)

Then it creates a name for the recursion flag. Since thread-local variables are global for the whole thread, we need a variable name which is reasonably unique. We get one by combining the current class and the method name.

flag_name = "in:#{name}##{method_name}"

We’re going to replace the method with a modified version, so we need to stash the original implementation somewhere. We grab an UnboundMethod object by calling #instance_method() on the current class:

original = instance_method(method_name)

Then we proceed to redefine the method:

define_method(method_name) do |*args|

The first thing the modified method does is check to see if the recursion flag is set. If it is, it bails out.

return if Thread.current[flag_name]

Otherwise, it sets the flag:

Thread.current[flag_name] = true

And then it binds the UnboundMethod stored in original to the current instance. This generates a Method instance, which it proceeds to call, passing any arguments along.

(Note that I wrote this code for a Ruby 1.8 codebase. If it were for 1.9 I would pass the &block down to the original as well. Since blocks in 1.8 can’t accept &block arguments, I can’t do that here so long as I’m using the block form of define_method. So this macro will only work for methods which do not take blocks.)

original.bind(self).call(*args)

Finally, the code ensures that the recursion flag will be unset when the method is finished, even if an exception is raised.

ensure
  Thread.current[flag_name] = nil
end
Note that setting a thread-local variable to nil removes it:

Thread.current.keys               # => []
Thread.current[:foo] = 42
Thread.current.keys               # => [:foo, :__inspect_key__]
Thread.current[:foo] = nil
Thread.current.keys               # => [:__inspect_key__]

I have no idea what :__inspect_key__ is, but as you can see the :foo key goes away after setting it to nil.

Let’s take a look at this in practice. Here’s are two methods #foo and #bar (original, no?). #foo is recursive, with #bar as an intermediary, preventing it from passing a recursion flag directly to itself.

class X
  def foo
    bar
    "foo done"
  end

  def bar
    foo
  end
end

Calling the #foo method results in a stack overflow:

X.new.foo rescue $! # => #<SystemStackError: stack level too deep>

But if we update the method with prevent_recursion, it successfully exits:

require './prevent-recursion.rb'

class X
  extend RecursionHelpers

  prevent_recursion :foo
end

X.new.foo # => "foo done"

1 comment

  1. I tried to use this code to prevent recursion with a validate method in a model. It didn’t prevent. Is there anything unique to the validate method that would prevent the prevent from working?

Comments are closed.