Skip to content

Commit

Permalink
Make known Arel::Nodes::SqlLiteral retryable.
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianna-chang-shopify committed Mar 19, 2024
1 parent dbbe689 commit f2a8bf9
Show file tree
Hide file tree
Showing 16 changed files with 42 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def strict_loading?
def append_constraints(connection, join, constraints)
if join.is_a?(Arel::Nodes::StringJoin)
join_string = Arel::Nodes::And.new(constraints.unshift join.left)
join.left = Arel.sql(connection.visitor.compile(join_string))
join.left = Arel.sql(connection.visitor.compile(join_string), retryable: true)
else
right = join.right
right.expr = Arel::Nodes::And.new(constraints.unshift right.expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def with_yaml_fallback(value) # :nodoc:
end

# This is a safe default, even if not high precision on all databases
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

# Returns an Arel SQL literal for the CURRENT_TIMESTAMP for usage with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module DatabaseStatements

# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_current-timestamp
# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-type-syntax.html
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

def write_query?(sql) # :nodoc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def exec_restart_db_transaction # :nodoc:
end

# From https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-CURRENT
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

def high_precision_current_timestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def exec_rollback_db_transaction # :nodoc:

# https://stackoverflow.com/questions/17574784
# https://www.sqlite.org/lang_datefunc.html
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')").freeze # :nodoc:
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')", retryable: true).freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP

def high_precision_current_timestamp
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/internal_metadata.rb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def update_entry(connection, key, new_value)

def select_entry(connection, key)
sm = Arel::SelectManager.new(arel_table)
sm.project(Arel::Nodes::SqlLiteral.new("*"))
sm.project(Arel::Nodes::SqlLiteral.new("*", retryable: true))
sm.where(arel_table[primary_key].eq(Arel::Nodes::BindParam.new(key)))
sm.order(arel_table[primary_key].asc)
sm.limit = 1
Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/active_record/relation/calculations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def aggregate_column(column_name)
return column_name if Arel::Expressions === column_name

arel_column(column_name.to_s) do |name|
Arel.sql(column_name == :all ? "*" : name)
Arel.sql(column_name == :all ? "*" : name, retryable: true)
end
end

Expand Down Expand Up @@ -643,7 +643,7 @@ def build_count_subquery(relation, column_name, distinct)
relation.select_values = [ aggregate_column(column_name).as(column_alias) ]
end

subquery_alias = Arel.sql("subquery_for_count")
subquery_alias = Arel.sql("subquery_for_count", retryable: true)
select_value = operation_over_aggregate_column(column_alias, "count", false)

relation.build_subquery(subquery_alias, select_value)
Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/active_record/relation/predicate_builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def build_from_hash(attributes, &block)
def self.references(attributes)
attributes.each_with_object([]) do |(key, value), result|
if value.is_a?(Hash)
result << Arel.sql(key)
result << Arel.sql(key, retryable: true)
elsif (idx = key.rindex("."))
result << Arel.sql(key[0, idx])
result << Arel.sql(key[0, idx], retryable: true)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/active_record/relation/query_methods.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ def build_join_buckets
end

joins.each_with_index do |join, i|
joins[i] = Arel::Nodes::StringJoin.new(Arel.sql(join.strip)) if join.is_a?(String)
joins[i] = Arel::Nodes::StringJoin.new(Arel.sql(join.strip, retryable: true)) if join.is_a?(String)
end

while joins.first.is_a?(Arel::Nodes::Join)
Expand Down Expand Up @@ -2013,7 +2013,7 @@ def order_column(field)
if attr_name == "count" && !group_values.empty?
table[attr_name]
else
Arel.sql(adapter_class.quote_table_name(attr_name))
Arel.sql(adapter_class.quote_table_name(attr_name), retryable: true)
end
end
end
Expand Down
10 changes: 7 additions & 3 deletions activerecord/lib/arel.rb
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@ module Arel
# that this behavior only applies when bind value parameters are
# supplied in the call; without them, the placeholder tokens have no
# special meaning, and will be passed through to the query as-is.
def self.sql(sql_string, *positional_binds, **named_binds)
#
# The +:retryable+ option can be used to mark the SQL as safe to retry.
# Use this option only if the SQL is idempotent, as it could be executed
# more than once.
def self.sql(sql_string, *positional_binds, retryable: false, **named_binds)
if positional_binds.empty? && named_binds.empty?
Arel::Nodes::SqlLiteral.new sql_string
Arel::Nodes::SqlLiteral.new(sql_string, retryable: retryable)
else
Arel::Nodes::BoundSqlLiteral.new sql_string, positional_binds, named_binds
end
end

def self.star # :nodoc:
sql "*"
sql("*", retryable: true)
end

def self.arel_node?(value) # :nodoc:
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/arel/alias_predication.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Arel # :nodoc: all
module AliasPredication
def as(other)
Nodes::As.new self, Nodes::SqlLiteral.new(other)
Nodes::As.new self, Nodes::SqlLiteral.new(other, retryable: true)
end
end
end
7 changes: 7 additions & 0 deletions activerecord/lib/arel/nodes/sql_literal.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ class SqlLiteral < String
include Arel::AliasPredication
include Arel::OrderPredications

attr_reader :retryable

def initialize(string, retryable: false)
@retryable = retryable
super(string)
end

def encode_with(coder)
coder.scalar = self.to_s
end
Expand Down
6 changes: 3 additions & 3 deletions activerecord/lib/arel/select_manager.rb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def exists
end

def as(other)
create_table_alias grouping(@ast), Nodes::SqlLiteral.new(other)
create_table_alias grouping(@ast), Nodes::SqlLiteral.new(other, retryable: true)
end

def lock(locking = Arel.sql("FOR UPDATE"))
Expand Down Expand Up @@ -131,7 +131,7 @@ def project(*projections)
# FIXME: converting these to SQLLiterals is probably not good, but
# rails tests require it.
@ctx.projections.concat projections.map { |x|
STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s) : x
STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s, retryable: true) : x
}
self
end
Expand Down Expand Up @@ -172,7 +172,7 @@ def distinct_on(value)
def order(*expr)
# FIXME: We SHOULD NOT be converting these to SqlLiteral automatically
@ast.orders.concat expr.map { |x|
STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s) : x
STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s, retryable: true) : x
}
self
end
Expand Down
4 changes: 2 additions & 2 deletions activerecord/lib/arel/visitors/mysql.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def visit_Arel_Nodes_SelectStatement(o, collector)
end

def visit_Arel_Nodes_SelectCore(o, collector)
o.froms ||= Arel.sql("DUAL")
o.froms ||= Arel.sql("DUAL", retryable: true)
super
end

Expand Down Expand Up @@ -103,7 +103,7 @@ def build_subselect(key, o)
Nodes::SelectStatement.new.tap do |stmt|
core = stmt.cores.last
core.froms = Nodes::Grouping.new(subselect).as("__active_record_temp")
core.projections = [Arel.sql(quote_column_name(key.name))]
core.projections = [Arel.sql(quote_column_name(key.name), retryable: true)]
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/arel/visitors/to_sql.rb
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def visit_Arel_Nodes_BindParam(o, collector)

def visit_Arel_Nodes_SqlLiteral(o, collector)
collector.preparable = false
collector.retryable = false
collector.retryable = o.retryable
collector << o.to_s
end

Expand Down
10 changes: 9 additions & 1 deletion activerecord/test/cases/arel/visitors/to_sql_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dispatch
end

it "should mark collector as non-retryable when visiting named function" do
function = Nodes::NamedFunction.new("omg", [Arel.star])
function = Nodes::NamedFunction.new("ABS", [@table])
collector = Collectors::SQLString.new
@visitor.accept(function, collector)

Expand All @@ -85,6 +85,14 @@ def dispatch
assert_equal false, collector.retryable
end

it "should mark collector as retryable if SQL literal is marked as retryable" do
node = Nodes::SqlLiteral.new("COUNT(*)", retryable: true)
collector = Collectors::SQLString.new
@visitor.accept(node, collector)

assert collector.retryable
end

it "should mark collector as non-retryable when visiting bound SQL literal" do
node = Nodes::BoundSqlLiteral.new("id IN (?)", [[1, 2, 3]], {})
collector = Collectors::SQLString.new
Expand Down

0 comments on commit f2a8bf9

Please sign in to comment.