diff --git a/Torch.Tests/PatchTest.cs b/Torch.Tests/PatchTest.cs new file mode 100644 index 0000000..1ee4225 --- /dev/null +++ b/Torch.Tests/PatchTest.cs @@ -0,0 +1,386 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Text; +using Torch.Managers.PatchManager; +using Torch.Managers.PatchManager.MSIL; +using Torch.Utils; +using Xunit; + +// ReSharper disable UnusedMember.Local +namespace Torch.Tests +{ +#pragma warning disable 414 + public class PatchTest + { + #region TestRunner + private static readonly PatchManager _patchContext = new PatchManager(null); + + [Theory] + [MemberData(nameof(Prefixes))] + public void TestPrefix(TestBootstrap runner) + { + runner.TestPrefix(); + } + + [Theory] + [MemberData(nameof(Transpilers))] + public void TestTranspile(TestBootstrap runner) + { + runner.TestTranspile(); + } + + [Theory] + [MemberData(nameof(Suffixes))] + public void TestSuffix(TestBootstrap runner) + { + runner.TestSuffix(); + } + + [Theory] + [MemberData(nameof(Combo))] + public void TestCombo(TestBootstrap runner) + { + runner.TestCombo(); + } + + + + public class TestBootstrap + { + public bool HasPrefix => _prefixMethod != null; + public bool HasTranspile => _transpileMethod != null; + public bool HasSuffix => _suffixMethod != null; + + private readonly MethodInfo _prefixMethod, _prefixAssert; + private readonly MethodInfo _suffixMethod, _suffixAssert; + private readonly MethodInfo _transpileMethod, _transpileAssert; + private readonly MethodInfo _targetMethod, _targetAssert; + private readonly MethodInfo _resetMethod; + private readonly object _instance; + private readonly object[] _targetParams; + private readonly Type _type; + + public TestBootstrap(Type t) + { + const BindingFlags flags = BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance; + _type = t; + _prefixMethod = t.GetMethod("Prefix", flags); + _prefixAssert = t.GetMethod("AssertPrefix", flags); + _suffixMethod = t.GetMethod("Suffix", flags); + _suffixAssert = t.GetMethod("AssertSuffix", flags); + _transpileMethod = t.GetMethod("Transpile", flags); + _transpileAssert = t.GetMethod("AssertTranspile", flags); + _targetMethod = t.GetMethod("Target", flags); + _targetAssert = t.GetMethod("AssertNormal", flags); + _resetMethod = t.GetMethod("Reset", flags); + if (_targetMethod == null) + throw new Exception($"{t.FullName} must have a method named Target"); + if (_targetAssert == null) + throw new Exception($"{t.FullName} must have a method named AssertNormal"); + _instance = !_targetMethod.IsStatic ? Activator.CreateInstance(t) : null; + _targetParams = (object[])t.GetField("_targetParams", flags)?.GetValue(null) ?? new object[0]; + } + + private void Invoke(MethodBase i, params object[] args) + { + if (i == null) return; + i.Invoke(i.IsStatic ? null : _instance, args); + } + + private void Invoke() + { + _targetMethod.Invoke(_instance, _targetParams); + Invoke(_targetAssert); + } + + public void TestPrefix() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Prefixes.Add(_prefixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_prefixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestSuffix() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Suffixes.Add(_suffixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_suffixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestTranspile() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Transpilers.Add(_transpileMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_transpileAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestCombo() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + if (_prefixMethod != null) + context.GetPattern(_targetMethod).Prefixes.Add(_prefixMethod); + if (_transpileMethod != null) + context.GetPattern(_targetMethod).Transpilers.Add(_transpileMethod); + if (_suffixMethod != null) + context.GetPattern(_targetMethod).Suffixes.Add(_suffixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_prefixAssert); + Invoke(_transpileAssert); + Invoke(_suffixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public override string ToString() + { + return _type.Name; + } + } + + private class PatchTestAttribute : Attribute + { + } + + private static readonly List _patchTest; + + static PatchTest() + { + TestUtils.Init(); + foreach (Type type in typeof(PatchManager).Assembly.GetTypes()) + if (type.Namespace?.StartsWith(typeof(PatchManager).Namespace ?? "") ?? false) + ReflectedManager.Process(type); + + _patchTest = new List(); + foreach (Type type in typeof(PatchTest).GetNestedTypes(BindingFlags.NonPublic)) + if (type.GetCustomAttribute(typeof(PatchTestAttribute)) != null) + _patchTest.Add(new TestBootstrap(type)); + } + + public static IEnumerable Prefixes => _patchTest.Where(x => x.HasPrefix).Select(x => new object[] { x }); + public static IEnumerable Transpilers => _patchTest.Where(x => x.HasTranspile).Select(x => new object[] { x }); + public static IEnumerable Suffixes => _patchTest.Where(x => x.HasSuffix).Select(x => new object[] { x }); + public static IEnumerable Combo => _patchTest.Where(x => x.HasPrefix || x.HasTranspile || x.HasSuffix).Select(x => new object[] { x }); + #endregion + + #region PatchTests + + [PatchTest] + private class StaticNoRetNoParm + { + private static bool _prefixHit, _normalHit, _suffixHit, _transpileHit; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix() + { + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target() + { + _normalHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Suffix() + { + _suffixHit = true; + } + + public static IEnumerable Transpile(IEnumerable instructions) + { + yield return new MsilInstruction(OpCodes.Ldnull); + yield return new MsilInstruction(OpCodes.Ldc_I4_1); + yield return new MsilInstruction(OpCodes.Stfld).InlineValue(typeof(StaticNoRetNoParm).GetField("_transpileHit", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public)); + foreach (MsilInstruction i in instructions) + yield return i; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = _transpileHit = false; + } + + public static void AssertTranspile() + { + Assert.True(_transpileHit, "Failed to transpile"); + } + + public static void AssertSuffix() + { + Assert.True(_suffixHit, "Failed to suffix"); + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + + [PatchTest] + private class StaticNoRetParam + { + private static bool _prefixHit, _normalHit, _suffixHit; + private static readonly object[] _targetParams = { "test", 1, new StringBuilder("test1") }; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix(string str, int i, StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target(string str, int i, StringBuilder o) + { + _normalHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Suffix(string str, int i, StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + _suffixHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertSuffix() + { + Assert.True(_suffixHit, "Failed to suffix"); + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + + [PatchTest] + private class StaticNoRetParamReplace + { + private static bool _prefixHit, _normalHit, _suffixHit; + private static readonly object[] _targetParams = { "test", 1, new StringBuilder("stest1") }; + private static readonly object[] _replacedParams = { "test2", 2, new StringBuilder("stest2") }; + private static object[] _calledParams; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix(ref string str, ref int i, ref StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + str = (string)_replacedParams[0]; + i = (int)_replacedParams[1]; + o = (StringBuilder)_replacedParams[2]; + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target(string str, int i, StringBuilder o) + { + _calledParams = new object[] { str, i, o }; + _normalHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + for (var i = 0; i < 3; i++) + Assert.Equal(_replacedParams[i], _calledParams[i]); + } + } + + [PatchTest] + private class StaticCancelExec + { + private static bool _prefixHit, _normalHit, _suffixHit; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static bool Prefix() + { + _prefixHit = true; + return false; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target() + { + _normalHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertNormal() + { + Assert.False(_normalHit, "Executed normally when canceled"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + #endregion + } +#pragma warning restore 414 +} diff --git a/Torch.Tests/Torch.Tests.csproj b/Torch.Tests/Torch.Tests.csproj index fc5f81a..d99bc0b 100644 --- a/Torch.Tests/Torch.Tests.csproj +++ b/Torch.Tests/Torch.Tests.csproj @@ -63,6 +63,7 @@ Properties\AssemblyVersion.cs + diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs index 8c03138..025b61a 100644 --- a/Torch/Managers/PatchManager/DecoratedMethod.cs +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -49,6 +49,7 @@ namespace Torch.Managers.PatchManager AssemblyMemory.WriteMemory(_revertAddress, _revertData); _revertData = null; _pinnedPatch.Value.Free(); + _pinnedPatch = null; } } diff --git a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs index a170d0c..e18eb30 100644 --- a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs +++ b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs @@ -37,7 +37,7 @@ namespace Torch.Managers.PatchManager.Transpile public LocalBuilder DeclareLocal(Type localType, bool isPinned = false) { LocalBuilder res = Backing.DeclareLocal(localType, isPinned); - _log.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); + _log?.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); return res; } @@ -45,70 +45,70 @@ namespace Torch.Managers.PatchManager.Transpile /// public void Emit(OpCode op) { - _log.Trace($"Emit\t{op,_opcodePadding}"); + _log?.Trace($"Emit\t{op,_opcodePadding}"); Backing.Emit(op); } /// public void Emit(OpCode op, LocalBuilder arg) { - _log.Trace($"Emit\t{op,_opcodePadding} L:{arg.LocalIndex} {arg.LocalType}"); + _log?.Trace($"Emit\t{op,_opcodePadding} L:{arg.LocalIndex} {arg.LocalType}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, int arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, long arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, float arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, double arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, string arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Type arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, FieldInfo arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, MethodInfo arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } @@ -121,35 +121,35 @@ namespace Torch.Managers.PatchManager.Transpile /// public void Emit(OpCode op, Label arg) { - _log.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); + _log?.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Label[] arg) { - _log.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); + _log?.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, SignatureHelper arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, ConstructorInfo arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void MarkLabel(Label label) { - _log.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); + _log?.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); Backing.MarkLabel(label); } @@ -166,7 +166,7 @@ namespace Torch.Managers.PatchManager.Transpile [Conditional("DEBUG")] public void EmitComment(string comment) { - _log.Trace($"// {comment}"); + _log?.Trace($"// {comment}"); } } #pragma warning restore 162