diff --git a/Torch/Plugins/PluginManager.cs b/Torch/Plugins/PluginManager.cs index 7c32ad3..54e978c 100644 --- a/Torch/Plugins/PluginManager.cs +++ b/Torch/Plugins/PluginManager.cs @@ -233,17 +233,9 @@ namespace Torch.Managers foreach (var file in files) { - if (!file.Contains(".dll", StringComparison.CurrentCultureIgnoreCase)) + if (!file.EndsWith(".dll", StringComparison.CurrentCultureIgnoreCase)) continue; - if (false) - { - var asm = Assembly.LoadFrom(file); - assemblies.Add(asm); - TorchBase.RegisterAuxAssembly(asm); - continue; - } - using (var stream = File.OpenRead(file)) { var data = stream.ReadToEnd(); @@ -261,15 +253,14 @@ namespace Torch.Managers { _log.Warn(e, $"Failed to read debugging symbols from {symbolPath}"); } - Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data); + assemblies.Add(symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data)); #else - Assembly asm = Assembly.Load(data); + assemblies.Add(Assembly.Load(data)); #endif - assemblies.Add(asm); - TorchBase.RegisterAuxAssembly(asm); } } + RegisterAllAssemblies(assemblies); InstantiatePlugin(manifest, assemblies); } @@ -288,7 +279,7 @@ namespace Torch.Managers foreach (var entry in zipFile.Entries) { - if (!entry.Name.Contains(".dll", StringComparison.CurrentCultureIgnoreCase)) + if (!entry.Name.EndsWith(".dll", StringComparison.CurrentCultureIgnoreCase)) continue; @@ -309,19 +300,52 @@ namespace Torch.Managers { _log.Warn(e, $"Failed to read debugging symbols from {path}:{symbolEntryName}"); } - Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data); + assemblies.Add(symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data)); #else - Assembly asm = Assembly.Load(data); + assemblies.Add(Assembly.Load(data)); #endif - assemblies.Add(asm); - TorchBase.RegisterAuxAssembly(asm); } } } + RegisterAllAssemblies(assemblies); InstantiatePlugin(manifest, assemblies); } + private void RegisterAllAssemblies(IReadOnlyCollection assemblies) + { + Assembly ResolveDependentAssembly(object sender, ResolveEventArgs args) + { + var requiredAssemblyName = new AssemblyName(args.Name); + foreach (Assembly asm in assemblies) + { + if (IsAssemblyCompatible(requiredAssemblyName, asm.GetName())) + return asm; + } + _log.Warn($"Could find dependent assembly! Requesting assembly: {args.RequestingAssembly}, dependent assembly: {requiredAssemblyName}"); + return null; + } + + try + { + AppDomain.CurrentDomain.AssemblyResolve += ResolveDependentAssembly; + foreach (Assembly asm in assemblies) + { + TorchBase.RegisterAuxAssembly(asm); + } + } + finally + { + AppDomain.CurrentDomain.AssemblyResolve -= ResolveDependentAssembly; + } + } + + private static bool IsAssemblyCompatible(AssemblyName a, AssemblyName b) + { + return a.Name == b.Name && a.Version.Major == b.Version.Major && a.Version.Minor == b.Version.Minor; + } + + private PluginManifest GetManifestFromZip(string path) { using (var zipFile = ZipFile.OpenRead(path))