From f2a8bf9ca99beed344a31ac9f9fe53717412c660 Mon Sep 17 00:00:00 2001 From: Adrianna Chang Date: Tue, 19 Mar 2024 17:32:07 -0400 Subject: [PATCH] Make known `Arel::Nodes::SqlLiteral` retryable. --- .../associations/join_dependency/join_association.rb | 2 +- .../abstract/database_statements.rb | 2 +- .../connection_adapters/mysql/database_statements.rb | 2 +- .../postgresql/database_statements.rb | 2 +- .../connection_adapters/sqlite3/database_statements.rb | 2 +- activerecord/lib/active_record/internal_metadata.rb | 2 +- .../lib/active_record/relation/calculations.rb | 4 ++-- .../lib/active_record/relation/predicate_builder.rb | 4 ++-- .../lib/active_record/relation/query_methods.rb | 4 ++-- activerecord/lib/arel.rb | 10 +++++++--- activerecord/lib/arel/alias_predication.rb | 2 +- activerecord/lib/arel/nodes/sql_literal.rb | 7 +++++++ activerecord/lib/arel/select_manager.rb | 6 +++--- activerecord/lib/arel/visitors/mysql.rb | 4 ++-- activerecord/lib/arel/visitors/to_sql.rb | 2 +- activerecord/test/cases/arel/visitors/to_sql_test.rb | 10 +++++++++- 16 files changed, 42 insertions(+), 23 deletions(-) diff --git a/activerecord/lib/active_record/associations/join_dependency/join_association.rb b/activerecord/lib/active_record/associations/join_dependency/join_association.rb index bd87870a3eb46..809d8e0455a9b 100644 --- a/activerecord/lib/active_record/associations/join_dependency/join_association.rb +++ b/activerecord/lib/active_record/associations/join_dependency/join_association.rb @@ -93,7 +93,7 @@ def strict_loading? def append_constraints(connection, join, constraints) if join.is_a?(Arel::Nodes::StringJoin) join_string = Arel::Nodes::And.new(constraints.unshift join.left) - join.left = Arel.sql(connection.visitor.compile(join_string)) + join.left = Arel.sql(connection.visitor.compile(join_string), retryable: true) else right = join.right right.expr = Arel::Nodes::And.new(constraints.unshift right.expr) diff --git a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb index d8244ff61e178..3e1ea4cf4de8d 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb @@ -501,7 +501,7 @@ def with_yaml_fallback(value) # :nodoc: end # This is a safe default, even if not high precision on all databases - HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP").freeze # :nodoc: + HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP", retryable: true).freeze # :nodoc: private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP # Returns an Arel SQL literal for the CURRENT_TIMESTAMP for usage with diff --git a/activerecord/lib/active_record/connection_adapters/mysql/database_statements.rb b/activerecord/lib/active_record/connection_adapters/mysql/database_statements.rb index a5abb36c9ee4d..a358cd2b2b6a5 100644 --- a/activerecord/lib/active_record/connection_adapters/mysql/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/mysql/database_statements.rb @@ -11,7 +11,7 @@ module DatabaseStatements # https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_current-timestamp # https://dev.mysql.com/doc/refman/5.7/en/date-and-time-type-syntax.html - HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)").freeze # :nodoc: + HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)", retryable: true).freeze # :nodoc: private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP def write_query?(sql) # :nodoc: diff --git a/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb b/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb index 6fb13f0c54f13..7ece6a11bca6e 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb @@ -124,7 +124,7 @@ def exec_restart_db_transaction # :nodoc: end # From https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-CURRENT - HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP").freeze # :nodoc: + HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP", retryable: true).freeze # :nodoc: private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP def high_precision_current_timestamp diff --git a/activerecord/lib/active_record/connection_adapters/sqlite3/database_statements.rb b/activerecord/lib/active_record/connection_adapters/sqlite3/database_statements.rb index feffa32690448..03453b48f4b14 100644 --- a/activerecord/lib/active_record/connection_adapters/sqlite3/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/sqlite3/database_statements.rb @@ -106,7 +106,7 @@ def exec_rollback_db_transaction # :nodoc: # https://stackoverflow.com/questions/17574784 # https://www.sqlite.org/lang_datefunc.html - HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')").freeze # :nodoc: + HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')", retryable: true).freeze # :nodoc: private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP def high_precision_current_timestamp diff --git a/activerecord/lib/active_record/internal_metadata.rb b/activerecord/lib/active_record/internal_metadata.rb index 1e1cccb44afc4..ac2f3333b5143 100644 --- a/activerecord/lib/active_record/internal_metadata.rb +++ b/activerecord/lib/active_record/internal_metadata.rb @@ -153,7 +153,7 @@ def update_entry(connection, key, new_value) def select_entry(connection, key) sm = Arel::SelectManager.new(arel_table) - sm.project(Arel::Nodes::SqlLiteral.new("*")) + sm.project(Arel::Nodes::SqlLiteral.new("*", retryable: true)) sm.where(arel_table[primary_key].eq(Arel::Nodes::BindParam.new(key))) sm.order(arel_table[primary_key].asc) sm.limit = 1 diff --git a/activerecord/lib/active_record/relation/calculations.rb b/activerecord/lib/active_record/relation/calculations.rb index 620f99826a3a5..c5c42540f1c1c 100644 --- a/activerecord/lib/active_record/relation/calculations.rb +++ b/activerecord/lib/active_record/relation/calculations.rb @@ -446,7 +446,7 @@ def aggregate_column(column_name) return column_name if Arel::Expressions === column_name arel_column(column_name.to_s) do |name| - Arel.sql(column_name == :all ? "*" : name) + Arel.sql(column_name == :all ? "*" : name, retryable: true) end end @@ -643,7 +643,7 @@ def build_count_subquery(relation, column_name, distinct) relation.select_values = [ aggregate_column(column_name).as(column_alias) ] end - subquery_alias = Arel.sql("subquery_for_count") + subquery_alias = Arel.sql("subquery_for_count", retryable: true) select_value = operation_over_aggregate_column(column_alias, "count", false) relation.build_subquery(subquery_alias, select_value) diff --git a/activerecord/lib/active_record/relation/predicate_builder.rb b/activerecord/lib/active_record/relation/predicate_builder.rb index dd6cd573d8a99..878e74a42d652 100644 --- a/activerecord/lib/active_record/relation/predicate_builder.rb +++ b/activerecord/lib/active_record/relation/predicate_builder.rb @@ -28,9 +28,9 @@ def build_from_hash(attributes, &block) def self.references(attributes) attributes.each_with_object([]) do |(key, value), result| if value.is_a?(Hash) - result << Arel.sql(key) + result << Arel.sql(key, retryable: true) elsif (idx = key.rindex(".")) - result << Arel.sql(key[0, idx]) + result << Arel.sql(key[0, idx], retryable: true) end end end diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index 00425d93d3752..bafe6eb3eeb91 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -1749,7 +1749,7 @@ def build_join_buckets end joins.each_with_index do |join, i| - joins[i] = Arel::Nodes::StringJoin.new(Arel.sql(join.strip)) if join.is_a?(String) + joins[i] = Arel::Nodes::StringJoin.new(Arel.sql(join.strip, retryable: true)) if join.is_a?(String) end while joins.first.is_a?(Arel::Nodes::Join) @@ -2013,7 +2013,7 @@ def order_column(field) if attr_name == "count" && !group_values.empty? table[attr_name] else - Arel.sql(adapter_class.quote_table_name(attr_name)) + Arel.sql(adapter_class.quote_table_name(attr_name), retryable: true) end end end diff --git a/activerecord/lib/arel.rb b/activerecord/lib/arel.rb index ecbde97e22ddf..738e80df359fd 100644 --- a/activerecord/lib/arel.rb +++ b/activerecord/lib/arel.rb @@ -45,16 +45,20 @@ module Arel # that this behavior only applies when bind value parameters are # supplied in the call; without them, the placeholder tokens have no # special meaning, and will be passed through to the query as-is. - def self.sql(sql_string, *positional_binds, **named_binds) + # + # The +:retryable+ option can be used to mark the SQL as safe to retry. + # Use this option only if the SQL is idempotent, as it could be executed + # more than once. + def self.sql(sql_string, *positional_binds, retryable: false, **named_binds) if positional_binds.empty? && named_binds.empty? - Arel::Nodes::SqlLiteral.new sql_string + Arel::Nodes::SqlLiteral.new(sql_string, retryable: retryable) else Arel::Nodes::BoundSqlLiteral.new sql_string, positional_binds, named_binds end end def self.star # :nodoc: - sql "*" + sql("*", retryable: true) end def self.arel_node?(value) # :nodoc: diff --git a/activerecord/lib/arel/alias_predication.rb b/activerecord/lib/arel/alias_predication.rb index 4abbbb7ef6def..1f7af26c25b49 100644 --- a/activerecord/lib/arel/alias_predication.rb +++ b/activerecord/lib/arel/alias_predication.rb @@ -3,7 +3,7 @@ module Arel # :nodoc: all module AliasPredication def as(other) - Nodes::As.new self, Nodes::SqlLiteral.new(other) + Nodes::As.new self, Nodes::SqlLiteral.new(other, retryable: true) end end end diff --git a/activerecord/lib/arel/nodes/sql_literal.rb b/activerecord/lib/arel/nodes/sql_literal.rb index a6138b96bf01b..f1862605d7fbd 100644 --- a/activerecord/lib/arel/nodes/sql_literal.rb +++ b/activerecord/lib/arel/nodes/sql_literal.rb @@ -8,6 +8,13 @@ class SqlLiteral < String include Arel::AliasPredication include Arel::OrderPredications + attr_reader :retryable + + def initialize(string, retryable: false) + @retryable = retryable + super(string) + end + def encode_with(coder) coder.scalar = self.to_s end diff --git a/activerecord/lib/arel/select_manager.rb b/activerecord/lib/arel/select_manager.rb index dada3324eab4e..d34e74d7f953f 100644 --- a/activerecord/lib/arel/select_manager.rb +++ b/activerecord/lib/arel/select_manager.rb @@ -46,7 +46,7 @@ def exists end def as(other) - create_table_alias grouping(@ast), Nodes::SqlLiteral.new(other) + create_table_alias grouping(@ast), Nodes::SqlLiteral.new(other, retryable: true) end def lock(locking = Arel.sql("FOR UPDATE")) @@ -131,7 +131,7 @@ def project(*projections) # FIXME: converting these to SQLLiterals is probably not good, but # rails tests require it. @ctx.projections.concat projections.map { |x| - STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s) : x + STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s, retryable: true) : x } self end @@ -172,7 +172,7 @@ def distinct_on(value) def order(*expr) # FIXME: We SHOULD NOT be converting these to SqlLiteral automatically @ast.orders.concat expr.map { |x| - STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s) : x + STRING_OR_SYMBOL_CLASS.include?(x.class) ? Nodes::SqlLiteral.new(x.to_s, retryable: true) : x } self end diff --git a/activerecord/lib/arel/visitors/mysql.rb b/activerecord/lib/arel/visitors/mysql.rb index 495c78ce3d047..937aa96da9163 100644 --- a/activerecord/lib/arel/visitors/mysql.rb +++ b/activerecord/lib/arel/visitors/mysql.rb @@ -27,7 +27,7 @@ def visit_Arel_Nodes_SelectStatement(o, collector) end def visit_Arel_Nodes_SelectCore(o, collector) - o.froms ||= Arel.sql("DUAL") + o.froms ||= Arel.sql("DUAL", retryable: true) super end @@ -103,7 +103,7 @@ def build_subselect(key, o) Nodes::SelectStatement.new.tap do |stmt| core = stmt.cores.last core.froms = Nodes::Grouping.new(subselect).as("__active_record_temp") - core.projections = [Arel.sql(quote_column_name(key.name))] + core.projections = [Arel.sql(quote_column_name(key.name), retryable: true)] end end end diff --git a/activerecord/lib/arel/visitors/to_sql.rb b/activerecord/lib/arel/visitors/to_sql.rb index 46ce3407c907a..8c1586f9ff5ef 100644 --- a/activerecord/lib/arel/visitors/to_sql.rb +++ b/activerecord/lib/arel/visitors/to_sql.rb @@ -772,7 +772,7 @@ def visit_Arel_Nodes_BindParam(o, collector) def visit_Arel_Nodes_SqlLiteral(o, collector) collector.preparable = false - collector.retryable = false + collector.retryable = o.retryable collector << o.to_s end diff --git a/activerecord/test/cases/arel/visitors/to_sql_test.rb b/activerecord/test/cases/arel/visitors/to_sql_test.rb index 888205ad56d52..b43e27561a990 100644 --- a/activerecord/test/cases/arel/visitors/to_sql_test.rb +++ b/activerecord/test/cases/arel/visitors/to_sql_test.rb @@ -70,7 +70,7 @@ def dispatch end it "should mark collector as non-retryable when visiting named function" do - function = Nodes::NamedFunction.new("omg", [Arel.star]) + function = Nodes::NamedFunction.new("ABS", [@table]) collector = Collectors::SQLString.new @visitor.accept(function, collector) @@ -85,6 +85,14 @@ def dispatch assert_equal false, collector.retryable end + it "should mark collector as retryable if SQL literal is marked as retryable" do + node = Nodes::SqlLiteral.new("COUNT(*)", retryable: true) + collector = Collectors::SQLString.new + @visitor.accept(node, collector) + + assert collector.retryable + end + it "should mark collector as non-retryable when visiting bound SQL literal" do node = Nodes::BoundSqlLiteral.new("id IN (?)", [[1, 2, 3]], {}) collector = Collectors::SQLString.new