From 47d19538c2ca1501f2005a6df7ce2d99dfa00ae4 Mon Sep 17 00:00:00 2001 From: Jean Boussier Date: Mon, 15 Apr 2024 09:25:16 +0200 Subject: [PATCH] Add `AR::Base#load_async(*associations)` as a nicer API --- activerecord/lib/active_record/associations.rb | 5 +++++ .../lib/active_record/associations/association.rb | 4 ++-- activerecord/lib/active_record/core.rb | 9 +++++---- .../cases/associations/belongs_to_associations_test.rb | 2 +- .../cases/associations/has_many_associations_test.rb | 2 +- .../test/cases/associations/has_one_associations_test.rb | 2 +- 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/activerecord/lib/active_record/associations.rb b/activerecord/lib/active_record/associations.rb index 81f5691f64ce..7fe61bab3c45 100644 --- a/activerecord/lib/active_record/associations.rb +++ b/activerecord/lib/active_record/associations.rb @@ -62,6 +62,11 @@ def association(name) # :nodoc: association end + def load_async(*associations) # TODO: doc + associations.map { |name| association(name) }.each(&:async_load_target) + self + end + def association_cached?(name) # :nodoc: @association_cache.key?(name) end diff --git a/activerecord/lib/active_record/associations/association.rb b/activerecord/lib/active_record/associations/association.rb index 408e8e5977c9..4c7f4f3acf23 100644 --- a/activerecord/lib/active_record/associations/association.rb +++ b/activerecord/lib/active_record/associations/association.rb @@ -191,7 +191,7 @@ def async_load_target @target = find_target(async: true) if (@stale_state && stale_target?) || find_target? loaded! unless loaded? - @target + nil end # We can't dump @reflection and @through_reflection since it contains the scope proc @@ -247,7 +247,7 @@ def find_target(async: false) if async return scope.load_async.then(&:to_a) else - return scope.to_a + return scope.to_a end end diff --git a/activerecord/lib/active_record/core.rb b/activerecord/lib/active_record/core.rb index e5962e3ead61..c67d261bad0e 100644 --- a/activerecord/lib/active_record/core.rb +++ b/activerecord/lib/active_record/core.rb @@ -432,10 +432,11 @@ def cached_find_by(keys, values) where(wheres).limit(1) } - begin - statement.execute(values.flatten, lease_connection, allow_retry: true).then(&:first) - rescue TypeError - raise ActiveRecord::StatementInvalid + begin + statement.execute(values.flatten, lease_connection, allow_retry: true).then(&:first) + rescue TypeError + raise ActiveRecord::StatementInvalid + end end end end diff --git a/activerecord/test/cases/associations/belongs_to_associations_test.rb b/activerecord/test/cases/associations/belongs_to_associations_test.rb index 28705b3e56bd..fc5c8684f36e 100644 --- a/activerecord/test/cases/associations/belongs_to_associations_test.rb +++ b/activerecord/test/cases/associations/belongs_to_associations_test.rb @@ -1851,7 +1851,7 @@ def test_async_load_belongs_to client = Client.find(3) first_firm = companies(:first_firm) - promise = client.association(:firm).async_load_target + promise = client.load_async(:firm) wait_for_async_query events = [] diff --git a/activerecord/test/cases/associations/has_many_associations_test.rb b/activerecord/test/cases/associations/has_many_associations_test.rb index 0e29ebe99c5a..e5c81f4f43bf 100644 --- a/activerecord/test/cases/associations/has_many_associations_test.rb +++ b/activerecord/test/cases/associations/has_many_associations_test.rb @@ -3263,7 +3263,7 @@ class AsyncHasOneAssociationsTest < ActiveRecord::TestCase def test_async_load_has_many firm = companies(:first_firm) - promise = firm.association(:clients).async_load_target + promise = firm.load_async(:clients) wait_for_async_query events = [] diff --git a/activerecord/test/cases/associations/has_one_associations_test.rb b/activerecord/test/cases/associations/has_one_associations_test.rb index d721a24c8180..6c8d8e763a9d 100644 --- a/activerecord/test/cases/associations/has_one_associations_test.rb +++ b/activerecord/test/cases/associations/has_one_associations_test.rb @@ -955,7 +955,7 @@ def test_async_load_has_one firm = companies(:first_firm) first_account = Account.find(1) - promise = firm.association(:account).async_load_target + promise = firm.load_async(:account) wait_for_async_query events = []