diff --git a/lib/polo/adapters/mysql.rb b/lib/polo/adapters/mysql.rb index 11a3486..6562b51 100644 --- a/lib/polo/adapters/mysql.rb +++ b/lib/polo/adapters/mysql.rb @@ -4,7 +4,7 @@ class MySQL def on_duplicate_key_update(inserts, records) insert_and_record = inserts.zip(records) insert_and_record.map do |insert, record| - attrs = record.is_a?(Hash) ? record.fetch(:values) : record.attributes + attrs = record.is_a?(Hash) ? record.fetch(:values) : record.attributes.slice(*record.class.column_names) values_syntax = attrs.keys.map do |key| "`#{key}` = VALUES(`#{key}`)" end @@ -22,4 +22,4 @@ def ignore_transform(inserts, records) end end end -end \ No newline at end of file +end diff --git a/lib/polo/adapters/postgres.rb b/lib/polo/adapters/postgres.rb index 6bd1c37..e6e5c47 100644 --- a/lib/polo/adapters/postgres.rb +++ b/lib/polo/adapters/postgres.rb @@ -1,11 +1,59 @@ module Polo module Adapters class Postgres - # TODO: Implement UPSERT. This command became available in 9.1. - # - # See: http://www.the-art-of-web.com/sql/upsert/ def on_duplicate_key_update(inserts, records) - raise 'on_duplicate: :override is not currently supported in the PostgreSQL adapter' + @pg_version ||= ActiveRecord::Base.connection.select_value('SELECT version()')[/PostgreSQL ([\d\.]+)/, 1] + + insert_and_record = inserts.zip(records) + insert_and_record.map do |insert, record| + if @pg_version < '9.5.0' + naive_update_insert(insert, record) + else + add_upsert_to_insert(insert, record) + end + end + end + + def add_upsert_to_insert(insert, record) + if record.is_a?(Hash) + return naive_update_insert(insert, record) + end + + attrs = record.is_a?(Hash) ? record.fetch(:values) : record.attributes.slice(*record.class.column_names) + values_syntax = attrs.keys.reject { |key| key.to_s == 'id' }.map do |key| + %{"#{key}" = EXCLUDED."#{key}"} + end + + # Conflict on id column + on_dup_syntax = "ON CONFLICT (#{record.class.primary_key}) DO UPDATE SET #{values_syntax.join(', ')}" + + "#{insert} #{on_dup_syntax}" + end + + def naive_update_insert(insert, record) + table_name, id = table_name_and_key_for(record) + + attrs = record.is_a?(Hash) ? record.fetch(:values) : record.attributes_before_type_cast.slice(*record.class.column_names) + updates = attrs.except('id').map do |key, value| + column = ActiveRecord::Base.connection.send(:quote_column_name, key) + + ActiveRecord::Base.send(:sanitize_sql_array, ["#{column} = ?", value]) + end + condition = if id.blank? + record[:values].map { |k, v| + column = ActiveRecord::Base.connection.send(:quote_column_name, k) + ActiveRecord::Base.send(:sanitize_sql_array, ["#{column} = ?", v]) + }.join(' and ') + else + "id = #{id}" + end + + "do $$ + begin + #{insert}; + exception when unique_violation then + update #{table_name} set #{updates.join(', ')} where #{condition}; + end $$;" end # Internal: Transforms an INSERT with PostgreSQL-specific syntax. Ignores @@ -21,17 +69,22 @@ def on_duplicate_key_update(inserts, records) def ignore_transform(inserts, records) insert_and_record = inserts.zip(records) insert_and_record.map do |insert, record| - if record.is_a?(Hash) - id = record.fetch(:values)[:id] - table_name = record.fetch(:table_name) - else - id = record[:id] - table_name = record.class.arel_table.name - end + table_name, id = table_name_and_key_for(record) insert = insert.gsub(/VALUES \((.+)\)$/m, 'SELECT \\1') insert << " WHERE NOT EXISTS (SELECT 1 FROM #{table_name} WHERE id=#{id});" end end + + def table_name_and_key_for(record) + if record.is_a?(Hash) + id = record.fetch(:values)[:id] + table_name = record.fetch(:table_name) + else + id = record[:id] + table_name = record.class.arel_table.name + end + [table_name, id] + end end end end