Fix loading plugins from ZIP files

This commit is contained in:
Westin Miller
2017-11-01 19:50:02 -07:00
parent 462eb77e0d
commit b3ab0cbd74
2 changed files with 54 additions and 12 deletions

View File

@@ -236,7 +236,8 @@ namespace Torch.Managers
if (!file.Contains(".dll", StringComparison.CurrentCultureIgnoreCase)) if (!file.Contains(".dll", StringComparison.CurrentCultureIgnoreCase))
continue; continue;
if (false) { if (false)
{
var asm = Assembly.LoadFrom(file); var asm = Assembly.LoadFrom(file);
assemblies.Add(asm); assemblies.Add(asm);
TorchBase.RegisterAuxAssembly(asm); TorchBase.RegisterAuxAssembly(asm);
@@ -251,8 +252,15 @@ namespace Torch.Managers
var symbolPath = Path.Combine(Path.GetDirectoryName(file) ?? ".", var symbolPath = Path.Combine(Path.GetDirectoryName(file) ?? ".",
Path.GetFileNameWithoutExtension(file) + ".pdb"); Path.GetFileNameWithoutExtension(file) + ".pdb");
if (File.Exists(symbolPath)) if (File.Exists(symbolPath))
try
{
using (var symbolStream = File.OpenRead(symbolPath)) using (var symbolStream = File.OpenRead(symbolPath))
symbol = symbolStream.ReadToEnd(); symbol = symbolStream.ReadToEnd();
}
catch (Exception e)
{
_log.Warn(e, $"Failed to read debugging symbols from {symbolPath}");
}
Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data); Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data);
#else #else
Assembly asm = Assembly.Load(data); Assembly asm = Assembly.Load(data);
@@ -283,10 +291,29 @@ namespace Torch.Managers
if (!entry.Name.Contains(".dll", StringComparison.CurrentCultureIgnoreCase)) if (!entry.Name.Contains(".dll", StringComparison.CurrentCultureIgnoreCase))
continue; continue;
using (var stream = entry.Open()) using (var stream = entry.Open())
{ {
var data = stream.ReadToEnd(); var data = stream.ReadToEnd((int)entry.Length);
#if DEBUG
byte[] symbol = null;
var symbolEntryName = entry.FullName.Substring(0, entry.FullName.Length - "dll".Length) + "pdb";
var symbolEntry = zipFile.GetEntry(symbolEntryName);
if (symbolEntry != null)
try
{
using (var symbolStream = symbolEntry.Open())
symbol = symbolStream.ReadToEnd((int)symbolEntry.Length);
}
catch (Exception e)
{
_log.Warn(e, $"Failed to read debugging symbols from {path}:{symbolEntryName}");
}
Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data);
#else
Assembly asm = Assembly.Load(data); Assembly asm = Assembly.Load(data);
#endif
assemblies.Add(asm);
TorchBase.RegisterAuxAssembly(asm); TorchBase.RegisterAuxAssembly(asm);
} }
} }

View File

@@ -12,24 +12,39 @@ namespace Torch.Utils
{ {
private static readonly ThreadLocal<WeakReference<byte[]>> _streamBuffer = new ThreadLocal<WeakReference<byte[]>>(() => new WeakReference<byte[]>(null)); private static readonly ThreadLocal<WeakReference<byte[]>> _streamBuffer = new ThreadLocal<WeakReference<byte[]>>(() => new WeakReference<byte[]>(null));
public static byte[] ReadToEnd(this Stream stream) private static long LengthSafe(this Stream stream)
{
try
{
return stream.Length;
}
catch
{
return 512;
}
}
public static byte[] ReadToEnd(this Stream stream, int optionalDataLength = -1)
{ {
byte[] buffer; byte[] buffer;
if (!_streamBuffer.Value.TryGetTarget(out buffer)) if (!_streamBuffer.Value.TryGetTarget(out buffer))
buffer = new byte[stream.Length]; buffer = new byte[stream.LengthSafe()];
if (buffer.Length < stream.Length) var initialBufferSize = optionalDataLength > 0 ? optionalDataLength : stream.LengthSafe();
buffer = new byte[stream.Length]; if (buffer.Length < initialBufferSize)
buffer = new byte[initialBufferSize];
if (buffer.Length < 1024) if (buffer.Length < 1024)
buffer = new byte[1024]; buffer = new byte[1024];
var streamPosition = 0;
while (true) while (true)
{ {
if (buffer.Length == stream.Position) if (buffer.Length == streamPosition)
Array.Resize(ref buffer, Math.Max((int)stream.Length, buffer.Length * 2)); Array.Resize(ref buffer, Math.Max((int)stream.LengthSafe(), buffer.Length * 2));
int count = stream.Read(buffer, (int)stream.Position, buffer.Length - (int)stream.Position); int count = stream.Read(buffer, streamPosition, buffer.Length - streamPosition);
if (count == 0) if (count == 0)
break; break;
streamPosition += count;
} }
var result = new byte[(int)stream.Position]; var result = new byte[streamPosition];
Array.Copy(buffer, 0, result, 0, result.Length); Array.Copy(buffer, 0, result, 0, result.Length);
_streamBuffer.Value.SetTarget(buffer); _streamBuffer.Value.SetTarget(buffer);
return result; return result;