Sometimes you need to prevent code from executing recursively. I had to do this recently in order to prevent an infinite recursion in ActiveRecord callbacks. A before-save hook needed to save a parent model attribute, which because of an :autosave
option would then try to save the child model, which would then try to save the parent… and so on until the stack overflowed.
Unless the method is directly calling itself, the only way to short-circuit a method when it recurses is to have it set some global flag which it can then check in recursive calls. If the flag is already set, the recursively called method bails out early. Once the original invocation is finished it unsets the flag.
It’s not safe to use a true global variable for this flag, since this would not be threadsafe. The safe way to do it (albeit not a terribly pretty one) is to use thread-local variables.
I decided to go ahead and write a macro which can turn any method into a non-recursive one. Here’s the code:
module RecursionHelpers
def prevent_recursion(method_name)
flag_name = "in:#{name}##{method_name}"
original = instance_method(method_name)
define_method(method_name) do |*args|
return if Thread.current[flag_name]
begin
Thread.current[flag_name] = true
original.bind(self).call(*args)
ensure
Thread.current[flag_name] = nil
end
end
end
end
Let’s take a closer look at that.
The method takes an argument, method_name
, which is the name of the method to make non-recursive.
def self.prevent_recursion(method_name)
Then it creates a name for the recursion flag. Since thread-local variables are global for the whole thread, we need a variable name which is reasonably unique. We get one by combining the current class and the method name.
flag_name = "in:#{name}##{method_name}"
We’re going to replace the method with a modified version, so we need to stash the original implementation somewhere. We grab an UnboundMethod
object by calling #instance_method()
on the current class:
original = instance_method(method_name)
Then we proceed to redefine the method:
define_method(method_name) do |*args|
The first thing the modified method does is check to see if the recursion flag is set. If it is, it bails out.
return if Thread.current[flag_name]
Otherwise, it sets the flag:
Thread.current[flag_name] = true
And then it binds the UnboundMethod
stored in original to the current instance. This generates a Method
instance, which it proceeds to call, passing any arguments along.
(Note that I wrote this code for a Ruby 1.8 codebase. If it were for 1.9 I would pass the &block down to the original as well. Since blocks in 1.8 can’t accept &block arguments, I can’t do that here so long as I’m using the block form of define_method
. So this macro will only work for methods which do not take blocks.)
original.bind(self).call(*args)
Finally, the code ensures that the recursion flag will be unset when the method is finished, even if an exception is raised.
ensure
Thread.current[flag_name] = nil
end
Note that setting a thread-local variable to nil removes it:
Thread.current.keys # => []
Thread.current[:foo] = 42
Thread.current.keys # => [:foo, :__inspect_key__]
Thread.current[:foo] = nil
Thread.current.keys # => [:__inspect_key__]
I have no idea what :__inspect_key__
is, but as you can see the :foo
key goes away after setting it to nil.
Let’s take a look at this in practice. Here’s are two methods #foo
and #bar
(original, no?). #foo
is recursive, with #bar
as an intermediary, preventing it from passing a recursion flag directly to itself.
class X
def foo
bar
"foo done"
end
def bar
foo
end
end
Calling the #foo
method results in a stack overflow:
X.new.foo rescue $! # => #<SystemStackError: stack level too deep>
But if we update the method with prevent_recursion
, it successfully exits:
require './prevent-recursion.rb'
class X
extend RecursionHelpers
prevent_recursion :foo
end
X.new.foo # => "foo done"
I tried to use this code to prevent recursion with a validate method in a model. It didn’t prevent. Is there anything unique to the validate method that would prevent the prevent from working?