Skip to content

Commit

Permalink
Retry known idempotent SELECT queries on connection-related exceptions
Browse files Browse the repository at this point in the history
This commit makes two types of queries retry-able by opting into our `allow_retry` flag:
1) SELECT queries we construct by walking the Arel tree via `#to_sql_and_binds`. We use a
new `retryable` attribute on collector classes, which defaults to true for most node types,
but will be set to false for non-idempotent node types (functions, SQL literals, etc). The
`retryable` value is returned from  `#to_sql_and_binds` and used by `#select_all` and
passed down the call stack, eventually reaching the adapter's `#internal_exec_query` method.

2) `#find` and `#find_by` queries with known attributes. We set `allow_retry: true` in `#cached_find_by`,
and pass this down to `#find_by_sql` and `#_query_by_sql`.

These changes ensure that queries we know are safe to retry can be retried automatically.
  • Loading branch information
adrianna-chang-shopify committed Mar 19, 2024
1 parent a534bac commit 7d9c15d
Show file tree
Hide file tree
Showing 17 changed files with 166 additions and 35 deletions.
9 changes: 9 additions & 0 deletions activerecord/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
* Retry known idempotent SELECT queries on connection-related exceptions

SELECT queries we construct by walking the Arel tree and / or with known model attributes
are idempotent and can safely be retried in the case of a connection error. Previously,
adapters such as `TrilogyAdapter` would raise `ActiveRecord::ConnectionFailed: Trilogy::EOFError`
when encountering a connection error mid-request.

*Adrianna Chang*

* Add dirties option to uncached

This adds a `dirties` option to `ActiveRecord::Base.uncached` and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def to_sql(arel_or_sql_string, binds = [])
sql
end

def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil) # :nodoc:
def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil, allow_retry = false) # :nodoc:
# Arel::TreeManager -> Arel::Node
if arel_or_sql_string.respond_to?(:ast)
arel_or_sql_string = arel_or_sql_string.ast
Expand All @@ -27,6 +27,7 @@ def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil) # :nodoc:
end

collector = collector()
collector.retryable = true

if prepared_statements
collector.preparable = true
Expand All @@ -41,10 +42,11 @@ def to_sql_and_binds(arel_or_sql_string, binds = [], preparable = nil) # :nodoc:
else
sql = visitor.compile(arel_or_sql_string, collector)
end
[sql.freeze, binds, preparable]
allow_retry = collector.retryable
[sql.freeze, binds, preparable, allow_retry]
else
arel_or_sql_string = arel_or_sql_string.dup.freeze unless arel_or_sql_string.frozen?
[arel_or_sql_string, binds, preparable]
[arel_or_sql_string, binds, preparable, allow_retry]
end
end
private :to_sql_and_binds
Expand All @@ -64,11 +66,15 @@ def cacheable_query(klass, arel) # :nodoc:
end

# Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil, async: false)
def select_all(arel, name = nil, binds = [], preparable: nil, async: false, allow_retry: false)
arel = arel_from_relation(arel)
sql, binds, preparable = to_sql_and_binds(arel, binds, preparable)
sql, binds, preparable, allow_retry = to_sql_and_binds(arel, binds, preparable, allow_retry)

select(sql, name, binds, prepare: prepared_statements && preparable, async: async && FutureResult::SelectAll)
select(sql, name, binds,
prepare: prepared_statements && preparable,
async: async && FutureResult::SelectAll,
allow_retry: allow_retry
)
rescue ::RangeError
ActiveRecord::Result.empty(async: async)
end
Expand Down Expand Up @@ -507,7 +513,7 @@ def high_precision_current_timestamp
HIGH_PRECISION_CURRENT_TIMESTAMP
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
raise NotImplementedError
end

Expand Down Expand Up @@ -606,7 +612,7 @@ def combine_multi_statements(total_sql)
end

# Returns an ActiveRecord::Result instance.
def select(sql, name = nil, binds = [], prepare: false, async: false)
def select(sql, name = nil, binds = [], prepare: false, async: false, allow_retry: false)
if async && async_enabled?
if current_transaction.joinable?
raise AsynchronousQueryInsideTransactionError, "Asynchronous queries are not allowed inside transactions"
Expand All @@ -627,7 +633,7 @@ def select(sql, name = nil, binds = [], prepare: false, async: false)
return future_result
end

result = internal_exec_query(sql, name, binds, prepare: prepare)
result = internal_exec_query(sql, name, binds, prepare: prepare, allow_retry: allow_retry)
if async
FutureResult.wrap(result)
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,19 @@ def clear_query_cache
pool.clear_query_cache
end

def select_all(arel, name = nil, binds = [], preparable: nil, async: false) # :nodoc:
def select_all(arel, name = nil, binds = [], preparable: nil, async: false, allow_retry: false) # :nodoc:
arel = arel_from_relation(arel)

# If arel is locked this is a SELECT ... FOR UPDATE or somesuch.
# Such queries should not be cached.
if @query_cache&.enabled? && !(arel.respond_to?(:locked) && arel.locked)
sql, binds, preparable = to_sql_and_binds(arel, binds, preparable)
sql, binds, preparable, allow_retry = to_sql_and_binds(arel, binds, preparable)

if async
result = lookup_sql_cache(sql, name, binds) || super(sql, name, binds, preparable: preparable, async: async)
result = lookup_sql_cache(sql, name, binds) || super(sql, name, binds, preparable: preparable, async: async, allow_retry: allow_retry)
FutureResult.wrap(result)
else
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable, async: async) }
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable, async: async, allow_retry: allow_retry) }
end
else
super
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ def disable_referential_integrity # :nodoc:
# Mysql2Adapter doesn't have to free a result after using it, but we use this method
# to write stuff in an abstract way without concerning ourselves about whether it
# needs to be explicitly freed or not.
def execute_and_free(sql, name = nil, async: false) # :nodoc:
def execute_and_free(sql, name = nil, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)

mark_transaction_written_if_write(sql)
yield raw_execute(sql, name, async: async)
yield raw_execute(sql, name, async: async, allow_retry: allow_retry)
end

def begin_db_transaction # :nodoc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def select_all(*, **) # :nodoc:
result
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
if without_prepared_statement?(binds)
execute_and_free(sql, name, async: async) do |result|
execute_and_free(sql, name, async: async, allow_retry: allow_retry) do |result|
if result
build_result(columns: result.fields, rows: result.to_a)
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def exec_no_cache(sql, name, binds, async:, allow_retry:, materialize_transactio

type_casted_binds = type_casted_binds(binds)
log(sql, name, binds, type_casted_binds, async: async) do |notification_payload|
with_raw_connection(allow_retry: false, materialize_transactions: materialize_transactions) do |conn|
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
result = conn.exec_params(sql, type_casted_binds)
verified!
notification_payload[:row_count] = result.count
Expand All @@ -895,7 +895,7 @@ def exec_cache(sql, name, binds, async:, allow_retry:, materialize_transactions:

update_typemap_for_default_timezone

with_raw_connection(allow_retry: false, materialize_transactions: materialize_transactions) do |conn|
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
stmt_key = prepare_statement(sql, binds, conn)
type_casted_binds = type_casted_binds(binds)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def explain(arel, binds = [], _options = [])
SQLite3::ExplainPrettyPrinter.new.pp(result)
end

def internal_exec_query(sql, name = nil, binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = nil, binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def select_all(*, **) # :nodoc:
result
end

def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false, allow_retry: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

result = raw_execute(sql, name, async: async)
result = raw_execute(sql, name, async: async, allow_retry: allow_retry)
ActiveRecord::Result.new(result.fields, result.to_a)
end

Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/core.rb
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def cached_find_by(keys, values)
}

begin
statement.execute(values.flatten, lease_connection).first
statement.execute(values.flatten, lease_connection, allow_retry: true).first
rescue TypeError
raise ActiveRecord::StatementInvalid
end
Expand Down
8 changes: 4 additions & 4 deletions activerecord/lib/active_record/querying.rb
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ module Querying
#
# Note that building your own SQL query string from user input may expose your application to
# injection attacks (https://guides.rubyonrails.org/security.html#sql-injection).
def find_by_sql(sql, binds = [], preparable: nil, &block)
_load_from_sql(_query_by_sql(sql, binds, preparable: preparable), &block)
def find_by_sql(sql, binds = [], preparable: nil, allow_retry: false, &block)
_load_from_sql(_query_by_sql(sql, binds, preparable: preparable, allow_retry: allow_retry), &block)
end

# Same as <tt>#find_by_sql</tt> but perform the query asynchronously and returns an ActiveRecord::Promise.
Expand All @@ -58,8 +58,8 @@ def async_find_by_sql(sql, binds = [], preparable: nil, &block)
end
end

def _query_by_sql(sql, binds = [], preparable: nil, async: false) # :nodoc:
lease_connection.select_all(sanitize_sql(sql), "#{name} Load", binds, preparable: preparable, async: async)
def _query_by_sql(sql, binds = [], preparable: nil, async: false, allow_retry: false) # :nodoc:
lease_connection.select_all(sanitize_sql(sql), "#{name} Load", binds, preparable: preparable, async: async, allow_retry: allow_retry)
end

def _load_from_sql(result_set, &block) # :nodoc:
Expand Down
6 changes: 3 additions & 3 deletions activerecord/lib/active_record/statement_cache.rb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def sql_for(binds, connection)
end

class PartialQueryCollector
attr_accessor :preparable
attr_accessor :preparable, :retryable

def initialize
@parts = []
Expand Down Expand Up @@ -142,12 +142,12 @@ def initialize(query_builder, bind_map, klass)
@klass = klass
end

def execute(params, connection, &block)
def execute(params, connection, allow_retry: false, &block)
bind_values = bind_map.bind params

sql = query_builder.sql_for bind_values, connection

klass.find_by_sql(sql, bind_values, preparable: true, &block)
klass.find_by_sql(sql, bind_values, preparable: true, allow_retry: allow_retry, &block)
rescue ::RangeError
[]
end
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/arel/collectors/composite.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Arel # :nodoc: all
module Collectors
class Composite
attr_accessor :preparable
attr_accessor :preparable, :retryable

def initialize(left, right)
@left = left
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/arel/collectors/sql_string.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
module Arel # :nodoc: all
module Collectors
class SQLString < PlainString
attr_accessor :preparable
attr_accessor :preparable, :retryable

def initialize(*)
super
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/arel/collectors/substitute_binds.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Arel # :nodoc: all
module Collectors
class SubstituteBinds
attr_accessor :preparable
attr_accessor :preparable, :retryable

def initialize(quoter, delegate_collector)
@quoter = quoter
Expand Down
6 changes: 6 additions & 0 deletions activerecord/lib/arel/visitors/to_sql.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def compile(node, collector = Arel::Collectors::SQLString.new)

private
def visit_Arel_Nodes_DeleteStatement(o, collector)
collector.retryable = false
o = prepare_delete_statement(o)

if has_join_sources?(o)
Expand All @@ -37,6 +38,7 @@ def visit_Arel_Nodes_DeleteStatement(o, collector)
end

def visit_Arel_Nodes_UpdateStatement(o, collector)
collector.retryable = false
o = prepare_update_statement(o)

collector << "UPDATE "
Expand All @@ -49,6 +51,7 @@ def visit_Arel_Nodes_UpdateStatement(o, collector)
end

def visit_Arel_Nodes_InsertStatement(o, collector)
collector.retryable = false
collector << "INSERT INTO "
collector = visit o.relation, collector

Expand Down Expand Up @@ -381,6 +384,7 @@ def visit_Arel_Nodes_Group(o, collector)
end

def visit_Arel_Nodes_NamedFunction(o, collector)
collector.retryable = false
collector << o.name
collector << "("
collector << "DISTINCT " if o.distinct
Expand Down Expand Up @@ -768,10 +772,12 @@ def visit_Arel_Nodes_BindParam(o, collector)

def visit_Arel_Nodes_SqlLiteral(o, collector)
collector.preparable = false
collector.retryable = false
collector << o.to_s
end

def visit_Arel_Nodes_BoundSqlLiteral(o, collector)
collector.retryable = false
bind_index = 0

new_bind = lambda do |value|
Expand Down
66 changes: 64 additions & 2 deletions activerecord/test/cases/adapter_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -630,18 +630,78 @@ def teardown
assert_predicate @connection, :active?
end

test "querying after a failed query restores and succeeds" do
test "querying after a failed non-retryable query restores and succeeds" do
Post.first # Connection verified (and prepared statement pool populated if enabled)

remote_disconnect @connection

assert_raises(ActiveRecord::ConnectionFailed) do
Post.first # Connection no longer verified after failed query
@connection.execute("INSERT INTO posts(title, body) VALUES ('foo', 'bar')")
end

assert Post.first # Verifying the connection causes a reconnect and the query succeeds
assert_predicate @connection, :active?
end

test "idempotent SELECT queries are retried and result in a reconnect" do
Post.first

remote_disconnect @connection

assert Post.first
assert_predicate @connection, :active?

remote_disconnect @connection

assert Post.where(id: [1, 2]).first
assert_predicate @connection, :active?
end

test "#find and #find_by queries with known attributes are retried and result in a reconnect" do
Post.first

remote_disconnect @connection

assert Post.find(1)
assert_predicate @connection, :active?

remote_disconnect @connection

assert Post.find_by(title: "Welcome to the weblog")
assert_predicate @connection, :active?
end

test "queries containing SQL fragments are not retried" do
Post.first

remote_disconnect @connection

assert_raises(ActiveRecord::ConnectionFailed) { Post.where("1 = 1").to_a }
assert_not_predicate @connection, :active?

remote_disconnect @connection

assert_raises(ActiveRecord::ConnectionFailed) { Post.select("title AS custom_title").first }
assert_not_predicate @connection, :active?

remote_disconnect @connection

assert_raises(ActiveRecord::ConnectionFailed) { Post.find_by("updated_at < ?", 2.weeks.ago) }
assert_not_predicate @connection, :active?
end

test "queries containing SQL functions are not retried" do
Post.first

remote_disconnect @connection

tags_count_attr = Post.arel_table[:tags_count]
abs_tags_count = Arel::Nodes::NamedFunction.new("ABS", [tags_count_attr])

assert_raises(ActiveRecord::ConnectionFailed) do
Post.where(abs_tags_count.eq(2)).first
end
assert_not_predicate @connection, :active?
end

test "transaction restores after remote disconnection" do
Expand Down Expand Up @@ -779,6 +839,8 @@ def raw_transaction_open?(connection)
def remote_disconnect(connection)
case connection.adapter_name
when "PostgreSQL"
# Connection was left in a bad state, need to reconnect to simulate fresh disconnect
connection.verify! if connection.instance_variable_get(:@raw_connection).status == ::PG::CONNECTION_BAD
unless connection.instance_variable_get(:@raw_connection).transaction_status == ::PG::PQTRANS_INTRANS
connection.instance_variable_get(:@raw_connection).async_exec("begin")
end
Expand Down
Loading

0 comments on commit 7d9c15d

Please sign in to comment.