Merge pull request #242 from TorchAPI/eq-patcher-locals

Allow prefixes and suffixes to declare shared locals
This commit is contained in:
Westin Miller
2018-07-27 07:35:37 -07:00
committed by GitHub

View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel.Design;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
@@ -65,7 +66,8 @@ namespace Torch.Managers.PatchManager
{ {
if (_pinnedPatch.HasValue) if (_pinnedPatch.HasValue)
{ {
_log.Debug($"Revert {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); _log.Debug(
$"Revert {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})");
AssemblyMemory.WriteMemory(_revertAddress, _revertData); AssemblyMemory.WriteMemory(_revertAddress, _revertData);
_revertData = null; _revertData = null;
_pinnedPatch.Value.Free(); _pinnedPatch.Value.Free();
@@ -74,7 +76,9 @@ namespace Torch.Managers.PatchManager
} }
#region Create #region Create
private int _patchSalt = 0; private int _patchSalt = 0;
private DynamicMethod AllocatePatchMethod() private DynamicMethod AllocatePatchMethod()
{ {
Debug.Assert(_method.DeclaringType != null); Debug.Assert(_method.DeclaringType != null);
@@ -98,6 +102,7 @@ namespace Torch.Managers.PatchManager
public const string INSTANCE_PARAMETER = "__instance"; public const string INSTANCE_PARAMETER = "__instance";
public const string RESULT_PARAMETER = "__result"; public const string RESULT_PARAMETER = "__result";
public const string PREFIX_SKIPPED_PARAMETER = "__prefixSkipped"; public const string PREFIX_SKIPPED_PARAMETER = "__prefixSkipped";
public const string LOCAL_PARAMETER = "__local";
public DynamicMethod ComposePatchedMethod() public DynamicMethod ComposePatchedMethod()
@@ -112,6 +117,7 @@ namespace Torch.Managers.PatchManager
MethodTranspiler.IntegrityAnalysis(LogLevel.Info, il); MethodTranspiler.IntegrityAnalysis(LogLevel.Info, il);
} }
} }
MethodTranspiler.EmitMethod(il, generator); MethodTranspiler.EmitMethod(il, generator);
try try
@@ -126,13 +132,17 @@ namespace Torch.Managers.PatchManager
ctx.Read(); ctx.Read();
MethodTranspiler.IntegrityAnalysis(LogLevel.Warn, ctx.Instructions); MethodTranspiler.IntegrityAnalysis(LogLevel.Warn, ctx.Instructions);
} }
throw; throw;
} }
return method; return method;
} }
#endregion #endregion
#region Emit #region Emit
private IEnumerable<MsilInstruction> EmitPatched(Func<Type, bool, MsilLocal> declareLocal) private IEnumerable<MsilInstruction> EmitPatched(Func<Type, bool, MsilLocal> declareLocal)
{ {
var methodBody = _method.GetMethodBody(); var methodBody = _method.GetMethodBody();
@@ -142,6 +152,7 @@ namespace Torch.Managers.PatchManager
Debug.Assert(localVar.LocalType != null); Debug.Assert(localVar.LocalType != null);
declareLocal(localVar.LocalType, localVar.IsPinned); declareLocal(localVar.LocalType, localVar.IsPinned);
} }
var instructions = new List<MsilInstruction>(); var instructions = new List<MsilInstruction>();
var specialVariables = new Dictionary<string, MsilLocal>(); var specialVariables = new Dictionary<string, MsilLocal>();
@@ -157,6 +168,7 @@ namespace Torch.Managers.PatchManager
|| Prefixes.Any(x => x.ReturnType == typeof(bool))) || Prefixes.Any(x => x.ReturnType == typeof(bool)))
resultVariable = declareLocal(returnType, false); resultVariable = declareLocal(returnType, false);
} }
if (resultVariable != null) if (resultVariable != null)
instructions.AddRange(resultVariable.SetToDefault()); instructions.AddRange(resultVariable.SetToDefault());
MsilLocal prefixSkippedVariable = null; MsilLocal prefixSkippedVariable = null;
@@ -170,6 +182,23 @@ namespace Torch.Managers.PatchManager
if (resultVariable != null) if (resultVariable != null)
specialVariables.Add(RESULT_PARAMETER, resultVariable); specialVariables.Add(RESULT_PARAMETER, resultVariable);
// Create special variables
foreach (var m in Prefixes.Concat(Suffixes))
foreach (var param in m.GetParameters())
if (param.Name.StartsWith(LOCAL_PARAMETER))
{
var requiredType = param.ParameterType.IsByRef ? param.ParameterType.GetElementType() : param.ParameterType;
if (specialVariables.TryGetValue(param.Name, out var existingParam))
{
if (existingParam.Type != requiredType)
throw new ArgumentException(
$"Trying to use injected local {param.Name} for {m.DeclaringType?.FullName}#{m.ToString()} with type {requiredType} but a local with the same name already exists with type {existingParam.Type}",
param.Name);
}
else
specialVariables.Add(param.Name, declareLocal(requiredType, false));
}
foreach (MethodInfo prefix in Prefixes) foreach (MethodInfo prefix in Prefixes)
{ {
instructions.AddRange(EmitMonkeyCall(prefix, specialVariables)); instructions.AddRange(EmitMonkeyCall(prefix, specialVariables));
@@ -179,6 +208,7 @@ namespace Torch.Managers.PatchManager
throw new Exception( throw new Exception(
$"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}"); $"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}");
} }
instructions.AddRange(MethodTranspiler.Transpile(_method, (x) => declareLocal(x, false), Transpilers, labelAfterOriginalContent)); instructions.AddRange(MethodTranspiler.Transpile(_method, (x) => declareLocal(x, false), Transpilers, labelAfterOriginalContent));
instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(labelAfterOriginalContent)); instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(labelAfterOriginalContent));
@@ -192,6 +222,7 @@ namespace Torch.Managers.PatchManager
instructions.Add(new MsilInstruction(OpCodes.Ldc_I4_1)); instructions.Add(new MsilInstruction(OpCodes.Ldc_I4_1));
instructions.Add(new MsilInstruction(OpCodes.Stloc).InlineValue(prefixSkippedVariable)); instructions.Add(new MsilInstruction(OpCodes.Stloc).InlineValue(prefixSkippedVariable));
} }
instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(notSkip)); instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(notSkip));
foreach (MethodInfo suffix in Suffixes) foreach (MethodInfo suffix in Suffixes)
@@ -200,6 +231,7 @@ namespace Torch.Managers.PatchManager
if (suffix.ReturnType != typeof(void)) if (suffix.ReturnType != typeof(void))
throw new Exception($"Suffixes must return void. {suffix.DeclaringType?.FullName}.{suffix.Name} returns {suffix.ReturnType}"); throw new Exception($"Suffixes must return void. {suffix.DeclaringType?.FullName}.{suffix.Name} returns {suffix.ReturnType}");
} }
if (resultVariable != null) if (resultVariable != null)
instructions.Add(new MsilInstruction(OpCodes.Ldloc).InlineValue(resultVariable)); instructions.Add(new MsilInstruction(OpCodes.Ldloc).InlineValue(resultVariable));
instructions.Add(new MsilInstruction(OpCodes.Ret)); instructions.Add(new MsilInstruction(OpCodes.Ret));
@@ -218,11 +250,14 @@ namespace Torch.Managers.PatchManager
switch (param.Name) switch (param.Name)
{ {
case INSTANCE_PARAMETER: case INSTANCE_PARAMETER:
{
if (_method.IsStatic) if (_method.IsStatic)
throw new Exception("Can't use an instance parameter for a static method"); throw new Exception("Can't use an instance parameter for a static method");
yield return new MsilInstruction(OpCodes.Ldarg_0); yield return new MsilInstruction(OpCodes.Ldarg_0);
break; break;
}
case PREFIX_SKIPPED_PARAMETER: case PREFIX_SKIPPED_PARAMETER:
{
if (param.ParameterType != typeof(bool)) if (param.ParameterType != typeof(bool))
throw new Exception($"Prefix skipped parameter {param.ParameterType} must be of type bool"); throw new Exception($"Prefix skipped parameter {param.ParameterType} must be of type bool");
if (param.ParameterType.IsByRef || param.IsOut) if (param.ParameterType.IsByRef || param.IsOut)
@@ -232,8 +267,10 @@ namespace Torch.Managers.PatchManager
else else
yield return new MsilInstruction(OpCodes.Ldc_I4_0); yield return new MsilInstruction(OpCodes.Ldc_I4_0);
break; break;
}
case RESULT_PARAMETER: case RESULT_PARAMETER:
Type retType = param.ParameterType.IsByRef {
var retType = param.ParameterType.IsByRef
? param.ParameterType.GetElementType() ? param.ParameterType.GetElementType()
: param.ParameterType; : param.ParameterType;
if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].Type)) if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].Type))
@@ -241,7 +278,16 @@ namespace Torch.Managers.PatchManager
yield return new MsilInstruction(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc) yield return new MsilInstruction(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc)
.InlineValue(specialVariables[RESULT_PARAMETER]); .InlineValue(specialVariables[RESULT_PARAMETER]);
break; break;
}
default: default:
{
if (specialVariables.TryGetValue(param.Name, out var specialVar))
{
yield return new MsilInstruction(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc)
.InlineValue(specialVar);
break;
}
ParameterInfo declParam = _method.GetParameters().FirstOrDefault(x => x.Name == param.Name); ParameterInfo declParam = _method.GetParameters().FirstOrDefault(x => x.Name == param.Name);
if (declParam == null) if (declParam == null)
throw new Exception($"Parameter name {param.Name} not found"); throw new Exception($"Parameter name {param.Name} not found");
@@ -258,11 +304,15 @@ namespace Torch.Managers.PatchManager
yield return new MsilInstruction(OpCodes.Ldarg).InlineValue(new MsilArgument(paramIdx)); yield return new MsilInstruction(OpCodes.Ldarg).InlineValue(new MsilArgument(paramIdx));
yield return EmitExtensions.EmitDereference(declParam.ParameterType); yield return EmitExtensions.EmitDereference(declParam.ParameterType);
} }
break; break;
} }
} }
}
yield return new MsilInstruction(OpCodes.Call).InlineValue(patch); yield return new MsilInstruction(OpCodes.Call).InlineValue(patch);
} }
#endregion #endregion
} }
} }