From 9c3a22c556a41d3652f0d6d7a0b7e00d8fa714d2 Mon Sep 17 00:00:00 2001 From: Westin Miller Date: Sat, 9 Sep 2017 00:27:43 -0700 Subject: [PATCH 1/5] Method patching framework --- Torch/Managers/PatchManager/AssemblyMemory.cs | 99 ++++++ .../Managers/PatchManager/DecoratedMethod.cs | 193 ++++++++++++ Torch/Managers/PatchManager/EmitExtensions.cs | 75 +++++ .../PatchManager/MSIL/ITokenResolver.cs | 226 ++++++++++++++ .../PatchManager/MSIL/MsilInstruction.cs | 162 ++++++++++ Torch/Managers/PatchManager/MSIL/MsilLabel.cs | 67 ++++ .../Managers/PatchManager/MSIL/MsilOperand.cs | 25 ++ .../PatchManager/MSIL/MsilOperandBrTarget.cs | 40 +++ .../PatchManager/MSIL/MsilOperandInline.cs | 292 ++++++++++++++++++ .../PatchManager/MSIL/MsilOperandSwitch.cs | 36 +++ .../PatchManager/MethodRewritePattern.cs | 172 +++++++++++ Torch/Managers/PatchManager/PatchContext.cs | 44 +++ Torch/Managers/PatchManager/PatchManager.cs | 86 ++++++ .../PatchManager/PatchPriorityAttribute.cs | 24 ++ .../Transpile/LoggingILGenerator.cs | 173 +++++++++++ .../PatchManager/Transpile/MethodContext.cs | 110 +++++++ .../Transpile/MethodTranspiler.cs | 61 ++++ Torch/Torch.csproj | 19 ++ 18 files changed, 1904 insertions(+) create mode 100644 Torch/Managers/PatchManager/AssemblyMemory.cs create mode 100644 Torch/Managers/PatchManager/DecoratedMethod.cs create mode 100644 Torch/Managers/PatchManager/EmitExtensions.cs create mode 100644 Torch/Managers/PatchManager/MSIL/ITokenResolver.cs create mode 100644 Torch/Managers/PatchManager/MSIL/MsilInstruction.cs create mode 100644 Torch/Managers/PatchManager/MSIL/MsilLabel.cs create mode 100644 Torch/Managers/PatchManager/MSIL/MsilOperand.cs create mode 100644 Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs create mode 100644 Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs create mode 100644 Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs create mode 100644 Torch/Managers/PatchManager/MethodRewritePattern.cs create mode 100644 Torch/Managers/PatchManager/PatchContext.cs create mode 100644 Torch/Managers/PatchManager/PatchManager.cs create mode 100644 Torch/Managers/PatchManager/PatchPriorityAttribute.cs create mode 100644 Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs create mode 100644 Torch/Managers/PatchManager/Transpile/MethodContext.cs create mode 100644 Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs diff --git a/Torch/Managers/PatchManager/AssemblyMemory.cs b/Torch/Managers/PatchManager/AssemblyMemory.cs new file mode 100644 index 0000000..0998384 --- /dev/null +++ b/Torch/Managers/PatchManager/AssemblyMemory.cs @@ -0,0 +1,99 @@ +using System; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Torch.Managers.PatchManager +{ + internal class AssemblyMemory + { + /// + /// Gets the address, in RAM, where the body of a method starts. + /// + /// Method to find the start of + /// Address of the method's start + public static long GetMethodBodyStart(MethodBase method) + { + RuntimeMethodHandle handle; + if (method is DynamicMethod) + handle = (RuntimeMethodHandle)typeof(DynamicMethod).GetMethod("GetMethodDescriptor", BindingFlags.NonPublic | BindingFlags.Instance) + .Invoke(method, new object[0]); + else + handle = method.MethodHandle; + RuntimeHelpers.PrepareMethod(handle); + return handle.GetFunctionPointer().ToInt64(); + } + + + + // x64 ISA format: + // [prefixes] [opcode] [mod-r/m] + // [mod-r/m] is bitfield: + // [7-6] = "mod" adressing mode + // [5-3] = register or opcode extension + // [2-0] = "r/m" extra addressing mode + + + // http://ref.x86asm.net/coder64.html + /// Direct register addressing mode. (Jump directly to register) + private const byte MODRM_MOD_DIRECT = 0b11; + + /// Long-mode prefix (64-bit operand) + private const byte REX_W = 0x48; + + /// Moves a 16/32/64 operand into register i when opcode is (MOV_R0+i) + private const byte MOV_R0 = 0xB8; + + // Extra opcodes. Used with opcode extension. + private const byte EXT = 0xFF; + + /// Opcode extension used with for the JMP opcode. + private const byte OPCODE_EXTENSION_JMP = 4; + + + /// + /// Reads a byte array from a memory location + /// + /// Address to read from + /// Number of bytes to read + /// The bytes that were read + public static byte[] ReadMemory(long memory, int bytes) + { + var data = new byte[bytes]; + Marshal.Copy(new IntPtr(memory), data,0, bytes); + return data; + } + + /// + /// Writes a byte array to a memory location. + /// + /// Address to write to + /// Data to write + public static void WriteMemory(long memory, byte[] bytes) + { + Marshal.Copy(bytes,0, new IntPtr(memory), bytes.Length); + } + + /// + /// Writes an x64 assembly jump instruction at the given address. + /// + /// Address to write the instruction at + /// Target address of the jump + /// The bytes that were overwritten + public static byte[] WriteJump(long memory, long jumpTarget) + { + byte[] result = ReadMemory(memory, 12); + unsafe + { + var ptr = (byte*)memory; + *ptr = REX_W; + *(ptr + 1) = MOV_R0; + *((long*)(ptr + 2)) = jumpTarget; + *(ptr + 10) = EXT; + *(ptr + 11) = (MODRM_MOD_DIRECT << 6) | (OPCODE_EXTENSION_JMP << 3) | 0; + } + return result; + } + } +} diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs new file mode 100644 index 0000000..326f004 --- /dev/null +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -0,0 +1,193 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using NLog; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager +{ + internal class DecoratedMethod : MethodRewritePattern + { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + private readonly MethodBase _method; + + internal DecoratedMethod(MethodBase method) : base(null) + { + _method = method; + } + + private long _revertAddress; + private byte[] _revertData = null; + private GCHandle? _pinnedPatch; + + internal void Commit() + { + if (!Prefixes.HasChanges() && !Suffixes.HasChanges() && !Transpilers.HasChanges()) + return; + Revert(); + + if (Prefixes.Count == 0 && Suffixes.Count == 0 && Transpilers.Count == 0) + return; + var patch = ComposePatchedMethod(); + + _revertAddress = AssemblyMemory.GetMethodBodyStart(_method); + var newAddress = AssemblyMemory.GetMethodBodyStart(patch); + _revertData = AssemblyMemory.WriteJump(_revertAddress, newAddress); + _pinnedPatch = GCHandle.Alloc(patch); + } + + internal void Revert() + { + if (_pinnedPatch.HasValue) + { + AssemblyMemory.WriteMemory(_revertAddress, _revertData); + _revertData = null; + _pinnedPatch.Value.Free(); + } + } + + #region Create + private int _patchSalt = 0; + private DynamicMethod AllocatePatchMethod() + { + Debug.Assert(_method.DeclaringType != null); + var methodName = _method.Name + $"_{_patchSalt++}"; + var returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void); + var parameters = _method.GetParameters(); + var parameterTypes = (_method.IsStatic ? Enumerable.Empty() : new[] { typeof(object) }) + .Concat(parameters.Select(x => x.ParameterType)).ToArray(); + + var patchMethod = new DynamicMethod(methodName, MethodAttributes.Public | MethodAttributes.Static, CallingConventions.Standard, + returnType, parameterTypes, _method.DeclaringType, true); + if (!_method.IsStatic) + patchMethod.DefineParameter(0, ParameterAttributes.None, INSTANCE_PARAMETER); + for (var i = 0; i < parameters.Length; i++) + patchMethod.DefineParameter((patchMethod.IsStatic ? 0 : 1) + i, parameters[i].Attributes, parameters[i].Name); + + return patchMethod; + } + + + private const string INSTANCE_PARAMETER = "__instance"; + private const string RESULT_PARAMETER = "__result"; + + public DynamicMethod ComposePatchedMethod() + { + var method = AllocatePatchMethod(); + var generator = new LoggingIlGenerator(method.GetILGenerator()); + EmitPatched(generator); + + // Force it to compile + const BindingFlags nonPublicInstance = BindingFlags.NonPublic | BindingFlags.Instance; + const BindingFlags nonPublicStatic = BindingFlags.NonPublic | BindingFlags.Static; + var compileMethod = typeof(RuntimeHelpers).GetMethod("_CompileMethod", nonPublicStatic); + var getMethodDescriptor = typeof(DynamicMethod).GetMethod("GetMethodDescriptor", nonPublicInstance); + var handle = (RuntimeMethodHandle)getMethodDescriptor.Invoke(method, new object[0]); + var getMethodInfo = typeof(RuntimeMethodHandle).GetMethod("GetMethodInfo", nonPublicInstance); + var runtimeMethodInfo = getMethodInfo.Invoke(handle, new object[0]); + compileMethod.Invoke(null, new[] { runtimeMethodInfo }); + return method; + } + #endregion + + #region Emit + private void EmitPatched(LoggingIlGenerator target) + { + var originalLocalVariables = _method.GetMethodBody().LocalVariables + .Select(x => + { + Debug.Assert(x.LocalType != null); + return target.DeclareLocal(x.LocalType, x.IsPinned); + }).ToArray(); + + var specialVariables = new Dictionary(); + + var returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void); + var resultVariable = returnType != typeof(void) && Prefixes.Concat(Suffixes).SelectMany(x => x.GetParameters()).Any(x => x.Name == RESULT_PARAMETER) + ? target.DeclareLocal(returnType) + : null; + resultVariable?.SetToDefault(target); + + if (resultVariable != null) + specialVariables.Add(RESULT_PARAMETER, resultVariable); + + var labelAfterOriginalContent = target.DefineLabel(); + var labelAfterOriginalReturn = target.DefineLabel(); + + foreach (var prefix in Prefixes) + { + EmitMonkeyCall(target, prefix, specialVariables); + if (prefix.ReturnType == typeof(bool)) + target.Emit(OpCodes.Brfalse, labelAfterOriginalReturn); + else if (prefix.ReturnType != typeof(void)) + throw new Exception($"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}"); + } + + MethodTranspiler.Transpile(_method, Transpilers, target, labelAfterOriginalContent); + target.MarkLabel(labelAfterOriginalContent); + if (resultVariable != null) + target.Emit(OpCodes.Stloc, resultVariable); + target.MarkLabel(labelAfterOriginalReturn); + + foreach (var suffix in Suffixes) + { + EmitMonkeyCall(target, suffix, specialVariables); + if (suffix.ReturnType != typeof(void)) + throw new Exception($"Suffixes must return void. {suffix.DeclaringType?.FullName}.{suffix.Name} returns {suffix.ReturnType}"); + } + + if (resultVariable != null) + target.Emit(OpCodes.Ldloc, resultVariable); + target.Emit(OpCodes.Ret); + } + + private void EmitMonkeyCall(LoggingIlGenerator target, MethodInfo patch, + IReadOnlyDictionary specialVariables) + { + foreach (var param in patch.GetParameters()) + { + switch (param.Name) + { + case INSTANCE_PARAMETER: + if (_method.IsStatic) + throw new Exception("Can't use an instance parameter for a static method"); + target.Emit(OpCodes.Ldarg_0); + break; + case RESULT_PARAMETER: + var retType = param.ParameterType.IsByRef + ? param.ParameterType.GetElementType() + : param.ParameterType; + if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].LocalType)) + throw new Exception($"Return type {specialVariables[RESULT_PARAMETER].LocalType} can't be assigned to result parameter type {retType}"); + target.Emit(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc, specialVariables[RESULT_PARAMETER]); + break; + default: + var declParam = _method.GetParameters().FirstOrDefault(x => x.Name == param.Name); + if (declParam == null) + throw new Exception($"Parameter name {param.Name} not found"); + var paramIdx = (_method.IsStatic ? 0 : 1) + declParam.Position; + + var patchByRef = param.IsOut || param.ParameterType.IsByRef; + var declByRef = declParam.IsOut || declParam.ParameterType.IsByRef; + if (patchByRef == declByRef) + target.Emit(OpCodes.Ldarg, paramIdx); + else if (patchByRef) + target.Emit(OpCodes.Ldarga, paramIdx); + else + { + target.Emit(OpCodes.Ldarg, paramIdx); + target.EmitDereference(declParam.ParameterType); + } + break; + } + } + target.Emit(OpCodes.Call, patch); + } + #endregion + } +} diff --git a/Torch/Managers/PatchManager/EmitExtensions.cs b/Torch/Managers/PatchManager/EmitExtensions.cs new file mode 100644 index 0000000..90f1ffd --- /dev/null +++ b/Torch/Managers/PatchManager/EmitExtensions.cs @@ -0,0 +1,75 @@ +using System; +using System.Diagnostics; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager +{ + internal static class EmitExtensions + { + /// + /// Sets the given local to its default value in the given IL generator. + /// + /// Local to set to default + /// The IL generator + public static void SetToDefault(this LocalBuilder local, LoggingIlGenerator target) + { + Debug.Assert(local.LocalType != null); + if (local.LocalType.IsEnum || local.LocalType.IsPrimitive) + { + if (local.LocalType == typeof(float)) + target.Emit(OpCodes.Ldc_R4, 0f); + else if (local.LocalType == typeof(double)) + target.Emit(OpCodes.Ldc_R8, 0d); + else if (local.LocalType == typeof(long) || local.LocalType == typeof(ulong)) + target.Emit(OpCodes.Ldc_I8, 0L); + else + target.Emit(OpCodes.Ldc_I4, 0); + target.Emit(OpCodes.Stloc, local); + } + else if (local.LocalType.IsValueType) // struct + { + target.Emit(OpCodes.Ldloca, local); + target.Emit(OpCodes.Initobj, local.LocalType); + } + else // class + { + target.Emit(OpCodes.Ldnull); + target.Emit(OpCodes.Stloc, local); + } + } + + /// + /// Emits a dereference for the given type. + /// + /// IL Generator to emit on + /// Type to dereference + public static void EmitDereference(this LoggingIlGenerator target, Type type) + { + if (type.IsByRef) + type = type.GetElementType(); + Debug.Assert(type != null); + + if (type == typeof(float)) + target.Emit(OpCodes.Ldind_R4); + else if (type == typeof(double)) + target.Emit(OpCodes.Ldind_R8); + else if (type == typeof(byte)) + target.Emit(OpCodes.Ldind_U1); + else if (type == typeof(ushort) || type == typeof(char)) + target.Emit(OpCodes.Ldind_U2); + else if (type == typeof(uint)) + target.Emit(OpCodes.Ldind_U4); + else if (type == typeof(sbyte)) + target.Emit(OpCodes.Ldind_I1); + else if (type == typeof(short)) + target.Emit(OpCodes.Ldind_I2); + else if (type == typeof(int) || type.IsEnum) + target.Emit(OpCodes.Ldind_I4); + else if (type == typeof(long) || type == typeof(ulong)) + target.Emit(OpCodes.Ldind_I8); + else + target.Emit(OpCodes.Ldind_Ref); + } + } +} diff --git a/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs b/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs new file mode 100644 index 0000000..ffaac0a --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs @@ -0,0 +1,226 @@ +using System; +using System.Reflection; +using System.Reflection.Emit; + +namespace Torch.Managers.PatchManager.MSIL +{ + //https://stackoverflow.com/questions/4148297/resolving-the-tokens-found-in-the-il-from-a-dynamic-method/35711376#35711376 + internal interface ITokenResolver + { + MemberInfo ResolveMember(int token); + Type ResolveType(int token); + FieldInfo ResolveField(int token); + MethodBase ResolveMethod(int token); + byte[] ResolveSignature(int token); + string ResolveString(int token); + } + + internal sealed class NormalTokenResolver : ITokenResolver + { + private readonly Type[] _genericTypeArgs, _genericMethArgs; + private readonly Module _module; + + internal NormalTokenResolver(MethodBase method) + { + _module = method.Module; + _genericTypeArgs = method.DeclaringType?.GenericTypeArguments ?? new Type[0]; + _genericMethArgs = method.GetGenericArguments(); + } + + public MemberInfo ResolveMember(int token) + { + return _module.ResolveMember(token, _genericTypeArgs, _genericMethArgs); + } + + public Type ResolveType(int token) + { + return _module.ResolveType(token, _genericTypeArgs, _genericMethArgs); + } + + public FieldInfo ResolveField(int token) + { + return _module.ResolveField(token, _genericTypeArgs, _genericMethArgs); + } + + public MethodBase ResolveMethod(int token) + { + return _module.ResolveMethod(token, _genericTypeArgs, _genericMethArgs); + } + + public byte[] ResolveSignature(int token) + { + return _module.ResolveSignature(token); + } + + public string ResolveString(int token) + { + return _module.ResolveString(token); + } + } + + internal sealed class NullTokenResolver : ITokenResolver + { + internal static readonly NullTokenResolver Instance = new NullTokenResolver(); + + private NullTokenResolver() + { + } + + public MemberInfo ResolveMember(int token) + { + return null; + } + + public Type ResolveType(int token) + { + return null; + } + + public FieldInfo ResolveField(int token) + { + return null; + } + + public MethodBase ResolveMethod(int token) + { + return null; + } + + public byte[] ResolveSignature(int token) + { + return null; + } + + public string ResolveString(int token) + { + return null; + } + } + + internal sealed class DynamicMethodTokenResolver : ITokenResolver + { + private readonly MethodInfo _getFieldInfo; + private readonly MethodInfo _getMethodBase; + private readonly GetTypeFromHandleUnsafe _getTypeFromHandleUnsafe; + private readonly ConstructorInfo _runtimeFieldHandleStubCtor; + private readonly ConstructorInfo _runtimeMethodHandleInternalCtor; + private readonly SignatureResolver _signatureResolver; + private readonly StringResolver _stringResolver; + + private readonly TokenResolver _tokenResolver; + + public DynamicMethodTokenResolver(DynamicMethod dynamicMethod) + { + object resolver = typeof(DynamicMethod) + .GetField("m_resolver", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(dynamicMethod); + if (resolver == null) throw new ArgumentException("The dynamic method's IL has not been finalized."); + + _tokenResolver = (TokenResolver) resolver.GetType() + .GetMethod("ResolveToken", BindingFlags.Instance | BindingFlags.NonPublic) + .CreateDelegate(typeof(TokenResolver), resolver); + _stringResolver = (StringResolver) resolver.GetType() + .GetMethod("GetStringLiteral", BindingFlags.Instance | BindingFlags.NonPublic) + .CreateDelegate(typeof(StringResolver), resolver); + _signatureResolver = (SignatureResolver) resolver.GetType() + .GetMethod("ResolveSignature", BindingFlags.Instance | BindingFlags.NonPublic) + .CreateDelegate(typeof(SignatureResolver), resolver); + + _getTypeFromHandleUnsafe = (GetTypeFromHandleUnsafe) typeof(Type) + .GetMethod("GetTypeFromHandleUnsafe", BindingFlags.Static | BindingFlags.NonPublic, null, + new[] {typeof(IntPtr)}, null).CreateDelegate(typeof(GetTypeFromHandleUnsafe), null); + Type runtimeType = typeof(RuntimeTypeHandle).Assembly.GetType("System.RuntimeType"); + + Type runtimeMethodHandleInternal = + typeof(RuntimeTypeHandle).Assembly.GetType("System.RuntimeMethodHandleInternal"); + _getMethodBase = runtimeType.GetMethod("GetMethodBase", BindingFlags.Static | BindingFlags.NonPublic, null, + new[] {runtimeType, runtimeMethodHandleInternal}, null); + _runtimeMethodHandleInternalCtor = + runtimeMethodHandleInternal.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic, null, + new[] {typeof(IntPtr)}, null); + + Type runtimeFieldInfoStub = typeof(RuntimeTypeHandle).Assembly.GetType("System.RuntimeFieldInfoStub"); + _runtimeFieldHandleStubCtor = + runtimeFieldInfoStub.GetConstructor(BindingFlags.Instance | BindingFlags.Public, null, + new[] {typeof(IntPtr), typeof(object)}, null); + _getFieldInfo = runtimeType.GetMethod("GetFieldInfo", BindingFlags.Static | BindingFlags.NonPublic, null, + new[] {runtimeType, typeof(RuntimeTypeHandle).Assembly.GetType("System.IRuntimeFieldInfo")}, null); + } + + public Type ResolveType(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + return _getTypeFromHandleUnsafe.Invoke(typeHandle); + } + + public MethodBase ResolveMethod(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + return (MethodBase) _getMethodBase.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeMethodHandleInternalCtor.Invoke(new object[] {methodHandle}) + }); + } + + public FieldInfo ResolveField(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + return (FieldInfo) _getFieldInfo.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeFieldHandleStubCtor.Invoke(new object[] {fieldHandle, null}) + }); + } + + public MemberInfo ResolveMember(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + if (methodHandle != IntPtr.Zero) + return (MethodBase) _getMethodBase.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeMethodHandleInternalCtor.Invoke(new object[] {methodHandle}) + }); + + if (fieldHandle != IntPtr.Zero) + return (FieldInfo) _getFieldInfo.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeFieldHandleStubCtor.Invoke(new object[] {fieldHandle, null}) + }); + + if (typeHandle != IntPtr.Zero) + return _getTypeFromHandleUnsafe.Invoke(typeHandle); + + throw new NotImplementedException( + "DynamicMethods are not able to reference members by token other than types, methods and fields."); + } + + public byte[] ResolveSignature(int token) + { + return _signatureResolver.Invoke(token, 0); + } + + public string ResolveString(int token) + { + return _stringResolver.Invoke(token); + } + + private delegate void TokenResolver(int token, out IntPtr typeHandle, out IntPtr methodHandle, + out IntPtr fieldHandle); + + private delegate string StringResolver(int token); + + private delegate byte[] SignatureResolver(int token, int fromMethod); + + private delegate Type GetTypeFromHandleUnsafe(IntPtr handle); + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs new file mode 100644 index 0000000..a179a6e --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs @@ -0,0 +1,162 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Reflection.Emit; +using System.Text; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents a single MSIL instruction, and its operand + /// + public class MsilInstruction + { + private MsilOperand _operandBacking; + + /// + /// Creates a new instruction with the given opcode. + /// + /// Opcode + public MsilInstruction(OpCode opcode) + { + OpCode = opcode; + switch (opcode.OperandType) + { + case OperandType.InlineNone: + Operand = null; + break; + case OperandType.ShortInlineBrTarget: + case OperandType.InlineBrTarget: + Operand = new MsilOperandBrTarget(this); + break; + case OperandType.InlineField: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.InlineI: + Operand = new MsilOperandInline.MsilOperandInt32(this); + break; + case OperandType.InlineI8: + Operand = new MsilOperandInline.MsilOperandInt64(this); + break; + case OperandType.InlineMethod: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.InlineR: + Operand = new MsilOperandInline.MsilOperandDouble(this); + break; + case OperandType.InlineSig: + Operand = new MsilOperandInline.MsilOperandSignature(this); + break; + case OperandType.InlineString: + Operand = new MsilOperandInline.MsilOperandString(this); + break; + case OperandType.InlineSwitch: + Operand = new MsilOperandSwitch(this); + break; + case OperandType.InlineTok: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.InlineType: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.ShortInlineVar: + case OperandType.InlineVar: + if (OpCode.Name.IndexOf("loc", StringComparison.OrdinalIgnoreCase) != -1) + Operand = new MsilOperandInline.MsilOperandLocal(this); + else + Operand = new MsilOperandInline.MsilOperandParameter(this); + break; + case OperandType.ShortInlineI: + Operand = OpCode == OpCodes.Ldc_I4_S + ? (MsilOperand) new MsilOperandInline.MsilOperandInt8(this) + : new MsilOperandInline.MsilOperandUInt8(this); + break; + case OperandType.ShortInlineR: + Operand = new MsilOperandInline.MsilOperandSingle(this); + break; +#pragma warning disable 618 + case OperandType.InlinePhi: +#pragma warning restore 618 + default: + throw new ArgumentOutOfRangeException(); + } + } + + /// + /// Opcode of this instruction + /// + public OpCode OpCode { get; } + + /// + /// Raw memory offset of this instruction; optional. + /// + public int Offset { get; internal set; } + + /// + /// The operand for this instruction, or null. + /// + public MsilOperand Operand + { + get => _operandBacking; + set + { + if (_operandBacking != null && value.GetType() != _operandBacking.GetType()) + throw new ArgumentException($"Operand for {OpCode.Name} must be {_operandBacking.GetType().Name}"); + _operandBacking = value; + } + } + + /// + /// Labels pointing to this instruction. + /// + public HashSet Labels { get; } = new HashSet(); + + /// + /// Sets the inline value for this instruction. + /// + /// The type of the inline constraint + /// Value + /// This instruction + public MsilInstruction InlineValue(T o) + { + ((MsilOperandInline) Operand).Value = o; + return this; + } + + /// + /// Sets the inline branch target for this instruction. + /// + /// Target to jump to + /// This instruction + public MsilInstruction InlineTarget(MsilLabel label) + { + ((MsilOperandBrTarget) Operand).Target = label; + return this; + } + + /// + /// Emits this instruction to the given generator + /// + /// Emit target + public void Emit(LoggingIlGenerator target) + { + foreach (MsilLabel label in Labels) + target.MarkLabel(label.LabelFor(target)); + if (Operand != null) + Operand.Emit(target); + else + target.Emit(OpCode); + } + + /// + public override string ToString() + { + var sb = new StringBuilder(); + foreach (MsilLabel label in Labels) + sb.Append(label).Append(": "); + sb.Append(OpCode.Name).Append("\t").Append(Operand); + return sb.ToString(); + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilLabel.cs b/Torch/Managers/PatchManager/MSIL/MsilLabel.cs new file mode 100644 index 0000000..2111d27 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilLabel.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents an abstract label, identified by its reference. + /// + public class MsilLabel + { + private readonly List, Label>> _labelInstances = + new List, Label>>(); + + private readonly Label? _overrideLabel; + + /// + /// Creates an empty label the allocates a new when requested. + /// + public MsilLabel() + { + _overrideLabel = null; + } + + /// + /// Creates a label the always supplies the given + /// + public MsilLabel(Label overrideLabel) + { + _overrideLabel = overrideLabel; + } + + /// + /// Creates a label that supplies the given when a label for the given generator is requested, + /// otherwise it creates a new label. + /// + /// Generator to register the label on + /// Label to register + public MsilLabel(LoggingIlGenerator generator, Label label) + { + _labelInstances.Add( + new KeyValuePair, Label>( + new WeakReference(generator), label)); + } + + internal Label LabelFor(LoggingIlGenerator gen) + { + if (_overrideLabel.HasValue) + return _overrideLabel.Value; + foreach (KeyValuePair, Label> kv in _labelInstances) + if (kv.Key.TryGetTarget(out LoggingIlGenerator gen2) && gen2 == gen) + return kv.Value; + Label label = gen.DefineLabel(); + _labelInstances.Add( + new KeyValuePair, Label>(new WeakReference(gen), + label)); + return label; + } + + /// + public override string ToString() + { + return $"L{GetHashCode() & 0xFFFF:X4}"; + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperand.cs b/Torch/Managers/PatchManager/MSIL/MsilOperand.cs new file mode 100644 index 0000000..d9dbf2f --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperand.cs @@ -0,0 +1,25 @@ +using System.IO; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents an operand for a MSIL instruction + /// + public abstract class MsilOperand + { + protected MsilOperand(MsilInstruction instruction) + { + Instruction = instruction; + } + + /// + /// Instruction this operand is associated with + /// + public MsilInstruction Instruction { get; } + + internal abstract void Read(MethodContext context, BinaryReader reader); + + internal abstract void Emit(LoggingIlGenerator generator); + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs new file mode 100644 index 0000000..01e0913 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs @@ -0,0 +1,40 @@ +using System.IO; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents a branch target operand. + /// + public class MsilOperandBrTarget : MsilOperand + { + internal MsilOperandBrTarget(MsilInstruction instruction) : base(instruction) + { + } + + /// + /// Branch target + /// + public MsilLabel Target { get; set; } + + internal override void Read(MethodContext context, BinaryReader reader) + { + int val = Instruction.OpCode.OperandType == OperandType.InlineBrTarget + ? reader.ReadInt32() + : reader.ReadByte(); + Target = context.LabelAt((int) reader.BaseStream.Position + val); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Target.LabelFor(generator)); + } + + /// + public override string ToString() + { + return Target?.ToString() ?? "null"; + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs new file mode 100644 index 0000000..f4c45d1 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs @@ -0,0 +1,292 @@ +using System; +using System.IO; +using System.Reflection; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents an inline value + /// + /// The type of the inline value + public abstract class MsilOperandInline : MsilOperand + { + internal MsilOperandInline(MsilInstruction instruction) : base(instruction) + { + } + + /// + /// Inline value + /// + public T Value { get; set; } + + /// + public override string ToString() + { + return Value?.ToString() ?? "null"; + } + } + + /// + /// Registry of different inline operand types + /// + public static class MsilOperandInline + { + /// + /// Inline unsigned byte + /// + public class MsilOperandUInt8 : MsilOperandInline + { + internal MsilOperandUInt8(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadByte(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline signed byte + /// + public class MsilOperandInt8 : MsilOperandInline + { + internal MsilOperandInt8(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + (sbyte) reader.ReadByte(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline integer + /// + public class MsilOperandInt32 : MsilOperandInline + { + internal MsilOperandInt32(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadInt32(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline single + /// + public class MsilOperandSingle : MsilOperandInline + { + internal MsilOperandSingle(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadSingle(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline double + /// + public class MsilOperandDouble : MsilOperandInline + { + internal MsilOperandDouble(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadDouble(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline long + /// + public class MsilOperandInt64 : MsilOperandInline + { + internal MsilOperandInt64(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadInt64(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline string + /// + public class MsilOperandString : MsilOperandInline + { + internal MsilOperandString(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + context.TokenResolver.ResolveString(reader.ReadInt32()); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline CLR signature + /// + public class MsilOperandSignature : MsilOperandInline + { + internal MsilOperandSignature(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + byte[] sig = context.TokenResolver + .ResolveSignature(reader.ReadInt32()); + throw new ArgumentException("Can't figure out how to convert this."); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline parameter reference + /// + public class MsilOperandParameter : MsilOperandInline + { + internal MsilOperandParameter(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + context.Method.GetParameters()[ + Instruction.OpCode.OperandType == OperandType.ShortInlineVar + ? reader.ReadByte() + : reader.ReadUInt16()]; + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value.Position); + } + } + + /// + /// Inline local variable reference + /// + public class MsilOperandLocal : MsilOperandInline + { + internal MsilOperandLocal(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + context.Method.GetMethodBody().LocalVariables[ + Instruction.OpCode.OperandType == OperandType.ShortInlineVar + ? reader.ReadByte() + : reader.ReadUInt16()]; + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value.LocalIndex); + } + } + + /// + /// Inline or + /// + /// Actual member type + public class MsilOperandReflected : MsilOperandInline where TY : class + { + internal MsilOperandReflected(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineTok: + Value = context.TokenResolver.ResolveMember(reader.ReadInt32()) as TY; + break; + case OperandType.InlineType: + Value = context.TokenResolver.ResolveType(reader.ReadInt32()) as TY; + break; + case OperandType.InlineMethod: + Value = context.TokenResolver.ResolveMethod(reader.ReadInt32()) as TY; + break; + case OperandType.InlineField: + Value = context.TokenResolver.ResolveField(reader.ReadInt32()) as TY; + break; + default: + throw new ArgumentException("Reflected operand only applies to inline reflected types"); + } + } + + internal override void Emit(LoggingIlGenerator generator) + { + if (Value is ConstructorInfo) + generator.Emit(Instruction.OpCode, Value as ConstructorInfo); + else if (Value is FieldInfo) + generator.Emit(Instruction.OpCode, Value as FieldInfo); + else if (Value is Type) + generator.Emit(Instruction.OpCode, Value as Type); + else if (Value is MethodInfo) + generator.Emit(Instruction.OpCode, Value as MethodInfo); + } + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs new file mode 100644 index 0000000..66e3b2d --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs @@ -0,0 +1,36 @@ +using System.IO; +using System.Linq; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents the operand for an inline switch statement + /// + public class MsilOperandSwitch : MsilOperand + { + internal MsilOperandSwitch(MsilInstruction instruction) : base(instruction) + { + } + + /// + /// The target labels for this switch + /// + public MsilLabel[] Labels { get; set; } + + internal override void Read(MethodContext context, BinaryReader reader) + { + int length = reader.ReadInt32(); + int offset = (int) reader.BaseStream.Position + 4 * length; + Labels = new MsilLabel[length]; + for (var i = 0; i < Labels.Length; i++) + Labels[i] = context.LabelAt(offset + reader.ReadInt32()); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Labels?.Select(x => x.LabelFor(generator))?.ToArray() ?? new Label[0]); + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MethodRewritePattern.cs b/Torch/Managers/PatchManager/MethodRewritePattern.cs new file mode 100644 index 0000000..b52226d --- /dev/null +++ b/Torch/Managers/PatchManager/MethodRewritePattern.cs @@ -0,0 +1,172 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Reflection; +using System.Threading; + +namespace Torch.Managers.PatchManager +{ + /// + /// Defines the different components used to rewrite a method. + /// + public class MethodRewritePattern + { + /// + /// Sorts methods so that their priority is in descending order. Assumes priority zero if no attribute exists. + /// + private class MethodPriorityCompare : Comparer + { + internal static readonly MethodPriorityCompare Instance = new MethodPriorityCompare(); + + public override int Compare(MethodInfo x, MethodInfo y) + { + return -(x?.GetCustomAttribute()?.Priority ?? 0).CompareTo( + y?.GetCustomAttribute()?.Priority ?? 0); + } + } + + /// + /// Stores an set of methods according to a certain order. + /// + public class MethodRewriteSet : IEnumerable + { + private readonly MethodRewriteSet _backingSet; + private bool _sortDirty = false; + private readonly List _backingList = new List(); + + private int _hasChanges = 0; + + internal bool HasChanges() + { + return Interlocked.Exchange(ref _hasChanges, 0) != 0; + } + + /// + /// + /// The set to track changes on + internal MethodRewriteSet(MethodRewriteSet backingSet) + { + _backingSet = backingSet; + } + + /// + /// Adds the given method to this set if it doesn't already exist in the tracked set and this set. + /// + /// Method to add + /// true if added + public bool Add(MethodInfo m) + { + if (!m.IsStatic) + throw new ArgumentException("Patch methods must be static"); + if (_backingSet != null && !_backingSet.Add(m)) + return false; + if (_backingList.Contains(m)) + return false; + _sortDirty = true; + Interlocked.Exchange(ref _hasChanges, 1); + _backingList.Add(m); + return true; + } + + /// + /// Removes the given method from this set, and from the tracked set if it existed in this set. + /// + /// Method to remove + /// true if removed + public bool Remove(MethodInfo m) + { + if (_backingList.Remove(m)) + { + _sortDirty = true; + Interlocked.Exchange(ref _hasChanges, 1); + return _backingSet == null || _backingSet.Remove(m); + } + return false; + } + + /// + /// Removes all methods from this set, and their matches in the tracked set. + /// + public void RemoveAll() + { + foreach (var k in _backingList) + _backingSet.Remove(k); + _backingList.Clear(); + _sortDirty = true; + Interlocked.Exchange(ref _hasChanges, 1); + } + + /// + /// Gets the number of methods stored in this set. + /// + public int Count => _backingList.Count; + + /// + /// Gets an ordered enumerator over this set + /// + /// + public IEnumerator GetEnumerator() + { + CheckSort(); + return _backingList.GetEnumerator(); + } + + /// + /// Gets an ordered enumerator over this set + /// + /// + IEnumerator IEnumerable.GetEnumerator() + { + CheckSort(); + return _backingList.GetEnumerator(); + } + + private void CheckSort() + { + if (!_sortDirty) + return; + var tmp = _backingList.ToArray(); + MergeSort(tmp, _backingList, MethodPriorityCompare.Instance, 0, _backingList.Count); + _sortDirty = false; + } + + private static void MergeSort(IList src, IList dst, Comparer comparer, int left, int right) + { + if (left + 1 >= right) + return; + var mid = (left + right) / 2; + MergeSort(dst, src, comparer, left, mid); + MergeSort(dst, src, comparer, mid, right); + for (int i = left, j = left, k = mid; i < right; i++) + if ((k >= right || j < mid) && comparer.Compare(src[j], src[k]) <= 0) + dst[i] = src[j++]; + else + dst[i] = src[k++]; + } + } + + /// + /// Methods run before the original method is run. If they return false the original method is skipped. + /// + public MethodRewriteSet Prefixes { get; } + /// + /// Methods capable of accepting one and returing another, modified. + /// + public MethodRewriteSet Transpilers { get; } + /// + /// Methods run after the original method has run. + /// + public MethodRewriteSet Suffixes { get; } + + /// + /// + /// + /// The pattern to track changes on, or null + public MethodRewritePattern(MethodRewritePattern parentPattern) + { + Prefixes = new MethodRewriteSet(parentPattern?.Prefixes); + Transpilers = new MethodRewriteSet(parentPattern?.Transpilers); + Suffixes = new MethodRewriteSet(parentPattern?.Suffixes); + } + } +} diff --git a/Torch/Managers/PatchManager/PatchContext.cs b/Torch/Managers/PatchManager/PatchContext.cs new file mode 100644 index 0000000..a213d99 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchContext.cs @@ -0,0 +1,44 @@ +using System.Collections.Generic; +using System.Reflection; + +namespace Torch.Managers.PatchManager +{ + /// + /// Represents a set of common patches that can all be reversed in a single step. + /// + public class PatchContext + { + private readonly PatchManager _replacer; + private readonly Dictionary _rewritePatterns = new Dictionary(); + + internal PatchContext(PatchManager replacer) + { + _replacer = replacer; + } + + /// + /// Gets the rewrite pattern used to tracking changes in this context, creating one if it doesn't exist. + /// + /// Method to get the pattern for + /// + public MethodRewritePattern GetPattern(MethodBase method) + { + if (_rewritePatterns.TryGetValue(method, out MethodRewritePattern pattern)) + return pattern; + MethodRewritePattern parent = _replacer.GetPattern(method); + var res = new MethodRewritePattern(parent); + _rewritePatterns.Add(method, res); + return res; + } + + internal void RemoveAll() + { + foreach (MethodRewritePattern pattern in _rewritePatterns.Values) + { + pattern.Prefixes.RemoveAll(); + pattern.Transpilers.RemoveAll(); + pattern.Suffixes.RemoveAll(); + } + } + } +} diff --git a/Torch/Managers/PatchManager/PatchManager.cs b/Torch/Managers/PatchManager/PatchManager.cs new file mode 100644 index 0000000..86b0905 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchManager.cs @@ -0,0 +1,86 @@ +using System.Collections.Generic; +using System.Reflection; +using Torch.API; + +namespace Torch.Managers.PatchManager +{ + /// + /// Applies and removes patches from the IL of methods. + /// + public class PatchManager : Manager + { + /// + /// Creates a new patch manager. Only have one active at a time. + /// + /// + public PatchManager(ITorchBase torchInstance) : base(torchInstance) + { + } + + private readonly Dictionary _rewritePatterns = new Dictionary(); + private readonly HashSet _contexts = new HashSet(); + + /// + /// Gets the rewrite pattern for the given method, creating one if it doesn't exist. + /// + /// Method to get the pattern for + /// + public MethodRewritePattern GetPattern(MethodBase method) + { + if (_rewritePatterns.TryGetValue(method, out DecoratedMethod pattern)) + return pattern; + var res = new DecoratedMethod(method); + _rewritePatterns.Add(method, res); + return res; + } + + + /// + /// Creates a new used for tracking changes. A call to will apply the patches. + /// + public PatchContext AcquireContext() + { + var context = new PatchContext(this); + _contexts.Add(context); + return context; + } + + /// + /// Frees the given context, and unregister all patches from it. A call to will apply the unpatching operation. + /// + /// Context to remove + public void FreeContext(PatchContext context) + { + context.RemoveAll(); + _contexts.Remove(context); + } + + /// + /// Commits all method decorations into IL. + /// + public void Commit() + { + foreach (DecoratedMethod m in _rewritePatterns.Values) + m.Commit(); + } + + /// + /// Commits any existing patches. + /// + public override void Attach() + { + Commit(); + } + + /// + /// Unregisters and removes all patches, then applies the unpatching operation. + /// + public override void Detach() + { + foreach (DecoratedMethod m in _rewritePatterns.Values) + m.Revert(); + _rewritePatterns.Clear(); + _contexts.Clear(); + } + } +} diff --git a/Torch/Managers/PatchManager/PatchPriorityAttribute.cs b/Torch/Managers/PatchManager/PatchPriorityAttribute.cs new file mode 100644 index 0000000..a097693 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchPriorityAttribute.cs @@ -0,0 +1,24 @@ +using System; + +namespace Torch.Managers.PatchManager +{ + /// + /// Attribute used to decorate methods used for replacement. + /// + [AttributeUsage(AttributeTargets.Method)] + public class PatchPriorityAttribute : Attribute + { + /// + /// + /// + public PatchPriorityAttribute(int priority) + { + Priority = priority; + } + + /// + /// The priority of this replacement. A high priority prefix occurs first, and a high priority suffix or transpiler occurs last. + /// + public int Priority { get; set; } = 0; + } +} diff --git a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs new file mode 100644 index 0000000..0f49293 --- /dev/null +++ b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs @@ -0,0 +1,173 @@ +using System; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using NLog; + +// ReSharper disable ConditionIsAlwaysTrueOrFalse +#pragma warning disable 162 // unreachable code +namespace Torch.Managers.PatchManager.Transpile +{ + /// + /// An ILGenerator that can log emit calls when is enabled. + /// + public class LoggingIlGenerator + { + private const bool _logging = false; + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + + /// + /// Backing generator + /// + public ILGenerator Backing { get; } + + /// + /// Creates a new logging IL generator backed by the given generator. + /// + /// Backing generator + public LoggingIlGenerator(ILGenerator backing) + { + Backing = backing; + } + + /// + public LocalBuilder DeclareLocal(Type localType, bool isPinned = false) + { + LocalBuilder res = Backing.DeclareLocal(localType, isPinned); + if (_logging) + _log.Trace($"DeclareLocal {res.LocalType} {res.IsPinned} => {res.LocalIndex}"); + return res; + } + + + /// + public void Emit(OpCode op) + { + if (_logging) + _log.Trace($"Emit {op}"); + Backing.Emit(op); + } + + /// + public void Emit(OpCode op, LocalBuilder arg) + { + if (_logging) + _log.Trace($"Emit {op} L:{arg.LocalIndex} {arg.LocalType}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, int arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, long arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, float arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, double arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, string arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, Type arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, FieldInfo arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, MethodInfo arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + private static FieldInfo _labelID = + typeof(Label).GetField("m_label", BindingFlags.Instance | BindingFlags.NonPublic); + + /// + public void Emit(OpCode op, Label arg) + { + if (_logging) + _log.Trace($"Emit {op} L:{_labelID.GetValue(arg)}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, Label[] arg) + { + if (_logging) + _log.Trace($"Emit {op} {string.Join(", ", arg.Select(x => "L:" + _labelID.GetValue(x)))}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, SignatureHelper arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, ConstructorInfo arg) + { + if (_logging) + _log.Trace($"Emit {op} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void MarkLabel(Label label) + { + if (_logging) + _log.Trace($"MarkLabel L:{_labelID.GetValue(label)}"); + Backing.MarkLabel(label); + } + + /// + public Label DefineLabel() + { + return Backing.DefineLabel(); + } + } +#pragma warning restore 162 +} diff --git a/Torch/Managers/PatchManager/Transpile/MethodContext.cs b/Torch/Managers/PatchManager/Transpile/MethodContext.cs new file mode 100644 index 0000000..afa230c --- /dev/null +++ b/Torch/Managers/PatchManager/Transpile/MethodContext.cs @@ -0,0 +1,110 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using NLog; +using Torch.Managers.PatchManager.MSIL; + +namespace Torch.Managers.PatchManager.Transpile +{ + internal class MethodContext + { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + + public readonly MethodBase Method; + private readonly byte[] _msilBytes; + + internal Dictionary Labels { get; } = new Dictionary(); + private readonly List _instructions = new List(); + public IReadOnlyList Instructions => _instructions; + + internal ITokenResolver TokenResolver { get; } + + internal MsilLabel LabelAt(int i) + { + if (Labels.TryGetValue(i, out MsilLabel label)) + return label; + Labels.Add(i, label = new MsilLabel()); + return label; + } + + public MethodContext(MethodBase method) + { + Method = method; + _msilBytes = Method.GetMethodBody().GetILAsByteArray(); + TokenResolver = new NormalTokenResolver(method); + } + + public void Read() + { + ReadInstructions(); + ResolveLabels(); + } + + private void ReadInstructions() + { + Labels.Clear(); + _instructions.Clear(); + using (var memory = new MemoryStream(_msilBytes)) + using (var reader = new BinaryReader(memory)) + while (memory.Length > memory.Position) + { + var instructionValue = (short)memory.ReadByte(); + if (Prefixes.Contains(instructionValue)) + instructionValue = (short)((instructionValue << 8) | memory.ReadByte()); + if (!OpCodeLookup.TryGetValue(instructionValue, out OpCode opcode)) + throw new Exception($"Unknown opcode {instructionValue:X}"); + var instruction = new MsilInstruction(opcode) + { + Offset = (int) memory.Position + }; + _instructions.Add(instruction); + instruction.Operand?.Read(this, reader); + } + } + + private void ResolveLabels() + { + foreach (var label in Labels) + { + int min = 0, max = _instructions.Count - 1; + while (min <= max) + { + var mid = min + ((max - min) / 2); + if (label.Key < _instructions[mid].Offset) + max = mid - 1; + else + min = mid + 1; + } + _instructions[min]?.Labels?.Add(label.Value); + } + } + + public string ToHumanMsil() + { + return string.Join("\n", _instructions.Select(x => x.Offset + "\t" + x)); + } + + private static readonly Dictionary OpCodeLookup; + private static readonly HashSet Prefixes; + + static MethodContext() + { + OpCodeLookup = new Dictionary(); + Prefixes = new HashSet(); + foreach (FieldInfo field in typeof(OpCodes).GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var opcode = (OpCode)field.GetValue(null); + if (opcode.OpCodeType != OpCodeType.Nternal) + OpCodeLookup.Add(opcode.Value, opcode); + if ((ushort) opcode.Value > 0xFF) + { + Prefixes.Add((short) ((ushort) opcode.Value >> 8)); + } + } + } + } +} diff --git a/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs new file mode 100644 index 0000000..9709ed6 --- /dev/null +++ b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs @@ -0,0 +1,61 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.MSIL; + +namespace Torch.Managers.PatchManager.Transpile +{ + internal class MethodTranspiler + { + internal static void Transpile(MethodBase baseMethod, IEnumerable transpilers, LoggingIlGenerator output, Label? retLabel) + { + var context = new MethodContext(baseMethod); + context.Read(); + var methodContent = (IEnumerable) context.Instructions; + foreach (var transpiler in transpilers) + methodContent = (IEnumerable)transpiler.Invoke(null, new object[] { methodContent }); + methodContent = FixBranchAndReturn(methodContent, retLabel); + foreach (var k in methodContent) + k.Emit(output); + } + + private static IEnumerable FixBranchAndReturn(IEnumerable insn, Label? retTarget) + { + foreach (var i in insn) + { + if (retTarget.HasValue && i.OpCode == OpCodes.Ret) + { + var j = new MsilInstruction(OpCodes.Br); + ((MsilOperandBrTarget)j.Operand).Target = new MsilLabel(retTarget.Value); + yield return j; + continue; + } + if (_opcodeReplaceRule.TryGetValue(i.OpCode, out OpCode replaceOpcode)) + { + yield return new MsilInstruction(replaceOpcode) { Operand = i.Operand }; + continue; + } + yield return i; + } + } + + private static readonly Dictionary _opcodeReplaceRule; + static MethodTranspiler() + { + _opcodeReplaceRule = new Dictionary(); + foreach (var field in typeof(OpCodes).GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var opcode = (OpCode)field.GetValue(null); + if (opcode.OperandType == OperandType.ShortInlineBrTarget && + opcode.Name.EndsWith("_S", StringComparison.OrdinalIgnoreCase)) + { + var other = (OpCode?) typeof(OpCodes).GetField(field.Name.Substring(0, field.Name.Length - 2), + BindingFlags.Static | BindingFlags.Public)?.GetValue(null); + if (other.HasValue && other.Value.OperandType == OperandType.InlineBrTarget) + _opcodeReplaceRule.Add(opcode, other.Value); + } + } + } + } +} diff --git a/Torch/Torch.csproj b/Torch/Torch.csproj index 77002b8..32e91d1 100644 --- a/Torch/Torch.csproj +++ b/Torch/Torch.csproj @@ -21,6 +21,7 @@ x64 prompt MinimumRecommendedRules.ruleset + true $(SolutionDir)\bin\x64\Release\ @@ -31,6 +32,7 @@ prompt MinimumRecommendedRules.ruleset $(SolutionDir)\bin\x64\Release\Torch.xml + true @@ -156,6 +158,23 @@ + + + + + + + + + + + + + + + + + From 837b56462f2620efa1c9fe1eb2dacb9869fd90cc Mon Sep 17 00:00:00 2001 From: Westin Miller Date: Sat, 9 Sep 2017 00:50:45 -0700 Subject: [PATCH 2/5] Use reflected manager for patch internals --- Torch/Managers/PatchManager/AssemblyMemory.cs | 11 ++++++--- .../Managers/PatchManager/DecoratedMethod.cs | 23 +++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/Torch/Managers/PatchManager/AssemblyMemory.cs b/Torch/Managers/PatchManager/AssemblyMemory.cs index 0998384..4504993 100644 --- a/Torch/Managers/PatchManager/AssemblyMemory.cs +++ b/Torch/Managers/PatchManager/AssemblyMemory.cs @@ -3,11 +3,17 @@ using System.Reflection; using System.Reflection.Emit; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using Torch.Utils; namespace Torch.Managers.PatchManager { internal class AssemblyMemory { +#pragma warning disable 649 + [ReflectedMethod(Name = "GetMethodDescriptor")] + private static Func _getMethodHandle; +#pragma warning restore 649 + /// /// Gets the address, in RAM, where the body of a method starts. /// @@ -16,9 +22,8 @@ namespace Torch.Managers.PatchManager public static long GetMethodBodyStart(MethodBase method) { RuntimeMethodHandle handle; - if (method is DynamicMethod) - handle = (RuntimeMethodHandle)typeof(DynamicMethod).GetMethod("GetMethodDescriptor", BindingFlags.NonPublic | BindingFlags.Instance) - .Invoke(method, new object[0]); + if (method is DynamicMethod dyn) + handle = _getMethodHandle.Invoke(dyn); else handle = method.MethodHandle; RuntimeHelpers.PrepareMethod(handle); diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs index 326f004..75d09ff 100644 --- a/Torch/Managers/PatchManager/DecoratedMethod.cs +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -8,6 +8,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using NLog; using Torch.Managers.PatchManager.Transpile; +using Torch.Utils; namespace Torch.Managers.PatchManager { @@ -76,21 +77,25 @@ namespace Torch.Managers.PatchManager private const string INSTANCE_PARAMETER = "__instance"; private const string RESULT_PARAMETER = "__result"; +#pragma warning disable 649 + [ReflectedStaticMethod(Type = typeof(RuntimeHelpers), Name = "_CompileMethod", OverrideTypeNames = new[] { "System.IRuntimeMethodInfo" })] + private static Action _compileDynamicMethod; + [ReflectedMethod(Name = "GetMethodInfo")] + private static Func _getMethodInfo; + [ReflectedMethod(Name = "GetMethodDescriptor")] + private static Func _getMethodHandle; +#pragma warning restore 649 + public DynamicMethod ComposePatchedMethod() { - var method = AllocatePatchMethod(); + DynamicMethod method = AllocatePatchMethod(); var generator = new LoggingIlGenerator(method.GetILGenerator()); EmitPatched(generator); // Force it to compile - const BindingFlags nonPublicInstance = BindingFlags.NonPublic | BindingFlags.Instance; - const BindingFlags nonPublicStatic = BindingFlags.NonPublic | BindingFlags.Static; - var compileMethod = typeof(RuntimeHelpers).GetMethod("_CompileMethod", nonPublicStatic); - var getMethodDescriptor = typeof(DynamicMethod).GetMethod("GetMethodDescriptor", nonPublicInstance); - var handle = (RuntimeMethodHandle)getMethodDescriptor.Invoke(method, new object[0]); - var getMethodInfo = typeof(RuntimeMethodHandle).GetMethod("GetMethodInfo", nonPublicInstance); - var runtimeMethodInfo = getMethodInfo.Invoke(handle, new object[0]); - compileMethod.Invoke(null, new[] { runtimeMethodInfo }); + RuntimeMethodHandle handle = _getMethodHandle.Invoke(method); + object runtimeMethodInfo = _getMethodInfo.Invoke(handle); + _compileDynamicMethod.Invoke(runtimeMethodInfo); return method; } #endregion From 4f84cd8963c9018085cb5d4741f58e2dde33f830 Mon Sep 17 00:00:00 2001 From: Westin Miller Date: Sat, 9 Sep 2017 22:15:42 -0700 Subject: [PATCH 3/5] Normal logging for patch manager Fix: RE treating constructors as normal methods --- .../Managers/PatchManager/DecoratedMethod.cs | 46 ++++++++---- .../PatchManager/MSIL/MsilInstruction.cs | 2 +- .../PatchManager/MSIL/MsilOperandInline.cs | 14 ++-- .../PatchManager/MethodRewritePattern.cs | 2 +- .../Transpile/LoggingILGenerator.cs | 72 +++++++++---------- .../PatchManager/Transpile/MethodContext.cs | 10 ++- .../Transpile/MethodTranspiler.cs | 8 ++- 7 files changed, 96 insertions(+), 58 deletions(-) diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs index 75d09ff..8c03138 100644 --- a/Torch/Managers/PatchManager/DecoratedMethod.cs +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -112,8 +112,13 @@ namespace Torch.Managers.PatchManager var specialVariables = new Dictionary(); + Label? labelAfterOriginalContent = Suffixes.Count > 0 ? target.DefineLabel() : (Label?)null; + Label? labelAfterOriginalReturn = Prefixes.Any(x => x.ReturnType == typeof(bool)) ? target.DefineLabel() : (Label?)null; + + var returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void); - var resultVariable = returnType != typeof(void) && Prefixes.Concat(Suffixes).SelectMany(x => x.GetParameters()).Any(x => x.Name == RESULT_PARAMETER) + var resultVariable = returnType != typeof(void) && (labelAfterOriginalReturn.HasValue || // If we jump past main content we need local to store return val + Prefixes.Concat(Suffixes).SelectMany(x => x.GetParameters()).Any(x => x.Name == RESULT_PARAMETER)) ? target.DeclareLocal(returnType) : null; resultVariable?.SetToDefault(target); @@ -121,39 +126,54 @@ namespace Torch.Managers.PatchManager if (resultVariable != null) specialVariables.Add(RESULT_PARAMETER, resultVariable); - var labelAfterOriginalContent = target.DefineLabel(); - var labelAfterOriginalReturn = target.DefineLabel(); - + target.EmitComment("Prefixes Begin"); foreach (var prefix in Prefixes) { EmitMonkeyCall(target, prefix, specialVariables); if (prefix.ReturnType == typeof(bool)) - target.Emit(OpCodes.Brfalse, labelAfterOriginalReturn); + { + Debug.Assert(labelAfterOriginalReturn.HasValue); + target.Emit(OpCodes.Brfalse, labelAfterOriginalReturn.Value); + } else if (prefix.ReturnType != typeof(void)) - throw new Exception($"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}"); + throw new Exception( + $"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}"); } + target.EmitComment("Prefixes End"); + target.EmitComment("Original Begin"); MethodTranspiler.Transpile(_method, Transpilers, target, labelAfterOriginalContent); - target.MarkLabel(labelAfterOriginalContent); - if (resultVariable != null) - target.Emit(OpCodes.Stloc, resultVariable); - target.MarkLabel(labelAfterOriginalReturn); + target.EmitComment("Original End"); + if (labelAfterOriginalContent.HasValue) + { + target.MarkLabel(labelAfterOriginalContent.Value); + if (resultVariable != null) + target.Emit(OpCodes.Stloc, resultVariable); + } + if (labelAfterOriginalReturn.HasValue) + target.MarkLabel(labelAfterOriginalReturn.Value); + target.EmitComment("Suffixes Begin"); foreach (var suffix in Suffixes) { EmitMonkeyCall(target, suffix, specialVariables); if (suffix.ReturnType != typeof(void)) throw new Exception($"Suffixes must return void. {suffix.DeclaringType?.FullName}.{suffix.Name} returns {suffix.ReturnType}"); } + target.EmitComment("Suffixes End"); - if (resultVariable != null) - target.Emit(OpCodes.Ldloc, resultVariable); - target.Emit(OpCodes.Ret); + if (labelAfterOriginalContent.HasValue || labelAfterOriginalReturn.HasValue) + { + if (resultVariable != null) + target.Emit(OpCodes.Ldloc, resultVariable); + target.Emit(OpCodes.Ret); + } } private void EmitMonkeyCall(LoggingIlGenerator target, MethodInfo patch, IReadOnlyDictionary specialVariables) { + target.EmitComment($"Call {patch.DeclaringType?.FullName}#{patch.Name}"); foreach (var param in patch.GetParameters()) { switch (param.Name) diff --git a/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs index a179a6e..9ac0359 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs @@ -40,7 +40,7 @@ namespace Torch.Managers.PatchManager.MSIL Operand = new MsilOperandInline.MsilOperandInt64(this); break; case OperandType.InlineMethod: - Operand = new MsilOperandInline.MsilOperandReflected(this); + Operand = new MsilOperandInline.MsilOperandReflected(this); break; case OperandType.InlineR: Operand = new MsilOperandInline.MsilOperandDouble(this); diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs index f4c45d1..e327054 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.IO; using System.Reflection; using System.Reflection.Emit; @@ -257,23 +258,28 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { + object value = null; switch (Instruction.OpCode.OperandType) { case OperandType.InlineTok: - Value = context.TokenResolver.ResolveMember(reader.ReadInt32()) as TY; + value = context.TokenResolver.ResolveMember(reader.ReadInt32()); break; case OperandType.InlineType: - Value = context.TokenResolver.ResolveType(reader.ReadInt32()) as TY; + value = context.TokenResolver.ResolveType(reader.ReadInt32()); break; case OperandType.InlineMethod: - Value = context.TokenResolver.ResolveMethod(reader.ReadInt32()) as TY; + value = context.TokenResolver.ResolveMethod(reader.ReadInt32()); break; case OperandType.InlineField: - Value = context.TokenResolver.ResolveField(reader.ReadInt32()) as TY; + value = context.TokenResolver.ResolveField(reader.ReadInt32()); break; default: throw new ArgumentException("Reflected operand only applies to inline reflected types"); } + if (value is TY vty) + Value = vty; + else + throw new Exception($"Expected type {typeof(TY).Name} from operand {Instruction.OpCode.OperandType}, got {value.GetType()?.Name ?? "null"}"); } internal override void Emit(LoggingIlGenerator generator) diff --git a/Torch/Managers/PatchManager/MethodRewritePattern.cs b/Torch/Managers/PatchManager/MethodRewritePattern.cs index b52226d..3696371 100644 --- a/Torch/Managers/PatchManager/MethodRewritePattern.cs +++ b/Torch/Managers/PatchManager/MethodRewritePattern.cs @@ -90,7 +90,7 @@ namespace Torch.Managers.PatchManager public void RemoveAll() { foreach (var k in _backingList) - _backingSet.Remove(k); + _backingSet?.Remove(k); _backingList.Clear(); _sortDirty = true; Interlocked.Exchange(ref _hasChanges, 1); diff --git a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs index 0f49293..a170d0c 100644 --- a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs +++ b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs @@ -1,19 +1,22 @@ using System; +using System.Diagnostics; using System.Linq; using System.Reflection; using System.Reflection.Emit; using NLog; +using Torch.Utils; // ReSharper disable ConditionIsAlwaysTrueOrFalse #pragma warning disable 162 // unreachable code namespace Torch.Managers.PatchManager.Transpile { /// - /// An ILGenerator that can log emit calls when is enabled. + /// An ILGenerator that can log emit calls when the TRACE level is enabled. /// public class LoggingIlGenerator { - private const bool _logging = false; + private const int _opcodePadding = -10; + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); /// @@ -34,8 +37,7 @@ namespace Torch.Managers.PatchManager.Transpile public LocalBuilder DeclareLocal(Type localType, bool isPinned = false) { LocalBuilder res = Backing.DeclareLocal(localType, isPinned); - if (_logging) - _log.Trace($"DeclareLocal {res.LocalType} {res.IsPinned} => {res.LocalIndex}"); + _log.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); return res; } @@ -43,123 +45,111 @@ namespace Torch.Managers.PatchManager.Transpile /// public void Emit(OpCode op) { - if (_logging) - _log.Trace($"Emit {op}"); + _log.Trace($"Emit\t{op,_opcodePadding}"); Backing.Emit(op); } /// public void Emit(OpCode op, LocalBuilder arg) { - if (_logging) - _log.Trace($"Emit {op} L:{arg.LocalIndex} {arg.LocalType}"); + _log.Trace($"Emit\t{op,_opcodePadding} L:{arg.LocalIndex} {arg.LocalType}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, int arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, long arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, float arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, double arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, string arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Type arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, FieldInfo arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, MethodInfo arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } - private static FieldInfo _labelID = - typeof(Label).GetField("m_label", BindingFlags.Instance | BindingFlags.NonPublic); + +#pragma warning disable 649 + [ReflectedGetter(Name="m_label")] + private static Func _labelID; +#pragma warning restore 649 /// public void Emit(OpCode op, Label arg) { - if (_logging) - _log.Trace($"Emit {op} L:{_labelID.GetValue(arg)}"); + _log.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Label[] arg) { - if (_logging) - _log.Trace($"Emit {op} {string.Join(", ", arg.Select(x => "L:" + _labelID.GetValue(x)))}"); + _log.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, SignatureHelper arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, ConstructorInfo arg) { - if (_logging) - _log.Trace($"Emit {op} {arg}"); + _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void MarkLabel(Label label) { - if (_logging) - _log.Trace($"MarkLabel L:{_labelID.GetValue(label)}"); + _log.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); Backing.MarkLabel(label); } @@ -168,6 +158,16 @@ namespace Torch.Managers.PatchManager.Transpile { return Backing.DefineLabel(); } + + /// + /// Emits a comment to the log. + /// + /// Comment + [Conditional("DEBUG")] + public void EmitComment(string comment) + { + _log.Trace($"// {comment}"); + } } #pragma warning restore 162 } diff --git a/Torch/Managers/PatchManager/Transpile/MethodContext.cs b/Torch/Managers/PatchManager/Transpile/MethodContext.cs index afa230c..8c5be0d 100644 --- a/Torch/Managers/PatchManager/Transpile/MethodContext.cs +++ b/Torch/Managers/PatchManager/Transpile/MethodContext.cs @@ -52,11 +52,17 @@ namespace Torch.Managers.PatchManager.Transpile using (var reader = new BinaryReader(memory)) while (memory.Length > memory.Position) { + var count = 1; var instructionValue = (short)memory.ReadByte(); if (Prefixes.Contains(instructionValue)) - instructionValue = (short)((instructionValue << 8) | memory.ReadByte()); + { + instructionValue = (short) ((instructionValue << 8) | memory.ReadByte()); + count++; + } if (!OpCodeLookup.TryGetValue(instructionValue, out OpCode opcode)) throw new Exception($"Unknown opcode {instructionValue:X}"); + if (opcode.Size != count) + throw new Exception($"Opcode said it was {opcode.Size} but we read {count}"); var instruction = new MsilInstruction(opcode) { Offset = (int) memory.Position @@ -85,7 +91,7 @@ namespace Torch.Managers.PatchManager.Transpile public string ToHumanMsil() { - return string.Join("\n", _instructions.Select(x => x.Offset + "\t" + x)); + return string.Join("\n", _instructions.Select(x => $"IL_{x.Offset:X4}: {x}")); } private static readonly Dictionary OpCodeLookup; diff --git a/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs index 9709ed6..d9a37eb 100644 --- a/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs +++ b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs @@ -2,16 +2,22 @@ using System.Collections.Generic; using System.Reflection; using System.Reflection.Emit; +using NLog; using Torch.Managers.PatchManager.MSIL; namespace Torch.Managers.PatchManager.Transpile { internal class MethodTranspiler { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + internal static void Transpile(MethodBase baseMethod, IEnumerable transpilers, LoggingIlGenerator output, Label? retLabel) { var context = new MethodContext(baseMethod); context.Read(); + _log.Trace("Input Method:"); + _log.Trace(context.ToHumanMsil); + var methodContent = (IEnumerable) context.Instructions; foreach (var transpiler in transpilers) methodContent = (IEnumerable)transpiler.Invoke(null, new object[] { methodContent }); @@ -48,7 +54,7 @@ namespace Torch.Managers.PatchManager.Transpile { var opcode = (OpCode)field.GetValue(null); if (opcode.OperandType == OperandType.ShortInlineBrTarget && - opcode.Name.EndsWith("_S", StringComparison.OrdinalIgnoreCase)) + opcode.Name.EndsWith(".s", StringComparison.OrdinalIgnoreCase)) { var other = (OpCode?) typeof(OpCodes).GetField(field.Name.Substring(0, field.Name.Length - 2), BindingFlags.Static | BindingFlags.Public)?.GetValue(null); From 9a68ed6bd0a6697c30015afa53031a364fcf3c4c Mon Sep 17 00:00:00 2001 From: Westin Miller Date: Sat, 9 Sep 2017 23:30:49 -0700 Subject: [PATCH 4/5] Added patching tests Fixed error when reverting patches Made the LoggingILGenerator not break without a logger --- Torch.Tests/PatchTest.cs | 386 ++++++++++++++++++ Torch.Tests/Torch.Tests.csproj | 1 + .../Managers/PatchManager/DecoratedMethod.cs | 1 + .../Transpile/LoggingILGenerator.cs | 34 +- 4 files changed, 405 insertions(+), 17 deletions(-) create mode 100644 Torch.Tests/PatchTest.cs diff --git a/Torch.Tests/PatchTest.cs b/Torch.Tests/PatchTest.cs new file mode 100644 index 0000000..1ee4225 --- /dev/null +++ b/Torch.Tests/PatchTest.cs @@ -0,0 +1,386 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Text; +using Torch.Managers.PatchManager; +using Torch.Managers.PatchManager.MSIL; +using Torch.Utils; +using Xunit; + +// ReSharper disable UnusedMember.Local +namespace Torch.Tests +{ +#pragma warning disable 414 + public class PatchTest + { + #region TestRunner + private static readonly PatchManager _patchContext = new PatchManager(null); + + [Theory] + [MemberData(nameof(Prefixes))] + public void TestPrefix(TestBootstrap runner) + { + runner.TestPrefix(); + } + + [Theory] + [MemberData(nameof(Transpilers))] + public void TestTranspile(TestBootstrap runner) + { + runner.TestTranspile(); + } + + [Theory] + [MemberData(nameof(Suffixes))] + public void TestSuffix(TestBootstrap runner) + { + runner.TestSuffix(); + } + + [Theory] + [MemberData(nameof(Combo))] + public void TestCombo(TestBootstrap runner) + { + runner.TestCombo(); + } + + + + public class TestBootstrap + { + public bool HasPrefix => _prefixMethod != null; + public bool HasTranspile => _transpileMethod != null; + public bool HasSuffix => _suffixMethod != null; + + private readonly MethodInfo _prefixMethod, _prefixAssert; + private readonly MethodInfo _suffixMethod, _suffixAssert; + private readonly MethodInfo _transpileMethod, _transpileAssert; + private readonly MethodInfo _targetMethod, _targetAssert; + private readonly MethodInfo _resetMethod; + private readonly object _instance; + private readonly object[] _targetParams; + private readonly Type _type; + + public TestBootstrap(Type t) + { + const BindingFlags flags = BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance; + _type = t; + _prefixMethod = t.GetMethod("Prefix", flags); + _prefixAssert = t.GetMethod("AssertPrefix", flags); + _suffixMethod = t.GetMethod("Suffix", flags); + _suffixAssert = t.GetMethod("AssertSuffix", flags); + _transpileMethod = t.GetMethod("Transpile", flags); + _transpileAssert = t.GetMethod("AssertTranspile", flags); + _targetMethod = t.GetMethod("Target", flags); + _targetAssert = t.GetMethod("AssertNormal", flags); + _resetMethod = t.GetMethod("Reset", flags); + if (_targetMethod == null) + throw new Exception($"{t.FullName} must have a method named Target"); + if (_targetAssert == null) + throw new Exception($"{t.FullName} must have a method named AssertNormal"); + _instance = !_targetMethod.IsStatic ? Activator.CreateInstance(t) : null; + _targetParams = (object[])t.GetField("_targetParams", flags)?.GetValue(null) ?? new object[0]; + } + + private void Invoke(MethodBase i, params object[] args) + { + if (i == null) return; + i.Invoke(i.IsStatic ? null : _instance, args); + } + + private void Invoke() + { + _targetMethod.Invoke(_instance, _targetParams); + Invoke(_targetAssert); + } + + public void TestPrefix() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Prefixes.Add(_prefixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_prefixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestSuffix() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Suffixes.Add(_suffixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_suffixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestTranspile() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Transpilers.Add(_transpileMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_transpileAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestCombo() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + if (_prefixMethod != null) + context.GetPattern(_targetMethod).Prefixes.Add(_prefixMethod); + if (_transpileMethod != null) + context.GetPattern(_targetMethod).Transpilers.Add(_transpileMethod); + if (_suffixMethod != null) + context.GetPattern(_targetMethod).Suffixes.Add(_suffixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_prefixAssert); + Invoke(_transpileAssert); + Invoke(_suffixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public override string ToString() + { + return _type.Name; + } + } + + private class PatchTestAttribute : Attribute + { + } + + private static readonly List _patchTest; + + static PatchTest() + { + TestUtils.Init(); + foreach (Type type in typeof(PatchManager).Assembly.GetTypes()) + if (type.Namespace?.StartsWith(typeof(PatchManager).Namespace ?? "") ?? false) + ReflectedManager.Process(type); + + _patchTest = new List(); + foreach (Type type in typeof(PatchTest).GetNestedTypes(BindingFlags.NonPublic)) + if (type.GetCustomAttribute(typeof(PatchTestAttribute)) != null) + _patchTest.Add(new TestBootstrap(type)); + } + + public static IEnumerable Prefixes => _patchTest.Where(x => x.HasPrefix).Select(x => new object[] { x }); + public static IEnumerable Transpilers => _patchTest.Where(x => x.HasTranspile).Select(x => new object[] { x }); + public static IEnumerable Suffixes => _patchTest.Where(x => x.HasSuffix).Select(x => new object[] { x }); + public static IEnumerable Combo => _patchTest.Where(x => x.HasPrefix || x.HasTranspile || x.HasSuffix).Select(x => new object[] { x }); + #endregion + + #region PatchTests + + [PatchTest] + private class StaticNoRetNoParm + { + private static bool _prefixHit, _normalHit, _suffixHit, _transpileHit; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix() + { + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target() + { + _normalHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Suffix() + { + _suffixHit = true; + } + + public static IEnumerable Transpile(IEnumerable instructions) + { + yield return new MsilInstruction(OpCodes.Ldnull); + yield return new MsilInstruction(OpCodes.Ldc_I4_1); + yield return new MsilInstruction(OpCodes.Stfld).InlineValue(typeof(StaticNoRetNoParm).GetField("_transpileHit", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public)); + foreach (MsilInstruction i in instructions) + yield return i; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = _transpileHit = false; + } + + public static void AssertTranspile() + { + Assert.True(_transpileHit, "Failed to transpile"); + } + + public static void AssertSuffix() + { + Assert.True(_suffixHit, "Failed to suffix"); + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + + [PatchTest] + private class StaticNoRetParam + { + private static bool _prefixHit, _normalHit, _suffixHit; + private static readonly object[] _targetParams = { "test", 1, new StringBuilder("test1") }; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix(string str, int i, StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target(string str, int i, StringBuilder o) + { + _normalHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Suffix(string str, int i, StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + _suffixHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertSuffix() + { + Assert.True(_suffixHit, "Failed to suffix"); + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + + [PatchTest] + private class StaticNoRetParamReplace + { + private static bool _prefixHit, _normalHit, _suffixHit; + private static readonly object[] _targetParams = { "test", 1, new StringBuilder("stest1") }; + private static readonly object[] _replacedParams = { "test2", 2, new StringBuilder("stest2") }; + private static object[] _calledParams; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix(ref string str, ref int i, ref StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + str = (string)_replacedParams[0]; + i = (int)_replacedParams[1]; + o = (StringBuilder)_replacedParams[2]; + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target(string str, int i, StringBuilder o) + { + _calledParams = new object[] { str, i, o }; + _normalHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + for (var i = 0; i < 3; i++) + Assert.Equal(_replacedParams[i], _calledParams[i]); + } + } + + [PatchTest] + private class StaticCancelExec + { + private static bool _prefixHit, _normalHit, _suffixHit; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static bool Prefix() + { + _prefixHit = true; + return false; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target() + { + _normalHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertNormal() + { + Assert.False(_normalHit, "Executed normally when canceled"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + #endregion + } +#pragma warning restore 414 +} diff --git a/Torch.Tests/Torch.Tests.csproj b/Torch.Tests/Torch.Tests.csproj index fc5f81a..d99bc0b 100644 --- a/Torch.Tests/Torch.Tests.csproj +++ b/Torch.Tests/Torch.Tests.csproj @@ -63,6 +63,7 @@ Properties\AssemblyVersion.cs + diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs index 8c03138..025b61a 100644 --- a/Torch/Managers/PatchManager/DecoratedMethod.cs +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -49,6 +49,7 @@ namespace Torch.Managers.PatchManager AssemblyMemory.WriteMemory(_revertAddress, _revertData); _revertData = null; _pinnedPatch.Value.Free(); + _pinnedPatch = null; } } diff --git a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs index a170d0c..e18eb30 100644 --- a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs +++ b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs @@ -37,7 +37,7 @@ namespace Torch.Managers.PatchManager.Transpile public LocalBuilder DeclareLocal(Type localType, bool isPinned = false) { LocalBuilder res = Backing.DeclareLocal(localType, isPinned); - _log.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); + _log?.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); return res; } @@ -45,70 +45,70 @@ namespace Torch.Managers.PatchManager.Transpile /// public void Emit(OpCode op) { - _log.Trace($"Emit\t{op,_opcodePadding}"); + _log?.Trace($"Emit\t{op,_opcodePadding}"); Backing.Emit(op); } /// public void Emit(OpCode op, LocalBuilder arg) { - _log.Trace($"Emit\t{op,_opcodePadding} L:{arg.LocalIndex} {arg.LocalType}"); + _log?.Trace($"Emit\t{op,_opcodePadding} L:{arg.LocalIndex} {arg.LocalType}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, int arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, long arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, float arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, double arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, string arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Type arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, FieldInfo arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, MethodInfo arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } @@ -121,35 +121,35 @@ namespace Torch.Managers.PatchManager.Transpile /// public void Emit(OpCode op, Label arg) { - _log.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); + _log?.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Label[] arg) { - _log.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); + _log?.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, SignatureHelper arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, ConstructorInfo arg) { - _log.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void MarkLabel(Label label) { - _log.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); + _log?.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); Backing.MarkLabel(label); } @@ -166,7 +166,7 @@ namespace Torch.Managers.PatchManager.Transpile [Conditional("DEBUG")] public void EmitComment(string comment) { - _log.Trace($"// {comment}"); + _log?.Trace($"// {comment}"); } } #pragma warning restore 162 From f377d044d6f0c321a01789f55d4a024bfdb8129e Mon Sep 17 00:00:00 2001 From: John Gross Date: Thu, 21 Sep 2017 20:15:18 -0700 Subject: [PATCH 5/5] Observable type improvements --- Torch/Collections/ObservableDictionary.cs | 147 +++++++++++++++++++--- Torch/ViewModels/ViewModel.cs | 11 ++ 2 files changed, 142 insertions(+), 16 deletions(-) diff --git a/Torch/Collections/ObservableDictionary.cs b/Torch/Collections/ObservableDictionary.cs index 2d3f955..68a4a33 100644 --- a/Torch/Collections/ObservableDictionary.cs +++ b/Torch/Collections/ObservableDictionary.cs @@ -1,8 +1,10 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Collections.Specialized; using System.ComponentModel; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; using System.Windows.Threading; @@ -10,28 +12,148 @@ using System.Windows.Threading; namespace Torch.Collections { [Serializable] - public class ObservableDictionary : Dictionary, INotifyCollectionChanged, INotifyPropertyChanged + public class ObservableDictionary : ViewModel, IDictionary, INotifyCollectionChanged { - /// - public new void Add(TKey key, TValue value) + private IDictionary _internalDict; + + public ObservableDictionary() { - base.Add(key, value); - var kv = new KeyValuePair(key, value); - OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Add, kv)); + _internalDict = new Dictionary(); + } + + public ObservableDictionary(IDictionary dictionary) + { + _internalDict = new Dictionary(dictionary); + } + + /// + /// Create a using the given dictionary by reference. The original dictionary should not be used after calling this. + /// + public static ObservableDictionary ByReference(IDictionary dictionary) + { + return new ObservableDictionary + { + _internalDict = dictionary + }; } /// - public new bool Remove(TKey key) + public event NotifyCollectionChangedEventHandler CollectionChanged; + + /// + public event PropertyChangedEventHandler PropertyChanged; + + /// + public IEnumerator> GetEnumerator() { - if (!ContainsKey(key)) + return _internalDict.GetEnumerator(); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_internalDict).GetEnumerator(); + } + + /// + public void Add(KeyValuePair item) + { + Add(item.Key, item.Value); + } + + /// + public bool Remove(KeyValuePair item) + { + return Remove(item.Key); + } + + /// + public void Clear() + { + _internalDict.Clear(); + OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Reset)); + OnPropertyChanged(nameof(Count)); + } + + /// + public bool Contains(KeyValuePair item) + { + return _internalDict.Contains(item); + } + + /// + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + foreach (var kv in _internalDict) + { + array[arrayIndex] = kv; + arrayIndex++; + } + } + + /// + public int Count => _internalDict.Count; + + /// + public bool IsReadOnly => false; + + /// + public bool ContainsKey(TKey key) + { + return _internalDict.ContainsKey(key); + } + + /// + public void Add(TKey key, TValue value) + { + _internalDict.Add(key, value); + var kv = new KeyValuePair(key, value); + OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Add, kv)); + OnPropertyChanged(nameof(Count)); + } + + /// + public bool Remove(TKey key) + { + if (!_internalDict.ContainsKey(key)) return false; var kv = new KeyValuePair(key, this[key]); - base.Remove(key); + if (!_internalDict.Remove(key)) + return false; + OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Remove, kv)); + OnPropertyChanged(nameof(Count)); return true; } + /// + public bool TryGetValue(TKey key, out TValue value) + { + return _internalDict.TryGetValue(key, out value); + } + + /// + public TValue this[TKey key] + { + get => _internalDict[key]; + set + { + var oldKv = new KeyValuePair(key, _internalDict[key]); + var newKv = new KeyValuePair(key, value); + _internalDict[key] = value; + OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Replace, newKv, oldKv)); + } + + + } + + /// + public ICollection Keys => _internalDict.Keys; + + /// + public ICollection Values => _internalDict.Values; + private void OnCollectionChanged(NotifyCollectionChangedEventArgs e) { NotifyCollectionChangedEventHandler collectionChanged = CollectionChanged; @@ -52,12 +174,5 @@ namespace Torch.Collections nh.Invoke(this, e); } } - - - /// - public event NotifyCollectionChangedEventHandler CollectionChanged; - - /// - public event PropertyChangedEventHandler PropertyChanged; } } diff --git a/Torch/ViewModels/ViewModel.cs b/Torch/ViewModels/ViewModel.cs index 010f34c..7233102 100644 --- a/Torch/ViewModels/ViewModel.cs +++ b/Torch/ViewModels/ViewModel.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Linq; +using System.Reflection; using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; @@ -20,6 +21,16 @@ namespace Torch PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propName)); } + protected virtual void SetValue(ref T backingField, T value, [CallerMemberName] string propName = "") + { + if (backingField.Equals(value)) + return; + + backingField = value; + // ReSharper disable once ExplicitCallerInfoArgument + OnPropertyChanged(propName); + } + /// /// Fires PropertyChanged for all properties. ///