This commit is contained in:
Brant Martin
2018-10-23 16:30:15 -04:00
12 changed files with 616 additions and 131 deletions

View File

@@ -1,10 +1,12 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using System.Reflection.Emit; using System.Reflection.Emit;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Threading.Tasks;
using Torch.Managers.PatchManager; using Torch.Managers.PatchManager;
using Torch.Managers.PatchManager.MSIL; using Torch.Managers.PatchManager.MSIL;
using Torch.Utils; using Torch.Utils;
@@ -17,6 +19,7 @@ namespace Torch.Tests
public class PatchTest public class PatchTest
{ {
#region TestRunner #region TestRunner
private static readonly PatchManager _patchContext = new PatchManager(null); private static readonly PatchManager _patchContext = new PatchManager(null);
[Theory] [Theory]
@@ -48,6 +51,133 @@ namespace Torch.Tests
} }
[Fact]
public void TestTryCatchNop()
{
var ctx = _patchContext.AcquireContext();
ctx.GetPattern(TryCatchTest._target).Transpilers.Add(_nopTranspiler);
_patchContext.Commit();
Assert.False(TryCatchTest.Target());
Assert.True(TryCatchTest.FinallyHit);
_patchContext.FreeContext(ctx);
_patchContext.Commit();
}
[Fact]
public void TestTryCatchCancel()
{
var ctx = _patchContext.AcquireContext();
ctx.GetPattern(TryCatchTest._target).Transpilers.Add(TryCatchTest._removeThrowTranspiler);
ctx.GetPattern(TryCatchTest._target).DumpTarget = @"C:\tmp\dump.txt";
ctx.GetPattern(TryCatchTest._target).DumpMode = MethodRewritePattern.PrintModeEnum.Original | MethodRewritePattern.PrintModeEnum.Patched;
_patchContext.Commit();
Assert.True(TryCatchTest.Target());
Assert.True(TryCatchTest.FinallyHit);
_patchContext.FreeContext(ctx);
_patchContext.Commit();
}
private static readonly MethodInfo _nopTranspiler = typeof(PatchTest).GetMethod(nameof(NopTranspiler), BindingFlags.Static | BindingFlags.NonPublic);
private static IEnumerable<MsilInstruction> NopTranspiler(IEnumerable<MsilInstruction> input)
{
return input;
}
private class TryCatchTest
{
public static readonly MethodInfo _removeThrowTranspiler =
typeof(TryCatchTest).GetMethod(nameof(RemoveThrowTranspiler), BindingFlags.Static | BindingFlags.NonPublic);
private static IEnumerable<MsilInstruction> RemoveThrowTranspiler(IEnumerable<MsilInstruction> input)
{
foreach (var i in input)
if (i.OpCode == OpCodes.Throw)
yield return i.CopyWith(OpCodes.Pop);
else
yield return i;
}
public static readonly MethodInfo _target = typeof(TryCatchTest).GetMethod(nameof(Target), BindingFlags.Public | BindingFlags.Static);
public static bool FinallyHit = false;
public static bool Target()
{
FinallyHit = false;
try
{
try
{
// shim to prevent compiler optimization
if ("test".Length > "".Length)
throw new Exception();
return true;
}
catch (IOException ioe)
{
return false;
}
catch (Exception e)
{
return false;
}
finally
{
FinallyHit = true;
}
}
catch (Exception e)
{
throw;
}
}
}
[Fact]
public void TestAsyncNop()
{
var candidates = new List<Type>();
var nestedTypes = typeof(PatchTest).GetNestedTypes(BindingFlags.NonPublic | BindingFlags.Static);
foreach (var nested in nestedTypes)
if (nested.Name.StartsWith("<" + nameof(TestAsyncMethod) + ">"))
{
var good = false;
foreach (var itf in nested.GetInterfaces())
if (itf.FullName == typeof(IAsyncStateMachine).FullName)
{
good = true;
break;
}
if (good)
candidates.Add(nested);
}
if (candidates.Count != 1)
throw new Exception("Couldn't find async worker");
var method = candidates[0].GetMethod("MoveNext", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
if (method == null)
throw new Exception("Failed to find state machine move next instruction, cannot proceed");
var ctx = _patchContext.AcquireContext();
ctx.GetPattern(method).Transpilers.Add(_nopTranspiler);
ctx.GetPattern(method).DumpTarget = @"C:\tmp\dump.txt";
ctx.GetPattern(method).DumpMode = MethodRewritePattern.PrintModeEnum.Original | MethodRewritePattern.PrintModeEnum.Patched;
_patchContext.Commit();
Assert.Equal("TEST", TestAsyncMethod().Result);
_patchContext.FreeContext(ctx);
_patchContext.Commit();
}
private async Task<string> TestAsyncMethod()
{
var first = await Task.Run(() => "TE");
var last = await Task.Run(() => "ST");
return await Task.Run(() => first + last);
}
public class TestBootstrap public class TestBootstrap
{ {
@@ -82,7 +212,7 @@ namespace Torch.Tests
if (_targetAssert == null) if (_targetAssert == null)
throw new Exception($"{t.FullName} must have a method named AssertNormal"); throw new Exception($"{t.FullName} must have a method named AssertNormal");
_instance = !_targetMethod.IsStatic ? Activator.CreateInstance(t) : null; _instance = !_targetMethod.IsStatic ? Activator.CreateInstance(t) : null;
_targetParams = (object[])t.GetField("_targetParams", flags)?.GetValue(null) ?? new object[0]; _targetParams = (object[]) t.GetField("_targetParams", flags)?.GetValue(null) ?? new object[0];
} }
private void Invoke(MethodBase i, params object[] args) private void Invoke(MethodBase i, params object[] args)
@@ -185,10 +315,11 @@ namespace Torch.Tests
_patchTest.Add(new TestBootstrap(type)); _patchTest.Add(new TestBootstrap(type));
} }
public static IEnumerable<object[]> Prefixes => _patchTest.Where(x => x.HasPrefix).Select(x => new object[] { x }); public static IEnumerable<object[]> Prefixes => _patchTest.Where(x => x.HasPrefix).Select(x => new object[] {x});
public static IEnumerable<object[]> Transpilers => _patchTest.Where(x => x.HasTranspile).Select(x => new object[] { x }); public static IEnumerable<object[]> Transpilers => _patchTest.Where(x => x.HasTranspile).Select(x => new object[] {x});
public static IEnumerable<object[]> Suffixes => _patchTest.Where(x => x.HasSuffix).Select(x => new object[] { x }); public static IEnumerable<object[]> Suffixes => _patchTest.Where(x => x.HasSuffix).Select(x => new object[] {x});
public static IEnumerable<object[]> Combo => _patchTest.Where(x => x.HasPrefix || x.HasTranspile || x.HasSuffix).Select(x => new object[] { x }); public static IEnumerable<object[]> Combo => _patchTest.Where(x => x.HasPrefix || x.HasTranspile || x.HasSuffix).Select(x => new object[] {x});
#endregion #endregion
#region PatchTests #region PatchTests
@@ -220,7 +351,8 @@ namespace Torch.Tests
{ {
yield return new MsilInstruction(OpCodes.Ldnull); yield return new MsilInstruction(OpCodes.Ldnull);
yield return new MsilInstruction(OpCodes.Ldc_I4_1); yield return new MsilInstruction(OpCodes.Ldc_I4_1);
yield return new MsilInstruction(OpCodes.Stfld).InlineValue(typeof(StaticNoRetNoParm).GetField("_transpileHit", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public)); yield return new MsilInstruction(OpCodes.Stfld).InlineValue(typeof(StaticNoRetNoParm).GetField("_transpileHit",
BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public));
foreach (MsilInstruction i in instructions) foreach (MsilInstruction i in instructions)
yield return i; yield return i;
} }
@@ -255,7 +387,7 @@ namespace Torch.Tests
private class StaticNoRetParam private class StaticNoRetParam
{ {
private static bool _prefixHit, _normalHit, _suffixHit; private static bool _prefixHit, _normalHit, _suffixHit;
private static readonly object[] _targetParams = { "test", 1, new StringBuilder("test1") }; private static readonly object[] _targetParams = {"test", 1, new StringBuilder("test1")};
[MethodImpl(MethodImplOptions.NoInlining)] [MethodImpl(MethodImplOptions.NoInlining)]
public static void Prefix(string str, int i, StringBuilder o) public static void Prefix(string str, int i, StringBuilder o)
@@ -306,8 +438,8 @@ namespace Torch.Tests
private class StaticNoRetParamReplace private class StaticNoRetParamReplace
{ {
private static bool _prefixHit, _normalHit, _suffixHit; private static bool _prefixHit, _normalHit, _suffixHit;
private static readonly object[] _targetParams = { "test", 1, new StringBuilder("stest1") }; private static readonly object[] _targetParams = {"test", 1, new StringBuilder("stest1")};
private static readonly object[] _replacedParams = { "test2", 2, new StringBuilder("stest2") }; private static readonly object[] _replacedParams = {"test2", 2, new StringBuilder("stest2")};
private static object[] _calledParams; private static object[] _calledParams;
[MethodImpl(MethodImplOptions.NoInlining)] [MethodImpl(MethodImplOptions.NoInlining)]
@@ -316,16 +448,16 @@ namespace Torch.Tests
Assert.Equal(_targetParams[0], str); Assert.Equal(_targetParams[0], str);
Assert.Equal(_targetParams[1], i); Assert.Equal(_targetParams[1], i);
Assert.Equal(_targetParams[2], o); Assert.Equal(_targetParams[2], o);
str = (string)_replacedParams[0]; str = (string) _replacedParams[0];
i = (int)_replacedParams[1]; i = (int) _replacedParams[1];
o = (StringBuilder)_replacedParams[2]; o = (StringBuilder) _replacedParams[2];
_prefixHit = true; _prefixHit = true;
} }
[MethodImpl(MethodImplOptions.NoInlining)] [MethodImpl(MethodImplOptions.NoInlining)]
public static void Target(string str, int i, StringBuilder o) public static void Target(string str, int i, StringBuilder o)
{ {
_calledParams = new object[] { str, i, o }; _calledParams = new object[] {str, i, o};
_normalHit = true; _normalHit = true;
} }
@@ -380,6 +512,7 @@ namespace Torch.Tests
Assert.True(_prefixHit, "Failed to prefix"); Assert.True(_prefixHit, "Failed to prefix");
} }
} }
#endregion #endregion
} }
#pragma warning restore 414 #pragma warning restore 414

View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel.Design; using System.ComponentModel.Design;
using System.Diagnostics; using System.Diagnostics;
using System.IO;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using System.Reflection.Emit; using System.Reflection.Emit;
@@ -44,7 +45,7 @@ namespace Torch.Managers.PatchManager
if (Prefixes.Count == 0 && Suffixes.Count == 0 && Transpilers.Count == 0 && PostTranspilers.Count == 0) if (Prefixes.Count == 0 && Suffixes.Count == 0 && Transpilers.Count == 0 && PostTranspilers.Count == 0)
return; return;
_log.Log(PrintMsil ? LogLevel.Info : LogLevel.Debug, _log.Log(PrintMode != 0 ? LogLevel.Info : LogLevel.Debug,
$"Begin patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); $"Begin patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})");
var patch = ComposePatchedMethod(); var patch = ComposePatchedMethod();
@@ -52,7 +53,7 @@ namespace Torch.Managers.PatchManager
var newAddress = AssemblyMemory.GetMethodBodyStart(patch); var newAddress = AssemblyMemory.GetMethodBodyStart(patch);
_revertData = AssemblyMemory.WriteJump(_revertAddress, newAddress); _revertData = AssemblyMemory.WriteJump(_revertAddress, newAddress);
_pinnedPatch = GCHandle.Alloc(patch); _pinnedPatch = GCHandle.Alloc(patch);
_log.Log(PrintMsil ? LogLevel.Info : LogLevel.Debug, _log.Log(PrintMode != 0 ? LogLevel.Info : LogLevel.Debug,
$"Done patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); $"Done patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})");
} }
catch (Exception exception) catch (Exception exception)
@@ -104,17 +105,74 @@ namespace Torch.Managers.PatchManager
public const string PREFIX_SKIPPED_PARAMETER = "__prefixSkipped"; public const string PREFIX_SKIPPED_PARAMETER = "__prefixSkipped";
public const string LOCAL_PARAMETER = "__local"; public const string LOCAL_PARAMETER = "__local";
private void SavePatchedMethod(string target)
{
var asmBuilder =
AppDomain.CurrentDomain.DefineDynamicAssembly(new AssemblyName("SomeName"), AssemblyBuilderAccess.RunAndSave, Path.GetDirectoryName(target));
var moduleBuilder = asmBuilder.DefineDynamicModule(Path.GetFileNameWithoutExtension(target), Path.GetFileName(target));
var typeBuilder = moduleBuilder.DefineType("Test", TypeAttributes.Public);
var methodName = _method.Name + $"_{_patchSalt}";
var returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void);
var parameters = _method.GetParameters();
var parameterTypes = (_method.IsStatic ? Enumerable.Empty<Type>() : new[] {_method.DeclaringType})
.Concat(parameters.Select(x => x.ParameterType)).ToArray();
var patchMethod = typeBuilder.DefineMethod(methodName, MethodAttributes.Public | MethodAttributes.Static, CallingConventions.Standard,
returnType, parameterTypes);
if (!_method.IsStatic)
patchMethod.DefineParameter(0, ParameterAttributes.None, INSTANCE_PARAMETER);
for (var i = 0; i < parameters.Length; i++)
patchMethod.DefineParameter((patchMethod.IsStatic ? 0 : 1) + i, parameters[i].Attributes, parameters[i].Name);
var generator = new LoggingIlGenerator(patchMethod.GetILGenerator(), LogLevel.Trace);
List<MsilInstruction> il = EmitPatched((type, pinned) => new MsilLocal(generator.DeclareLocal(type, pinned))).ToList();
MethodTranspiler.EmitMethod(il, generator);
Type res = typeBuilder.CreateType();
asmBuilder.Save(Path.GetFileName(target));
foreach (var method in res.GetMethods(BindingFlags.Public | BindingFlags.Static))
_log.Info($"Information " + method);
}
public DynamicMethod ComposePatchedMethod() public DynamicMethod ComposePatchedMethod()
{ {
DynamicMethod method = AllocatePatchMethod(); DynamicMethod method = AllocatePatchMethod();
var generator = new LoggingIlGenerator(method.GetILGenerator(), PrintMsil ? LogLevel.Info : LogLevel.Trace); var generator = new LoggingIlGenerator(method.GetILGenerator(),
PrintMode.HasFlag(PrintModeEnum.EmittedReflection) ? LogLevel.Info : LogLevel.Trace);
List<MsilInstruction> il = EmitPatched((type, pinned) => new MsilLocal(generator.DeclareLocal(type, pinned))).ToList(); List<MsilInstruction> il = EmitPatched((type, pinned) => new MsilLocal(generator.DeclareLocal(type, pinned))).ToList();
if (PrintMsil)
var dumpTarget = DumpTarget != null ? File.CreateText(DumpTarget) : null;
try
{
const string gap = "\n\n\n\n\n";
void LogTarget(PrintModeEnum mode, bool err, string msg)
{
if (DumpMode.HasFlag(mode))
dumpTarget?.WriteLine((err ? "ERROR " : "") + msg);
if (!PrintMode.HasFlag(mode)) return;
if (err)
_log.Error(msg);
else
_log.Info(msg);
}
if (PrintMsil || DumpTarget != null)
{ {
lock (_log) lock (_log)
{ {
MethodTranspiler.IntegrityAnalysis(LogLevel.Info, il); var ctx = new MethodContext(_method);
ctx.Read();
LogTarget(PrintModeEnum.Original, false, "========== Original method ==========");
MethodTranspiler.IntegrityAnalysis((a, b) => LogTarget(PrintModeEnum.Original, a, b), ctx.Instructions, true);
LogTarget(PrintModeEnum.Original, false, gap);
LogTarget(PrintModeEnum.Emitted, false, "========== Desired method ==========");
MethodTranspiler.IntegrityAnalysis((a, b) => LogTarget(PrintModeEnum.Emitted, a, b), il);
LogTarget(PrintModeEnum.Emitted, false, gap);
} }
} }
@@ -130,12 +188,29 @@ namespace Torch.Managers.PatchManager
{ {
var ctx = new MethodContext(method); var ctx = new MethodContext(method);
ctx.Read(); ctx.Read();
MethodTranspiler.IntegrityAnalysis(LogLevel.Warn, ctx.Instructions); MethodTranspiler.IntegrityAnalysis((err, msg) => _log.Warn(msg), ctx.Instructions);
} }
throw; throw;
} }
if (PrintMsil || DumpTarget != null)
{
lock (_log)
{
var ctx = new MethodContext(method);
ctx.Read();
LogTarget(PrintModeEnum.Patched, false, "========== Patched method ==========");
MethodTranspiler.IntegrityAnalysis((a, b) => LogTarget(PrintModeEnum.Patched, a, b), ctx.Instructions, true);
LogTarget(PrintModeEnum.Patched, false, gap);
}
}
}
finally
{
dumpTarget?.Close();
}
return method; return method;
} }
@@ -274,7 +349,8 @@ namespace Torch.Managers.PatchManager
? param.ParameterType.GetElementType() ? param.ParameterType.GetElementType()
: param.ParameterType; : param.ParameterType;
if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].Type)) if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].Type))
throw new Exception($"Return type {specialVariables[RESULT_PARAMETER].Type} can't be assigned to result parameter type {retType}"); throw new Exception(
$"Return type {specialVariables[RESULT_PARAMETER].Type} can't be assigned to result parameter type {retType}");
yield return new MsilInstruction(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc) yield return new MsilInstruction(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc)
.InlineValue(specialVariables[RESULT_PARAMETER]); .InlineValue(specialVariables[RESULT_PARAMETER]);
break; break;

View File

@@ -2,9 +2,11 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq;
using System.Reflection; using System.Reflection;
using System.Reflection.Emit; using System.Reflection.Emit;
using System.Text; using System.Text;
using System.Windows.Documents;
using Torch.Managers.PatchManager.Transpile; using Torch.Managers.PatchManager.Transpile;
using Torch.Utils; using Torch.Utils;
@@ -103,8 +105,21 @@ namespace Torch.Managers.PatchManager.MSIL
/// <summary> /// <summary>
/// The try catch operation that is performed here. /// The try catch operation that is performed here.
/// </summary> /// </summary>
public MsilTryCatchOperation TryCatchOperation { get; set; } = null; [Obsolete("Since instructions can have multiple try catch operations you need to be using TryCatchOperations")]
public MsilTryCatchOperation TryCatchOperation
{
get => TryCatchOperations.FirstOrDefault();
set
{
TryCatchOperations.Clear();
TryCatchOperations.Add(value);
}
}
/// <summary>
/// The try catch operations performed here, in order from first to last.
/// </summary>
public readonly List<MsilTryCatchOperation> TryCatchOperations = new List<MsilTryCatchOperation>();
private static readonly ConcurrentDictionary<Type, PropertyInfo> _setterInfoForInlines = new ConcurrentDictionary<Type, PropertyInfo>(); private static readonly ConcurrentDictionary<Type, PropertyInfo> _setterInfoForInlines = new ConcurrentDictionary<Type, PropertyInfo>();
@@ -125,15 +140,18 @@ namespace Torch.Managers.PatchManager.MSIL
target = genType.GetProperty(nameof(MsilOperandInline<int>.Value)); target = genType.GetProperty(nameof(MsilOperandInline<int>.Value));
_setterInfoForInlines[type] = target; _setterInfoForInlines[type] = target;
} }
Debug.Assert(target?.DeclaringType != null); Debug.Assert(target?.DeclaringType != null);
if (target.DeclaringType.IsInstanceOfType(Operand)) if (target.DeclaringType.IsInstanceOfType(Operand))
{ {
target.SetValue(Operand, o); target.SetValue(Operand, o);
return this; return this;
} }
type = type.BaseType; type = type.BaseType;
} }
((MsilOperandInline<T>)Operand).Value = o;
((MsilOperandInline<T>) Operand).Value = o;
return this; return this;
} }
@@ -148,7 +166,8 @@ namespace Torch.Managers.PatchManager.MSIL
Operand?.CopyTo(result.Operand); Operand?.CopyTo(result.Operand);
foreach (MsilLabel x in Labels) foreach (MsilLabel x in Labels)
result.Labels.Add(x); result.Labels.Add(x);
result.TryCatchOperation = TryCatchOperation; foreach (var op in TryCatchOperations)
result.TryCatchOperations.Add(op);
return result; return result;
} }
@@ -170,7 +189,7 @@ namespace Torch.Managers.PatchManager.MSIL
/// <returns>This instruction</returns> /// <returns>This instruction</returns>
public MsilInstruction InlineTarget(MsilLabel label) public MsilInstruction InlineTarget(MsilLabel label)
{ {
((MsilOperandBrTarget)Operand).Target = label; ((MsilOperandBrTarget) Operand).Target = label;
return this; return this;
} }
@@ -185,7 +204,6 @@ namespace Torch.Managers.PatchManager.MSIL
} }
#pragma warning disable 649 #pragma warning disable 649
[ReflectedMethod(Name = "StackChange")] [ReflectedMethod(Name = "StackChange")]
private static Func<OpCode, int> _stackChange; private static Func<OpCode, int> _stackChange;
@@ -210,7 +228,13 @@ namespace Torch.Managers.PatchManager.MSIL
if (!op.IsStatic && OpCode != OpCodes.Newobj) if (!op.IsStatic && OpCode != OpCodes.Newobj)
num--; num--;
} }
return num; return num;
} }
/// <summary>
/// Gets the maximum amount of space this instruction will use.
/// </summary>
public int MaxBytes => 2 + (Operand?.MaxBytes ?? 0);
} }
} }

View File

@@ -18,6 +18,11 @@ namespace Torch.Managers.PatchManager.MSIL
/// </summary> /// </summary>
public MsilInstruction Instruction { get; } public MsilInstruction Instruction { get; }
/// <summary>
/// Gets the maximum amount of space this operand will use.
/// </summary>
public abstract int MaxBytes { get; }
internal abstract void CopyTo(MsilOperand operand); internal abstract void CopyTo(MsilOperand operand);
internal abstract void Read(MethodContext context, BinaryReader reader); internal abstract void Read(MethodContext context, BinaryReader reader);

View File

@@ -57,6 +57,8 @@ namespace Torch.Managers.PatchManager.MSIL
} }
public override int MaxBytes => 4; // Long branch
internal override void CopyTo(MsilOperand operand) internal override void CopyTo(MsilOperand operand)
{ {
var lt = operand as MsilOperandBrTarget; var lt = operand as MsilOperandBrTarget;

View File

@@ -54,6 +54,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => Instruction.OpCode.OperandType == OperandType.InlineI ? 4 : 1;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
// ReSharper disable once SwitchStatementMissingSomeCases // ReSharper disable once SwitchStatementMissingSomeCases
@@ -98,6 +100,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => 4;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
// ReSharper disable once SwitchStatementMissingSomeCases // ReSharper disable once SwitchStatementMissingSomeCases
@@ -136,6 +140,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => 8;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
// ReSharper disable once SwitchStatementMissingSomeCases // ReSharper disable once SwitchStatementMissingSomeCases
@@ -174,6 +180,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => 8;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
// ReSharper disable once SwitchStatementMissingSomeCases // ReSharper disable once SwitchStatementMissingSomeCases
@@ -212,6 +220,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => 4;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
// ReSharper disable once SwitchStatementMissingSomeCases // ReSharper disable once SwitchStatementMissingSomeCases
@@ -250,6 +260,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => throw new NotImplementedException();
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
// ReSharper disable once SwitchStatementMissingSomeCases // ReSharper disable once SwitchStatementMissingSomeCases
@@ -286,6 +298,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => Instruction.OpCode.OperandType == OperandType.ShortInlineVar ? 1 : 2;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
int id; int id;
@@ -339,6 +353,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => 2;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
int id; int id;
@@ -390,6 +406,8 @@ namespace Torch.Managers.PatchManager.MSIL
{ {
} }
public override int MaxBytes => 4;
internal override void Read(MethodContext context, BinaryReader reader) internal override void Read(MethodContext context, BinaryReader reader)
{ {
object value = null; object value = null;

View File

@@ -20,6 +20,7 @@ namespace Torch.Managers.PatchManager.MSIL
/// </summary> /// </summary>
public MsilLabel[] Labels { get; set; } public MsilLabel[] Labels { get; set; }
public override int MaxBytes => 4 + (Labels?.Length * 4 ?? 0);
internal override void CopyTo(MsilOperand operand) internal override void CopyTo(MsilOperand operand)
{ {

View File

@@ -24,6 +24,8 @@ namespace Torch.Managers.PatchManager.MSIL
/// </summary> /// </summary>
public class MsilTryCatchOperation public class MsilTryCatchOperation
{ {
internal int NativeOffset;
/// <summary> /// <summary>
/// Operation type /// Operation type
/// </summary> /// </summary>

View File

@@ -168,18 +168,71 @@ namespace Torch.Managers.PatchManager
/// <summary> /// <summary>
/// Should the resulting MSIL of the transpile operation be printed. /// Should the resulting MSIL of the transpile operation be printed.
/// </summary> /// </summary>
[Obsolete]
public bool PrintMsil public bool PrintMsil
{ {
get => _parent?.PrintMsil ?? _printMsilBacking; get => PrintMode != 0;
set => PrintMode = PrintModeEnum.Emitted;
}
private PrintModeEnum _printMsilBacking;
/// <summary>
/// Types of IL to print to log
/// </summary>
public PrintModeEnum PrintMode
{
get => _parent?.PrintMode ?? _printMsilBacking;
set set
{ {
if (_parent != null) if (_parent != null)
_parent.PrintMsil = value; _parent.PrintMode = value;
else else
_printMsilBacking = value; _printMsilBacking = value;
} }
} }
private bool _printMsilBacking;
[Flags]
public enum PrintModeEnum
{
Original = 1,
Emitted = 2,
Patched = 4,
EmittedReflection = 8
}
/// <summary>
/// File to dump the emitted MSIL to.
/// </summary>
public string DumpTarget
{
get => _parent?.DumpTarget ?? _dumpTargetBacking;
set
{
if (_parent != null)
_parent.DumpTarget = value;
else
_dumpTargetBacking = value;
}
}
/// <summary>
/// Types of IL to dump to file
/// </summary>
public PrintModeEnum DumpMode
{
get => _parent?.DumpMode ?? _dumpTargetMode;
set
{
if (_parent != null)
_parent.DumpMode = value;
else
_dumpTargetMode = value;
}
}
private PrintModeEnum _dumpTargetMode;
private string _dumpTargetBacking;
private readonly MethodRewritePattern _parent; private readonly MethodRewritePattern _parent;

View File

@@ -40,21 +40,25 @@ namespace Torch.Managers.PatchManager
MethodTranspiler.EmitMethod(insn.ToList(), generator); MethodTranspiler.EmitMethod(insn.ToList(), generator);
} }
public delegate void DelPrintIntegrityInfo(bool error, string msg);
/// <summary> /// <summary>
/// Analyzes the integrity of a set of instructions. /// Analyzes the integrity of a set of instructions.
/// </summary> /// </summary>
/// <param name="level">default logging level</param> /// <param name="handler">Logger</param>
/// <param name="instructions">instructions</param> /// <param name="instructions">instructions</param>
public static void IntegrityAnalysis(LogLevel level, IReadOnlyList<MsilInstruction> instructions) public static void IntegrityAnalysis(DelPrintIntegrityInfo handler, IReadOnlyList<MsilInstruction> instructions)
{ {
MethodTranspiler.IntegrityAnalysis(level, instructions); MethodTranspiler.IntegrityAnalysis(handler, instructions);
} }
#pragma warning disable 649 #pragma warning disable 649
[ReflectedStaticMethod(Type = typeof(RuntimeHelpers), Name = "_CompileMethod", OverrideTypeNames = new[] { "System.IRuntimeMethodInfo" })] [ReflectedStaticMethod(Type = typeof(RuntimeHelpers), Name = "_CompileMethod", OverrideTypeNames = new[] {"System.IRuntimeMethodInfo"})]
private static Action<object> _compileDynamicMethod; private static Action<object> _compileDynamicMethod;
[ReflectedMethod(Name = "GetMethodInfo")] [ReflectedMethod(Name = "GetMethodInfo")]
private static Func<RuntimeMethodHandle, object> _getMethodInfo; private static Func<RuntimeMethodHandle, object> _getMethodInfo;
[ReflectedMethod(Name = "GetMethodDescriptor")] [ReflectedMethod(Name = "GetMethodDescriptor")]
private static Func<DynamicMethod, RuntimeMethodHandle> _getMethodHandle; private static Func<DynamicMethod, RuntimeMethodHandle> _getMethodHandle;
#pragma warning restore 649 #pragma warning restore 649

View File

@@ -8,6 +8,7 @@ using System.Reflection.Emit;
using NLog; using NLog;
using Torch.Managers.PatchManager.MSIL; using Torch.Managers.PatchManager.MSIL;
using Torch.Utils; using Torch.Utils;
using VRage.Game.VisualScripting.Utils;
namespace Torch.Managers.PatchManager.Transpile namespace Torch.Managers.PatchManager.Transpile
{ {
@@ -44,14 +45,47 @@ namespace Torch.Managers.PatchManager.Transpile
#pragma warning disable 649 #pragma warning disable 649
[ReflectedMethod(Name = "BakeByteArray")] private static Func<ILGenerator, byte[]> _ilGeneratorBakeByteArray; [ReflectedMethod(Name = "BakeByteArray")]
private static Func<ILGenerator, byte[]> _ilGeneratorBakeByteArray;
[ReflectedMethod(Name = "GetExceptions")]
private static Func<ILGenerator, Array> _ilGeneratorGetExceptionHandlers;
private const string InternalExceptionInfo = "System.Reflection.Emit.__ExceptionInfo, mscorlib";
[ReflectedMethod(Name = "GetExceptionTypes", TypeName = InternalExceptionInfo)]
private static Func<object, int[]> _exceptionHandlerGetTypes;
[ReflectedMethod(Name = "GetStartAddress", TypeName = InternalExceptionInfo)]
private static Func<object, int> _exceptionHandlerGetStart;
[ReflectedMethod(Name = "GetEndAddress", TypeName = InternalExceptionInfo)]
private static Func<object, int> _exceptionHandlerGetEnd;
[ReflectedMethod(Name = "GetFinallyEndAddress", TypeName = InternalExceptionInfo)]
private static Func<object, int> _exceptionHandlerGetFinallyEnd;
[ReflectedMethod(Name = "GetNumberOfCatches", TypeName = InternalExceptionInfo)]
private static Func<object, int> _exceptionHandlerGetCatchCount;
[ReflectedMethod(Name = "GetCatchAddresses", TypeName = InternalExceptionInfo)]
private static Func<object, int[]> _exceptionHandlerGetCatchAddrs;
[ReflectedMethod(Name = "GetCatchEndAddresses", TypeName = InternalExceptionInfo)]
private static Func<object, int[]> _exceptionHandlerGetCatchEndAddrs;
[ReflectedMethod(Name = "GetFilterAddresses", TypeName = InternalExceptionInfo)]
private static Func<object, int[]> _exceptionHandlerGetFilterAddrs;
#pragma warning restore 649 #pragma warning restore 649
private readonly Array _dynamicExceptionTable;
public MethodContext(DynamicMethod method) public MethodContext(DynamicMethod method)
{ {
Method = null; Method = null;
MethodBody = null; MethodBody = null;
_msilBytes = _ilGeneratorBakeByteArray(method.GetILGenerator()); _msilBytes = _ilGeneratorBakeByteArray(method.GetILGenerator());
_dynamicExceptionTable = _ilGeneratorGetExceptionHandlers(method.GetILGenerator());
TokenResolver = new DynamicMethodTokenResolver(method); TokenResolver = new DynamicMethodTokenResolver(method);
} }
@@ -76,6 +110,7 @@ namespace Torch.Managers.PatchManager.Transpile
{ {
instructionValue = (short) ((instructionValue << 8) | memory.ReadByte()); instructionValue = (short) ((instructionValue << 8) | memory.ReadByte());
} }
if (!OpCodeLookup.TryGetValue(instructionValue, out OpCode opcode)) if (!OpCodeLookup.TryGetValue(instructionValue, out OpCode opcode))
{ {
var msg = $"Unknown opcode {instructionValue:X}"; var msg = $"Unknown opcode {instructionValue:X}";
@@ -83,6 +118,7 @@ namespace Torch.Managers.PatchManager.Transpile
Debug.Assert(false, msg); Debug.Assert(false, msg);
continue; continue;
} }
if (opcode.Size != memory.Position - opcodeOffset) if (opcode.Size != memory.Position - opcodeOffset)
throw new Exception( throw new Exception(
$"Opcode said it was {opcode.Size} but we read {memory.Position - opcodeOffset}"); $"Opcode said it was {opcode.Size} but we read {memory.Position - opcodeOffset}");
@@ -97,28 +133,56 @@ namespace Torch.Managers.PatchManager.Transpile
private void ResolveCatchClauses() private void ResolveCatchClauses()
{ {
if (MethodBody == null) if (MethodBody != null)
return; foreach (var clause in MethodBody.ExceptionHandlingClauses)
foreach (ExceptionHandlingClause clause in MethodBody.ExceptionHandlingClauses)
{ {
var beginInstruction = FindInstruction(clause.TryOffset); AddEhHandler(clause.TryOffset, MsilTryCatchOperationType.BeginExceptionBlock);
var catchInstruction = FindInstruction(clause.HandlerOffset);
var finalInstruction = FindInstruction(clause.HandlerOffset + clause.HandlerLength);
beginInstruction.TryCatchOperation =
new MsilTryCatchOperation(MsilTryCatchOperationType.BeginExceptionBlock);
if ((clause.Flags & ExceptionHandlingClauseOptions.Fault) != 0) if ((clause.Flags & ExceptionHandlingClauseOptions.Fault) != 0)
catchInstruction.TryCatchOperation = AddEhHandler(clause.HandlerOffset, MsilTryCatchOperationType.BeginFaultBlock);
new MsilTryCatchOperation(MsilTryCatchOperationType.BeginFaultBlock);
else if ((clause.Flags & ExceptionHandlingClauseOptions.Finally) != 0) else if ((clause.Flags & ExceptionHandlingClauseOptions.Finally) != 0)
catchInstruction.TryCatchOperation = AddEhHandler(clause.HandlerOffset, MsilTryCatchOperationType.BeginFinallyBlock);
new MsilTryCatchOperation(MsilTryCatchOperationType.BeginFinallyBlock);
else else
catchInstruction.TryCatchOperation = AddEhHandler(clause.HandlerOffset, MsilTryCatchOperationType.BeginClauseBlock, clause.CatchType);
new MsilTryCatchOperation(MsilTryCatchOperationType.BeginClauseBlock, clause.CatchType); AddEhHandler(clause.HandlerOffset + clause.HandlerLength, MsilTryCatchOperationType.EndExceptionBlock);
finalInstruction.TryCatchOperation =
new MsilTryCatchOperation(MsilTryCatchOperationType.EndExceptionBlock);
} }
if (_dynamicExceptionTable != null)
foreach (var eh in _dynamicExceptionTable)
{
var catchCount = _exceptionHandlerGetCatchCount(eh);
var exTypes = _exceptionHandlerGetTypes(eh);
var exCatches = _exceptionHandlerGetCatchAddrs(eh);
var exCatchesEnd = _exceptionHandlerGetCatchEndAddrs(eh);
var exFilters = _exceptionHandlerGetFilterAddrs(eh);
var tryAddr = _exceptionHandlerGetStart(eh);
var endAddr = _exceptionHandlerGetEnd(eh);
var endFinallyAddr = _exceptionHandlerGetFinallyEnd(eh);
for (var i = 0; i < catchCount; i++)
{
var flags = (ExceptionHandlingClauseOptions) exTypes[i];
var endAddress = (flags & ExceptionHandlingClauseOptions.Finally) != 0 ? endFinallyAddr : endAddr;
var catchAddr = exCatches[i];
var catchEndAddr = exCatchesEnd[i];
var filterAddr = exFilters[i];
AddEhHandler(tryAddr, MsilTryCatchOperationType.BeginExceptionBlock);
if ((flags & ExceptionHandlingClauseOptions.Fault) != 0)
AddEhHandler(catchAddr, MsilTryCatchOperationType.BeginFaultBlock);
else if ((flags & ExceptionHandlingClauseOptions.Finally) != 0)
AddEhHandler(catchAddr, MsilTryCatchOperationType.BeginFinallyBlock);
else
AddEhHandler(catchAddr, MsilTryCatchOperationType.BeginClauseBlock);
AddEhHandler(catchEndAddr, MsilTryCatchOperationType.EndExceptionBlock);
}
}
}
private void AddEhHandler(int offset, MsilTryCatchOperationType op, Type type = null)
{
var instruction = FindInstruction(offset);
instruction.TryCatchOperations.Add(new MsilTryCatchOperation(op, type) {NativeOffset = offset});
instruction.TryCatchOperations.Sort((a, b) => a.NativeOffset.CompareTo(b.NativeOffset));
} }
public MsilInstruction FindInstruction(int offset) public MsilInstruction FindInstruction(int offset)
@@ -132,6 +196,7 @@ namespace Torch.Managers.PatchManager.Transpile
else else
max = mid; max = mid;
} }
return min >= 0 && min < _instructions.Count ? _instructions[min] : null; return min >= 0 && min < _instructions.Count ? _instructions[min] : null;
} }

View File

@@ -14,8 +14,8 @@ namespace Torch.Managers.PatchManager.Transpile
{ {
public static readonly Logger _log = LogManager.GetCurrentClassLogger(); public static readonly Logger _log = LogManager.GetCurrentClassLogger();
internal static IEnumerable<MsilInstruction> Transpile(MethodBase baseMethod, Func<Type, MsilLocal> localCreator, internal static IEnumerable<MsilInstruction> Transpile(MethodBase baseMethod, Func<Type, MsilLocal> localCreator, IEnumerable<MethodInfo> transpilers,
IEnumerable<MethodInfo> transpilers, MsilLabel retLabel) MsilLabel retLabel)
{ {
var context = new MethodContext(baseMethod); var context = new MethodContext(baseMethod);
context.Read(); context.Read();
@@ -24,8 +24,7 @@ namespace Torch.Managers.PatchManager.Transpile
} }
internal static IEnumerable<MsilInstruction> Transpile(MethodBase baseMethod, IEnumerable<MsilInstruction> methodContent, internal static IEnumerable<MsilInstruction> Transpile(MethodBase baseMethod, IEnumerable<MsilInstruction> methodContent,
Func<Type, MsilLocal> localCreator, Func<Type, MsilLocal> localCreator, IEnumerable<MethodInfo> transpilers, MsilLabel retLabel)
IEnumerable<MethodInfo> transpilers, MsilLabel retLabel)
{ {
foreach (MethodInfo transpiler in transpilers) foreach (MethodInfo transpiler in transpilers)
{ {
@@ -44,24 +43,61 @@ namespace Torch.Managers.PatchManager.Transpile
throw new ArgumentException( throw new ArgumentException(
$"Bad transpiler parameter type {parameter.ParameterType.FullName} {parameter.Name}"); $"Bad transpiler parameter type {parameter.ParameterType.FullName} {parameter.Name}");
} }
methodContent = (IEnumerable<MsilInstruction>)transpiler.Invoke(null, paramList.ToArray());
methodContent = (IEnumerable<MsilInstruction>) transpiler.Invoke(null, paramList.ToArray());
} }
return FixBranchAndReturn(methodContent, retLabel); return FixBranchAndReturn(methodContent, retLabel);
} }
internal static void EmitMethod(IReadOnlyList<MsilInstruction> instructions, LoggingIlGenerator target) internal static void EmitMethod(IReadOnlyList<MsilInstruction> source, LoggingIlGenerator target)
{ {
for (var i = 0; i < instructions.Count; i++) var instructions = source.ToArray();
var offsets = new int[instructions.Length];
// Calc worst case offsets
{
var j = 0;
for (var i = 0; i < instructions.Length; i++)
{
offsets[i] = j;
j += instructions[i].MaxBytes;
}
}
// Perform label markup
var targets = new Dictionary<MsilLabel, int>();
for (var i = 0; i < instructions.Length; i++)
foreach (var label in instructions[i].Labels)
{
if (targets.TryGetValue(label, out var other))
_log.Warn($"Label {label} is applied to ({i}: {instructions[i]}) and ({other}: {instructions[other]})");
targets[label] = i;
}
// Simplify branch instructions
for (var i = 0; i < instructions.Length; i++)
{
var existing = instructions[i];
if (existing.Operand is MsilOperandBrTarget brOperand && _longToShortBranch.TryGetValue(existing.OpCode, out var shortOpcode))
{
var targetIndex = targets[brOperand.Target];
var delta = offsets[targetIndex] - offsets[i];
if (sbyte.MinValue < delta && delta < sbyte.MaxValue)
instructions[i] = instructions[i].CopyWith(shortOpcode);
}
}
for (var i = 0; i < instructions.Length; i++)
{ {
MsilInstruction il = instructions[i]; MsilInstruction il = instructions[i];
if (il.TryCatchOperation != null) foreach (var tro in il.TryCatchOperations)
switch (il.TryCatchOperation.Type) switch (tro.Type)
{ {
case MsilTryCatchOperationType.BeginExceptionBlock: case MsilTryCatchOperationType.BeginExceptionBlock:
target.BeginExceptionBlock(); target.BeginExceptionBlock();
break; break;
case MsilTryCatchOperationType.BeginClauseBlock: case MsilTryCatchOperationType.BeginClauseBlock:
target.BeginCatchBlock(il.TryCatchOperation.CatchType); target.BeginCatchBlock(tro.CatchType);
break; break;
case MsilTryCatchOperationType.BeginFaultBlock: case MsilTryCatchOperationType.BeginFaultBlock:
target.BeginFaultBlock(); target.BeginFaultBlock();
@@ -79,20 +115,23 @@ namespace Torch.Managers.PatchManager.Transpile
foreach (MsilLabel label in il.Labels) foreach (MsilLabel label in il.Labels)
target.MarkLabel(label.LabelFor(target)); target.MarkLabel(label.LabelFor(target));
MsilInstruction ilNext = i < instructions.Count - 1 ? instructions[i + 1] : null; MsilInstruction ilNext = i < instructions.Length - 1 ? instructions[i + 1] : null;
// Leave opcodes emitted by these: // Leave opcodes emitted by these:
if (il.OpCode == OpCodes.Endfilter && ilNext?.TryCatchOperation?.Type == if (il.OpCode == OpCodes.Endfilter && ilNext != null &&
MsilTryCatchOperationType.BeginClauseBlock) ilNext.TryCatchOperations.Any(x => x.Type == MsilTryCatchOperationType.BeginClauseBlock))
continue; continue;
if ((il.OpCode == OpCodes.Leave || il.OpCode == OpCodes.Leave_S) && if ((il.OpCode == OpCodes.Leave || il.OpCode == OpCodes.Leave_S) && ilNext != null &&
(ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.EndExceptionBlock || ilNext.TryCatchOperations.Any(x => x.Type == MsilTryCatchOperationType.EndExceptionBlock ||
ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.BeginClauseBlock || x.Type == MsilTryCatchOperationType.BeginClauseBlock ||
ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.BeginFaultBlock || x.Type == MsilTryCatchOperationType.BeginFaultBlock ||
ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.BeginFinallyBlock)) x.Type == MsilTryCatchOperationType.BeginFinallyBlock))
continue; continue;
if ((il.OpCode == OpCodes.Leave || il.OpCode == OpCodes.Leave_S || il.OpCode == OpCodes.Endfinally) && if ((il.OpCode == OpCodes.Leave || il.OpCode == OpCodes.Leave_S || il.OpCode == OpCodes.Endfinally) &&
ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.EndExceptionBlock) ilNext != null && ilNext.TryCatchOperations.Any(x => x.Type == MsilTryCatchOperationType.EndExceptionBlock))
continue;
if (il.OpCode == OpCodes.Endfinally && ilNext != null &&
ilNext.TryCatchOperations.Any(x => x.Type == MsilTryCatchOperationType.EndExceptionBlock))
continue; continue;
if (il.Operand != null) if (il.Operand != null)
@@ -107,7 +146,7 @@ namespace Torch.Managers.PatchManager.Transpile
/// </summary> /// </summary>
/// <param name="level">default logging level</param> /// <param name="level">default logging level</param>
/// <param name="instructions">instructions</param> /// <param name="instructions">instructions</param>
public static void IntegrityAnalysis(LogLevel level, IReadOnlyList<MsilInstruction> instructions) public static void IntegrityAnalysis(PatchUtilities.DelPrintIntegrityInfo log, IReadOnlyList<MsilInstruction> instructions, bool offests = false)
{ {
var targets = new Dictionary<MsilLabel, int>(); var targets = new Dictionary<MsilLabel, int>();
for (var i = 0; i < instructions.Count; i++) for (var i = 0; i < instructions.Count; i++)
@@ -118,15 +157,46 @@ namespace Torch.Managers.PatchManager.Transpile
targets[label] = i; targets[label] = i;
} }
var simpleLabelNames = new Dictionary<MsilLabel, string>();
foreach (var lbl in targets.OrderBy(x => x.Value))
simpleLabelNames.Add(lbl.Key, "L" + simpleLabelNames.Count);
var reparsed = new HashSet<MsilLabel>(); var reparsed = new HashSet<MsilLabel>();
var labelStackSize = new Dictionary<MsilLabel, Dictionary<int, int>>(); var labelStackSize = new Dictionary<MsilLabel, Dictionary<int, int>>();
var stack = 0; var stack = 0;
var unreachable = false; var unreachable = false;
var data = new StringBuilder[instructions.Count]; var data = new StringBuilder[instructions.Count];
for (var i = 0; i < instructions.Count; i++) var tryCatchDepth = new int[instructions.Count];
for (var i = 0; i < instructions.Count - 1; i++)
{ {
var k = instructions[i];
var prevDepth = i > 0 ? tryCatchDepth[i] : 0;
var currentDepth = prevDepth;
foreach (var tro in k.TryCatchOperations)
if (tro.Type == MsilTryCatchOperationType.BeginExceptionBlock)
currentDepth++;
else if (tro.Type == MsilTryCatchOperationType.EndExceptionBlock)
currentDepth--;
tryCatchDepth[i + 1] = currentDepth;
}
for (var i = 0; i < instructions.Count; i++)
{
var tryCatchDepthSelf = tryCatchDepth[i];
var k = instructions[i]; var k = instructions[i];
var line = (data[i] ?? (data[i] = new StringBuilder())).Clear(); var line = (data[i] ?? (data[i] = new StringBuilder())).Clear();
foreach (var tro in k.TryCatchOperations)
{
if (tro.Type == MsilTryCatchOperationType.BeginExceptionBlock)
tryCatchDepthSelf++;
line.AppendLine($"{new string(' ', (tryCatchDepthSelf - 1) * 2)}// {tro.Type} ({tro.CatchType}) ({tro.NativeOffset:X4})");
if (tro.Type == MsilTryCatchOperationType.EndExceptionBlock)
tryCatchDepthSelf--;
}
var tryCatchIndent = new string(' ', tryCatchDepthSelf * 2);
if (!unreachable) if (!unreachable)
{ {
foreach (var label in k.Labels) foreach (var label in k.Labels)
@@ -138,59 +208,83 @@ namespace Torch.Managers.PatchManager.Transpile
if (otherStack.Values.Distinct().Count() > 1 || (otherStack.Count == 1 && !otherStack.ContainsValue(stack))) if (otherStack.Values.Distinct().Count() > 1 || (otherStack.Count == 1 && !otherStack.ContainsValue(stack)))
{ {
string otherDesc = string.Join(", ", otherStack.Select(x => $"{x.Key:X4}=>{x.Value}")); string otherDesc = string.Join(", ", otherStack.Select(x => $"{x.Key:X4}=>{x.Value}"));
line.AppendLine($"WARN// | Label {label} has multiple entry stack sizes ({otherDesc})"); line.AppendLine($"WARN{tryCatchIndent}// | Label {simpleLabelNames[label]} has multiple entry stack sizes ({otherDesc})");
} }
} }
} }
foreach (var label in k.Labels) foreach (var label in k.Labels)
{ {
if (!labelStackSize.TryGetValue(label, out var entry)) if (!labelStackSize.TryGetValue(label, out var entry))
{
line.AppendLine($"{tryCatchIndent}// \\/ Label {simpleLabelNames[label]}");
continue; continue;
}
string desc = string.Join(", ", entry.Select(x => $"{x.Key:X4}=>{x.Value}")); string desc = string.Join(", ", entry.Select(x => $"{x.Key:X4}=>{x.Value}"));
line.AppendLine($"// \\/ Label {label} has stack sizes {desc}"); line.AppendLine($"{tryCatchIndent}// \\/ Label {simpleLabelNames[label]} has stack sizes {desc}");
if (unreachable && entry.Any()) if (unreachable && entry.Any())
{ {
stack = entry.Values.First(); stack = entry.Values.First();
unreachable = false; unreachable = false;
} }
} }
stack += k.StackChange();
if (k.TryCatchOperation != null)
line.AppendLine($"// .{k.TryCatchOperation.Type} ({k.TryCatchOperation.CatchType})");
line.AppendLine($"{i:X4} S:{stack:D2} dS:{k.StackChange():+0;-#}\t{k}" + (unreachable ? "\t// UNREACHABLE" : ""));
if (k.Operand is MsilOperandBrTarget br)
{
if (!targets.ContainsKey(br.Target))
line.AppendLine($"WARN// ^ Unknown target {br.Target}");
if (!labelStackSize.TryGetValue(br.Target, out Dictionary<int, int> otherStack)) if (k.TryCatchOperations.Any(x => x.Type == MsilTryCatchOperationType.BeginClauseBlock))
labelStackSize[br.Target] = otherStack = new Dictionary<int, int>(); stack++; // Exception info
stack += k.StackChange();
line.Append($"{tryCatchIndent}{(offests ? k.Offset : i):X4} S:{stack:D2} dS:{k.StackChange():+0;-#}\t{k.OpCode}\t");
if (k.Operand is MsilOperandBrTarget bri)
line.Append(simpleLabelNames[bri.Target]);
else
line.Append(k.Operand);
line.AppendLine($"\t{(unreachable ? "\t// UNREACHABLE" : "")}");
MsilLabel[] branchTargets = null;
if (k.Operand is MsilOperandBrTarget br)
branchTargets = new[] {br.Target};
else if (k.Operand is MsilOperandSwitch swi)
branchTargets = swi.Labels;
if (branchTargets != null)
{
var foundUnprocessed = false;
foreach (var brTarget in branchTargets)
{
if (!labelStackSize.TryGetValue(brTarget, out Dictionary<int, int> otherStack))
labelStackSize[brTarget] = otherStack = new Dictionary<int, int>();
otherStack[i] = stack; otherStack[i] = stack;
if (otherStack.Values.Distinct().Count() > 1 || (otherStack.Count == 1 && !otherStack.ContainsValue(stack))) if (otherStack.Values.Distinct().Count() > 1 || (otherStack.Count == 1 && !otherStack.ContainsValue(stack)))
{ {
string otherDesc = string.Join(", ", otherStack.Select(x => $"{x.Key:X4}=>{x.Value}")); string otherDesc = string.Join(", ", otherStack.Select(x => $"{x.Key:X4}=>{x.Value}"));
line.AppendLine($"WARN// ^ Label {br.Target} has multiple entry stack sizes ({otherDesc})"); line.AppendLine($"WARN{tryCatchIndent}// ^ Label {simpleLabelNames[brTarget]} has multiple entry stack sizes ({otherDesc})");
} }
if (targets.TryGetValue(br.Target, out var target) && target < i && reparsed.Add(br.Target))
if (targets.TryGetValue(brTarget, out var target) && target < i && reparsed.Add(brTarget))
{ {
i = target - 1; i = target - 1;
unreachable = false; unreachable = false;
foundUnprocessed = true;
break;
}
}
if (foundUnprocessed)
continue; continue;
} }
}
if (k.OpCode == OpCodes.Br || k.OpCode == OpCodes.Br_S || k.OpCode == OpCodes.Leave || k.OpCode == OpCodes.Leave_S) if (k.OpCode == OpCodes.Br || k.OpCode == OpCodes.Br_S || k.OpCode == OpCodes.Leave || k.OpCode == OpCodes.Leave_S)
unreachable = true; unreachable = true;
} }
foreach (var k in data) foreach (var k in data)
foreach (var line in k.ToString().Split('\n')) foreach (var line in k.ToString().Split('\n'))
{ {
if (string.IsNullOrWhiteSpace(line)) if (string.IsNullOrWhiteSpace(line))
continue; continue;
if (line.StartsWith("WARN", StringComparison.OrdinalIgnoreCase)) if (line.StartsWith("WARN", StringComparison.OrdinalIgnoreCase))
_log.Warn(line.Substring(4).Trim()); log(true, line.Substring(4).Trim());
else else
_log.Log(level, line.Trim()); log(false, line.Trim('\n', '\r'));
} }
} }
@@ -204,7 +298,7 @@ namespace Torch.Managers.PatchManager.Transpile
_log.Trace($"Replacing {i} with {j}"); _log.Trace($"Replacing {i} with {j}");
yield return j; yield return j;
} }
else if (_opcodeReplaceRule.TryGetValue(i.OpCode, out OpCode replaceOpcode)) else if (_shortToLongBranch.TryGetValue(i.OpCode, out OpCode replaceOpcode))
{ {
var result = i.CopyWith(replaceOpcode); var result = i.CopyWith(replaceOpcode);
_log.Trace($"Replacing {i} with {result}"); _log.Trace($"Replacing {i} with {result}");
@@ -215,23 +309,31 @@ namespace Torch.Managers.PatchManager.Transpile
} }
} }
private static readonly Dictionary<OpCode, OpCode> _opcodeReplaceRule; private static readonly Dictionary<OpCode, OpCode> _shortToLongBranch;
private static readonly Dictionary<OpCode, OpCode> _longToShortBranch;
static MethodTranspiler() static MethodTranspiler()
{ {
_opcodeReplaceRule = new Dictionary<OpCode, OpCode>(); _shortToLongBranch = new Dictionary<OpCode, OpCode>();
_longToShortBranch = new Dictionary<OpCode, OpCode>();
foreach (var field in typeof(OpCodes).GetFields(BindingFlags.Static | BindingFlags.Public)) foreach (var field in typeof(OpCodes).GetFields(BindingFlags.Static | BindingFlags.Public))
{ {
var opcode = (OpCode)field.GetValue(null); var opcode = (OpCode) field.GetValue(null);
if (opcode.OperandType == OperandType.ShortInlineBrTarget && if (opcode.OperandType == OperandType.ShortInlineBrTarget &&
opcode.Name.EndsWith(".s", StringComparison.OrdinalIgnoreCase)) opcode.Name.EndsWith(".s", StringComparison.OrdinalIgnoreCase))
{ {
var other = (OpCode?)typeof(OpCodes).GetField(field.Name.Substring(0, field.Name.Length - 2), var other = (OpCode?) typeof(OpCodes).GetField(field.Name.Substring(0, field.Name.Length - 2),
BindingFlags.Static | BindingFlags.Public)?.GetValue(null); BindingFlags.Static | BindingFlags.Public)?.GetValue(null);
if (other.HasValue && other.Value.OperandType == OperandType.InlineBrTarget) if (other.HasValue && other.Value.OperandType == OperandType.InlineBrTarget)
_opcodeReplaceRule.Add(opcode, other.Value); {
_shortToLongBranch.Add(opcode, other.Value);
_longToShortBranch.Add(other.Value, opcode);
} }
} }
_opcodeReplaceRule[OpCodes.Leave_S] = OpCodes.Leave; }
_shortToLongBranch[OpCodes.Leave_S] = OpCodes.Leave;
_longToShortBranch[OpCodes.Leave] = OpCodes.Leave_S;
} }
} }
} }