diff --git a/Shuttle.Core.Data/DatabaseContextCache.cs b/Shuttle.Core.Data/DatabaseContextCache.cs index a467d05..fde0c9a 100644 --- a/Shuttle.Core.Data/DatabaseContextCache.cs +++ b/Shuttle.Core.Data/DatabaseContextCache.cs @@ -7,27 +7,35 @@ public class DatabaseContextCache : IDatabaseContextCache { public IDatabaseContext Current { get; private set; } - public void Use(string name) + public ActiveDatabaseContext Use(string name) { - Current = DatabaseContexts.Find(candidate => candidate.Name.Equals(name, StringComparison.OrdinalIgnoreCase)); + var current = Current; + + Current = DatabaseContexts.Find(candidate => candidate.Name.Equals(name, StringComparison.OrdinalIgnoreCase)); if (Current == null) { throw new Exception(string.Format(Resources.DatabaseContextNameNotFoundException, name)); } + + return new ActiveDatabaseContext(this, current); } - public void Use(IDatabaseContext context) + public ActiveDatabaseContext Use(IDatabaseContext context) { Guard.AgainstNull(context, nameof(context)); + var current = Current; + Current = DatabaseContexts.Find(candidate => candidate.Key.Equals(context.Key)); if (Current == null) { throw new Exception(string.Format(Resources.DatabaseContextKeyNotFoundException, context.Key)); } - } + + return new ActiveDatabaseContext(this, current); + } public bool Contains(string connectionString) { @@ -42,9 +50,10 @@ public void Add(IDatabaseContext context) } DatabaseContexts.Add(context); + Use(context); } - private IDatabaseContext Find(IDatabaseContext context) + private IDatabaseContext Find(IDatabaseContext context) { return DatabaseContexts.Find(candidate => candidate.Key.Equals(context.Key)); } @@ -78,7 +87,7 @@ public IDatabaseContext Get(string connectionString) return result; } - public DatabaseContextCollection DatabaseContexts { get; private set; } + public DatabaseContextCollection DatabaseContexts { get; } public DatabaseContextCache() { diff --git a/Shuttle.Core.Data/DatabaseContextFactory.cs b/Shuttle.Core.Data/DatabaseContextFactory.cs index 29c56af..88f6c72 100644 --- a/Shuttle.Core.Data/DatabaseContextFactory.cs +++ b/Shuttle.Core.Data/DatabaseContextFactory.cs @@ -1,5 +1,4 @@ using System; -using System.Configuration; using System.Data; using Shuttle.Core.Contract; diff --git a/Shuttle.Core.Data/IDatabaseContextCache.cs b/Shuttle.Core.Data/IDatabaseContextCache.cs index 0f1b769..691e6bf 100644 --- a/Shuttle.Core.Data/IDatabaseContextCache.cs +++ b/Shuttle.Core.Data/IDatabaseContextCache.cs @@ -3,8 +3,8 @@ namespace Shuttle.Core.Data public interface IDatabaseContextCache { IDatabaseContext Current { get; } - void Use(string name); - void Use(IDatabaseContext context); + ActiveDatabaseContext Use(string name); + ActiveDatabaseContext Use(IDatabaseContext context); bool Contains(string connectionString); void Add(IDatabaseContext context); void Remove(IDatabaseContext context); diff --git a/Shuttle.Core.Data/Properties/AssemblyInfo.cs b/Shuttle.Core.Data/Properties/AssemblyInfo.cs index 8c4159a..8f009fc 100644 --- a/Shuttle.Core.Data/Properties/AssemblyInfo.cs +++ b/Shuttle.Core.Data/Properties/AssemblyInfo.cs @@ -33,10 +33,10 @@ [assembly: AssemblyTitle(".NET Standard 2.0")] #endif -[assembly: AssemblyVersion("10.0.9.0")] +[assembly: AssemblyVersion("10.0.10.0")] [assembly: AssemblyCopyright("Copyright © Eben Roux 2017")] [assembly: AssemblyProduct("Shuttle.Core.Data")] [assembly: AssemblyCompany("Shuttle")] [assembly: AssemblyConfiguration("Release")] -[assembly: AssemblyInformationalVersion("10.0.9")] +[assembly: AssemblyInformationalVersion("10.0.10")] [assembly: ComVisible(false)] diff --git a/Shuttle.Core.Data/ThreadStaticDatabaseContextCache.cs b/Shuttle.Core.Data/ThreadStaticDatabaseContextCache.cs index f3dc5bb..f416a79 100644 --- a/Shuttle.Core.Data/ThreadStaticDatabaseContextCache.cs +++ b/Shuttle.Core.Data/ThreadStaticDatabaseContextCache.cs @@ -6,19 +6,16 @@ public class ThreadStaticDatabaseContextCache : IDatabaseContextCache { [ThreadStatic] private static DatabaseContextCache _cache; - public IDatabaseContext Current - { - get { return GuardedCache().Current; } - } + public IDatabaseContext Current => GuardedCache().Current; - public void Use(string name) + public ActiveDatabaseContext Use(string name) { - GuardedCache().Use(name); + return GuardedCache().Use(name); } - public void Use(IDatabaseContext context) + public ActiveDatabaseContext Use(IDatabaseContext context) { - GuardedCache().Use(context); + return GuardedCache().Use(context); } public bool Contains(string connectionString) @@ -29,7 +26,6 @@ public bool Contains(string connectionString) public void Add(IDatabaseContext context) { GuardedCache().Add(context); - GuardedCache().Use(context); } public void Remove(IDatabaseContext context)