diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5931769..6e11ce7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,9 +1,9 @@ # Making a Pull Request -* Fork this repository and make sure your local **master** branch is up to date with the main repository. -* Create a new branch for your addition with an appropriate name, e.g. **add-restart-command** +* Fork this repository and make sure your local **staging** branch is up to date with the main repository. +* Create a new branch from the **staging** branch for your addition with an appropriate name, e.g. **add-restart-command** * PRs work by submitting the *entire* branch, so this allows you to continue work without locking up your whole repository. * Commit your changes to that branch, making sure that you **follow the code guidelines below**. -* Submit your branch as a PR to be reviewed. +* Submit your branch as a PR to be reviewed, with Torch's **staging** branch as the base. ## Naming Conventions * Types: **PascalCase** diff --git a/Jenkins/release.ps1 b/Jenkins/release.ps1 new file mode 100644 index 0000000..81e5e63 --- /dev/null +++ b/Jenkins/release.ps1 @@ -0,0 +1,52 @@ +param([string] $ApiBase, [string]$tagName, [string]$authinfo, [string[]] $assetPaths) +Add-Type -AssemblyName "System.Web" + +$headers = @{ + Authorization = "Basic " + [System.Convert]::ToBase64String([System.Text.Encoding]::ASCII.GetBytes($authinfo)) + Accept = "application/vnd.github.v3+json" +} +try +{ + Write-Output("Checking if release with tag " + $tagName + " already exists...") + $release = Invoke-RestMethod -Uri ($ApiBase+"releases/tags/$tagName") -Method "GET" -Headers $headers + Write-Output(" Using existing release " + $release.id + " at " + $release.html_url) +} catch { + Write-Output(" Doesn't exist") + $rel_arg = @{ + tag_name=$tagName + name="Generated $tagName" + body="" + draft=$TRUE + prerelease=$tagName.Contains("alpha") -or $tagName.Contains("beta") + } + Write-Output("Creating new release " + $tagName + "...") + $release = Invoke-RestMethod -Uri ($ApiBase+"releases") -Method "POST" -Headers $headers -Body (ConvertTo-Json($rel_arg)) + Write-Output(" Created new release " + $tagName + " at " + $release.html_url) +} + +$assetsApiBase = $release.assets_url +Write-Output("Checking for existing assets...") +$existingAssets = Invoke-RestMethod -Uri ($assetsApiBase) -Method "GET" -Headers $headers +$assetLabels = ($assetPaths | ForEach-Object {[System.IO.Path]::GetFileName($_)}) +foreach ($asset in $existingAssets) { + if ($assetLabels -contains $asset.name) { + $uri = $asset.url + Write-Output(" Deleting old asset " + $asset.name + " (id " + $asset.id + "); URI=" + $uri) + $result = Invoke-RestMethod -Uri $uri -Method "DELETE" -Headers $headers + } +} +Write-Output("Uploading assets...") +$uploadUrl = $release.upload_url.Substring(0, $release.upload_url.LastIndexOf('{')) +foreach ($asset in $assetPaths) { + $assetName = [System.IO.Path]::GetFileName($asset) + $assetType = [System.Web.MimeMapping]::GetMimeMapping($asset) + $assetData = [System.IO.File]::ReadAllBytes($asset) + $headerExtra = $headers + @{ + "Content-Type" = $assetType + Name = $assetName + } + $uri = $uploadUrl + "?name=" + $assetName + Write-Output(" Uploading " + $asset + " as " + $assetType + "; URI=" + $uri) + $result = Invoke-RestMethod -Uri $uri -Method "POST" -Headers $headerExtra -Body $assetData + Write-Output(" ID=" + $result.id + ", found at=" + $result.browser_download_url) +} \ No newline at end of file diff --git a/Jenkinsfile b/Jenkinsfile index 841e5bc..cec3b0b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -54,4 +54,14 @@ node { archiveArtifacts artifacts: 'bin/x64/Release/Torch*', caseSensitive: false, fingerprint: true, onlyIfSuccessful: true } + + gitVersion = bat(returnStdout: true, script: "@git describe --tags").trim() + gitSimpleVersion = bat(returnStdout: true, script: "@git describe --tags --abbrev=0").trim() + if (gitVersion == gitSimpleVersion) { + stage('Release') { + withCredentials([usernamePassword(credentialsId: 'torch-github', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + powershell "& ./Jenkins/release.ps1 \"https://api.github.com/repos/TorchAPI/Torch/\" \"$gitSimpleVersion\" \"$USERNAME:$PASSWORD\" @(\"bin/torch-server.zip\", \"bin/torch-client.zip\")" + } + } + } } \ No newline at end of file diff --git a/Torch.API/IChatMessage.cs b/Torch.API/IChatMessage.cs deleted file mode 100644 index e9dbc11..0000000 --- a/Torch.API/IChatMessage.cs +++ /dev/null @@ -1,31 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Torch.API -{ - public interface IChatMessage - { - /// - /// The time the message was created. - /// - DateTime Timestamp { get; } - - /// - /// The SteamID of the message author. - /// - ulong SteamId { get; } - - /// - /// The name of the message author. - /// - string Name { get; } - - /// - /// The content of the message. - /// - string Message { get; } - } -} diff --git a/Torch.API/Managers/IChatManagerClient.cs b/Torch.API/Managers/IChatManagerClient.cs index a3a1f64..d9a0bba 100644 --- a/Torch.API/Managers/IChatManagerClient.cs +++ b/Torch.API/Managers/IChatManagerClient.cs @@ -3,6 +3,8 @@ using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; +using Sandbox.Engine.Multiplayer; +using Sandbox.Game.Multiplayer; using VRage.Network; namespace Torch.API.Managers @@ -12,6 +14,38 @@ namespace Torch.API.Managers /// public struct TorchChatMessage { + /// + /// Creates a new torch chat message with the given author and message. + /// + /// Author's name + /// Message + public TorchChatMessage(string author, string message) + { + Timestamp = DateTime.Now; + AuthorSteamId = null; + Author = author; + Message = message; + Font = "Blue"; + } + + /// + /// Creates a new torch chat message with the given author and message. + /// + /// Author's steam ID + /// Message + public TorchChatMessage(ulong authorSteamId, string message) + { + Timestamp = DateTime.Now; + AuthorSteamId = authorSteamId; + Author = MyMultiplayer.Static?.GetMemberName(authorSteamId) ?? "Player"; + Message = message; + Font = "Blue"; + } + + /// + /// This message's timestamp. + /// + public DateTime Timestamp; /// /// The author's steam ID, if available. Else, null. /// diff --git a/Torch.API/Managers/IMultiplayerManagerBase.cs b/Torch.API/Managers/IMultiplayerManagerBase.cs index 8b552e7..09b76c9 100644 --- a/Torch.API/Managers/IMultiplayerManagerBase.cs +++ b/Torch.API/Managers/IMultiplayerManagerBase.cs @@ -7,14 +7,7 @@ using VRage.Game.ModAPI; namespace Torch.API.Managers { /// - /// Delegate for received messages. - /// - /// Message data. - /// Flag to broadcast message to other players. - public delegate void MessageReceivedDel(IChatMessage message, ref bool sendToOthers); - - /// - /// API for multiplayer related functions. + /// API for multiplayer related functions common to servers and clients. /// public interface IMultiplayerManagerBase : IManager { diff --git a/Torch.API/Managers/IMultiplayerManagerServer.cs b/Torch.API/Managers/IMultiplayerManagerServer.cs index 36dddf8..b0247ba 100644 --- a/Torch.API/Managers/IMultiplayerManagerServer.cs +++ b/Torch.API/Managers/IMultiplayerManagerServer.cs @@ -6,6 +6,9 @@ using System.Threading.Tasks; namespace Torch.API.Managers { + /// + /// API for multiplayer functions that exist on servers and lobbies + /// public interface IMultiplayerManagerServer : IMultiplayerManagerBase { /// diff --git a/Torch.API/Session/ITorchSession.cs b/Torch.API/Session/ITorchSession.cs index 710c9dd..59add06 100644 --- a/Torch.API/Session/ITorchSession.cs +++ b/Torch.API/Session/ITorchSession.cs @@ -25,5 +25,15 @@ namespace Torch.API.Session /// IDependencyManager Managers { get; } + + /// + /// The current state of the session + /// + TorchSessionState State { get; } + + /// + /// Event raised when the changes. + /// + event TorchSessionStateChangedDel StateChanged; } } diff --git a/Torch.API/Session/ITorchSessionManager.cs b/Torch.API/Session/ITorchSessionManager.cs index 8e5a5ed..bfa3b88 100644 --- a/Torch.API/Session/ITorchSessionManager.cs +++ b/Torch.API/Session/ITorchSessionManager.cs @@ -17,32 +17,21 @@ namespace Torch.API.Session /// The manager that will live in the session, or null if none. public delegate IManager SessionManagerFactoryDel(ITorchSession session); - /// - /// Fired when the given session has been completely loaded or is unloading. - /// - /// The session - public delegate void TorchSessionLoadDel(ITorchSession session); - /// /// Manages the creation and destruction of instances for each created by Space Engineers. /// public interface ITorchSessionManager : IManager { - /// - /// Fired when a has finished loading. - /// - event TorchSessionLoadDel SessionLoaded; - - /// - /// Fired when a has begun unloading. - /// - event TorchSessionLoadDel SessionUnloading; - /// /// The currently running session /// ITorchSession CurrentSession { get; } + /// + /// Raised when any changes. + /// + event TorchSessionStateChangedDel SessionStateChanged; + /// /// Adds the given factory as a supplier for session based managers /// diff --git a/Torch.API/Session/TorchSessionState.cs b/Torch.API/Session/TorchSessionState.cs new file mode 100644 index 0000000..6d02da3 --- /dev/null +++ b/Torch.API/Session/TorchSessionState.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Torch.API.Session +{ + /// + /// Represents the state of a + /// + public enum TorchSessionState + { + /// + /// The session has been created, and is now loading. + /// + Loading, + /// + /// The session has loaded, and is now running. + /// + Loaded, + /// + /// The session was running, and is now unloading. + /// + Unloading, + /// + /// The session was unloading, and is now unloaded and stopped. + /// + Unloaded + } + + /// + /// Callback raised when a session's state changes + /// + /// The session who had a state change + /// The session's new state + public delegate void TorchSessionStateChangedDel(ITorchSession session, TorchSessionState newState); +} diff --git a/Torch.API/Torch.API.csproj b/Torch.API/Torch.API.csproj index 6ffcc86..d771abd 100644 --- a/Torch.API/Torch.API.csproj +++ b/Torch.API/Torch.API.csproj @@ -160,7 +160,6 @@ Properties\AssemblyVersion.cs - @@ -186,6 +185,7 @@ + diff --git a/Torch.Server/Initializer.cs b/Torch.Server/Initializer.cs index 933d683..e4fe99e 100644 --- a/Torch.Server/Initializer.cs +++ b/Torch.Server/Initializer.cs @@ -8,6 +8,7 @@ using System.Net; using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Windows.Threading; using NLog; using Torch.Utils; @@ -84,15 +85,15 @@ quit"; _server = new TorchServer(_config); _server.Init(); - if (_config.NoGui || _config.Autostart) - { - new Thread(_server.Start).Start(); - } - if (!_config.NoGui) { - new TorchUI(_server).ShowDialog(); + var ui = new TorchUI(_server); + if (_config.Autostart) + new Thread(_server.Start).Start(); + ui.ShowDialog(); } + else + _server.Start(); _resolver?.Dispose(); } diff --git a/Torch.Server/Managers/InstanceManager.cs b/Torch.Server/Managers/InstanceManager.cs index c8d52b8..f11a446 100644 --- a/Torch.Server/Managers/InstanceManager.cs +++ b/Torch.Server/Managers/InstanceManager.cs @@ -131,7 +131,7 @@ namespace Torch.Server.Managers public void SaveConfig() { - DedicatedConfig.Save(); + DedicatedConfig.Save(Path.Combine(Torch.Config.InstancePath, CONFIG_NAME)); Log.Info("Saved dedicated config."); try diff --git a/Torch.Server/Views/ChatControl.xaml b/Torch.Server/Views/ChatControl.xaml index 8d3e692..3f6ea01 100644 --- a/Torch.Server/Views/ChatControl.xaml +++ b/Torch.Server/Views/ChatControl.xaml @@ -10,20 +10,9 @@ - - - - - - - - - - - - - - + + + diff --git a/Torch.Server/Views/ChatControl.xaml.cs b/Torch.Server/Views/ChatControl.xaml.cs index 1c3152e..59334b2 100644 --- a/Torch.Server/Views/ChatControl.xaml.cs +++ b/Torch.Server/Views/ChatControl.xaml.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Collections.Specialized; using System.Linq; +using System.Reflection; using System.Text; using System.Threading.Tasks; using System.Windows; @@ -24,6 +25,7 @@ using Torch.API.Managers; using Torch.API.Session; using Torch.Managers; using Torch.Server.Managers; +using VRage.Game; namespace Torch.Server { @@ -42,43 +44,75 @@ namespace Torch.Server public void BindServer(ITorchServer server) { _server = (TorchBase)server; - ChatItems.Items.Clear(); + Dispatcher.Invoke(() => + { + ChatItems.Inlines.Clear(); + }); var sessionManager = server.Managers.GetManager(); - sessionManager.SessionLoaded += BindSession; - sessionManager.SessionUnloading += UnbindSession; + if (sessionManager != null) + sessionManager.SessionStateChanged += SessionStateChanged; } - private void BindSession(ITorchSession session) + private void SessionStateChanged(ITorchSession session, TorchSessionState state) { - Dispatcher.Invoke(() => + switch (state) { - var chatMgr = _server?.CurrentSession?.Managers.GetManager(); - if (chatMgr != null) - DataContext = new ChatManagerProxy(chatMgr); - }); + case TorchSessionState.Loading: + Dispatcher.Invoke(() => ChatItems.Inlines.Clear()); + break; + case TorchSessionState.Loaded: + { + var chatMgr = session.Managers.GetManager(); + if (chatMgr != null) + chatMgr.MessageRecieved += OnMessageRecieved; + } + break; + case TorchSessionState.Unloading: + { + var chatMgr = session.Managers.GetManager(); + if (chatMgr != null) + chatMgr.MessageRecieved -= OnMessageRecieved; + } + break; + case TorchSessionState.Unloaded: + break; + default: + throw new ArgumentOutOfRangeException(nameof(state), state, null); + } } - private void UnbindSession(ITorchSession session) + private void OnMessageRecieved(TorchChatMessage msg, ref bool consumed) { - Dispatcher.Invoke(() => - { - (DataContext as ChatManagerProxy)?.Dispose(); - DataContext = null; - }); + InsertMessage(msg); } - private void ChatHistory_CollectionChanged(object sender, NotifyCollectionChangedEventArgs e) + private static readonly Dictionary _brushes = new Dictionary(); + private static Brush LookupBrush(string font) { - ChatItems.ScrollToItem(ChatItems.Items.Count - 1); - /* - if (VisualTreeHelper.GetChildrenCount(ChatItems) > 0) + if (_brushes.TryGetValue(font, out Brush result)) + return result; + Brush brush = typeof(Brushes).GetField(font, BindingFlags.Static)?.GetValue(null) as Brush ?? Brushes.Blue; + _brushes.Add(font, brush); + return brush; + } + + private void InsertMessage(TorchChatMessage msg) + { + if (Dispatcher.CheckAccess()) { - - Border border = (Border)VisualTreeHelper.GetChild(ChatItems, 0); - ScrollViewer scrollViewer = (ScrollViewer)VisualTreeHelper.GetChild(border, 0); - scrollViewer.ScrollToBottom(); - }*/ + bool atBottom = ChatScroller.VerticalOffset + 8 > ChatScroller.ScrollableHeight; + var span = new Span(); + span.Inlines.Add($"{msg.Timestamp} "); + span.Inlines.Add(new Run(msg.Author) { Foreground = LookupBrush(msg.Font) }); + span.Inlines.Add($": {msg.Message}"); + span.Inlines.Add(new LineBreak()); + ChatItems.Inlines.Add(span); + if (atBottom) + ChatScroller.ScrollToBottom(); + } + else + Dispatcher.Invoke(() => InsertMessage(msg)); } private void SendButton_Click(object sender, RoutedEventArgs e) @@ -102,11 +136,12 @@ namespace Torch.Server var commands = _server.CurrentSession?.Managers.GetManager(); if (commands != null && commands.IsCommand(text)) { - (DataContext as ChatManagerProxy)?.AddMessage(new TorchChatMessage() { Author = "Server", Message = text }); + InsertMessage(new TorchChatMessage("Server", text) { Font = MyFontEnum.DarkBlue }); _server.Invoke(() => { - var response = commands.HandleCommandFromServer(text); - Dispatcher.BeginInvoke(() => OnMessageEntered_Callback(response)); + string response = commands.HandleCommandFromServer(text); + if (!string.IsNullOrWhiteSpace(response)) + InsertMessage(new TorchChatMessage("Server", response) { Font = MyFontEnum.Blue }); }); } else @@ -115,40 +150,5 @@ namespace Torch.Server } Message.Text = ""; } - - private void OnMessageEntered_Callback(string response) - { - if (!string.IsNullOrEmpty(response)) - (DataContext as ChatManagerProxy)?.AddMessage(new TorchChatMessage() { Author = "Server", Message = response }); - } - - private class ChatManagerProxy : IDisposable - { - private readonly IChatManagerClient _chatMgr; - - public ChatManagerProxy(IChatManagerClient chatMgr) - { - this._chatMgr = chatMgr; - this._chatMgr.MessageRecieved += ChatMgr_MessageRecieved; ; - } - - public IList ChatHistory { get; } = new ObservableList(); - - /// - public void Dispose() - { - _chatMgr.MessageRecieved -= ChatMgr_MessageRecieved; - } - - private void ChatMgr_MessageRecieved(TorchChatMessage msg, ref bool consumed) - { - AddMessage(msg); - } - - internal void AddMessage(TorchChatMessage msg) - { - ChatHistory.Add(new ChatMessage(DateTime.Now, msg.AuthorSteamId ?? 0, msg.Author, msg.Message)); - } - } } } diff --git a/Torch.Server/Views/PlayerListControl.xaml.cs b/Torch.Server/Views/PlayerListControl.xaml.cs index b76b7c7..c9e9676 100644 --- a/Torch.Server/Views/PlayerListControl.xaml.cs +++ b/Torch.Server/Views/PlayerListControl.xaml.cs @@ -46,18 +46,20 @@ namespace Torch.Server _server = server; var sessionManager = server.Managers.GetManager(); - sessionManager.SessionLoaded += BindSession; - sessionManager.SessionUnloading += UnbindSession; + sessionManager.SessionStateChanged += SessionStateChanged; } - private void BindSession(ITorchSession session) + private void SessionStateChanged(ITorchSession session, TorchSessionState newState) { - Dispatcher.Invoke(() => DataContext = _server?.CurrentSession?.Managers.GetManager()); - } - - private void UnbindSession(ITorchSession session) - { - Dispatcher.Invoke(() => DataContext = null); + switch (newState) + { + case TorchSessionState.Loaded: + Dispatcher.Invoke(() => DataContext = _server?.CurrentSession?.Managers.GetManager()); + break; + case TorchSessionState.Unloading: + Dispatcher.Invoke(() => DataContext = null); + break; + } } private void KickButton_Click(object sender, RoutedEventArgs e) @@ -68,7 +70,7 @@ namespace Torch.Server private void BanButton_Click(object sender, RoutedEventArgs e) { - var player = (KeyValuePair) PlayerList.SelectedItem; + var player = (KeyValuePair)PlayerList.SelectedItem; _server.CurrentSession?.Managers.GetManager()?.BanPlayer(player.Key); } } diff --git a/Torch.Tests/PatchTest.cs b/Torch.Tests/PatchTest.cs new file mode 100644 index 0000000..1ee4225 --- /dev/null +++ b/Torch.Tests/PatchTest.cs @@ -0,0 +1,386 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Text; +using Torch.Managers.PatchManager; +using Torch.Managers.PatchManager.MSIL; +using Torch.Utils; +using Xunit; + +// ReSharper disable UnusedMember.Local +namespace Torch.Tests +{ +#pragma warning disable 414 + public class PatchTest + { + #region TestRunner + private static readonly PatchManager _patchContext = new PatchManager(null); + + [Theory] + [MemberData(nameof(Prefixes))] + public void TestPrefix(TestBootstrap runner) + { + runner.TestPrefix(); + } + + [Theory] + [MemberData(nameof(Transpilers))] + public void TestTranspile(TestBootstrap runner) + { + runner.TestTranspile(); + } + + [Theory] + [MemberData(nameof(Suffixes))] + public void TestSuffix(TestBootstrap runner) + { + runner.TestSuffix(); + } + + [Theory] + [MemberData(nameof(Combo))] + public void TestCombo(TestBootstrap runner) + { + runner.TestCombo(); + } + + + + public class TestBootstrap + { + public bool HasPrefix => _prefixMethod != null; + public bool HasTranspile => _transpileMethod != null; + public bool HasSuffix => _suffixMethod != null; + + private readonly MethodInfo _prefixMethod, _prefixAssert; + private readonly MethodInfo _suffixMethod, _suffixAssert; + private readonly MethodInfo _transpileMethod, _transpileAssert; + private readonly MethodInfo _targetMethod, _targetAssert; + private readonly MethodInfo _resetMethod; + private readonly object _instance; + private readonly object[] _targetParams; + private readonly Type _type; + + public TestBootstrap(Type t) + { + const BindingFlags flags = BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance; + _type = t; + _prefixMethod = t.GetMethod("Prefix", flags); + _prefixAssert = t.GetMethod("AssertPrefix", flags); + _suffixMethod = t.GetMethod("Suffix", flags); + _suffixAssert = t.GetMethod("AssertSuffix", flags); + _transpileMethod = t.GetMethod("Transpile", flags); + _transpileAssert = t.GetMethod("AssertTranspile", flags); + _targetMethod = t.GetMethod("Target", flags); + _targetAssert = t.GetMethod("AssertNormal", flags); + _resetMethod = t.GetMethod("Reset", flags); + if (_targetMethod == null) + throw new Exception($"{t.FullName} must have a method named Target"); + if (_targetAssert == null) + throw new Exception($"{t.FullName} must have a method named AssertNormal"); + _instance = !_targetMethod.IsStatic ? Activator.CreateInstance(t) : null; + _targetParams = (object[])t.GetField("_targetParams", flags)?.GetValue(null) ?? new object[0]; + } + + private void Invoke(MethodBase i, params object[] args) + { + if (i == null) return; + i.Invoke(i.IsStatic ? null : _instance, args); + } + + private void Invoke() + { + _targetMethod.Invoke(_instance, _targetParams); + Invoke(_targetAssert); + } + + public void TestPrefix() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Prefixes.Add(_prefixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_prefixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestSuffix() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Suffixes.Add(_suffixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_suffixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestTranspile() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + context.GetPattern(_targetMethod).Transpilers.Add(_transpileMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_transpileAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public void TestCombo() + { + Invoke(_resetMethod); + PatchContext context = _patchContext.AcquireContext(); + if (_prefixMethod != null) + context.GetPattern(_targetMethod).Prefixes.Add(_prefixMethod); + if (_transpileMethod != null) + context.GetPattern(_targetMethod).Transpilers.Add(_transpileMethod); + if (_suffixMethod != null) + context.GetPattern(_targetMethod).Suffixes.Add(_suffixMethod); + _patchContext.Commit(); + + Invoke(); + Invoke(_prefixAssert); + Invoke(_transpileAssert); + Invoke(_suffixAssert); + + _patchContext.FreeContext(context); + _patchContext.Commit(); + } + + public override string ToString() + { + return _type.Name; + } + } + + private class PatchTestAttribute : Attribute + { + } + + private static readonly List _patchTest; + + static PatchTest() + { + TestUtils.Init(); + foreach (Type type in typeof(PatchManager).Assembly.GetTypes()) + if (type.Namespace?.StartsWith(typeof(PatchManager).Namespace ?? "") ?? false) + ReflectedManager.Process(type); + + _patchTest = new List(); + foreach (Type type in typeof(PatchTest).GetNestedTypes(BindingFlags.NonPublic)) + if (type.GetCustomAttribute(typeof(PatchTestAttribute)) != null) + _patchTest.Add(new TestBootstrap(type)); + } + + public static IEnumerable Prefixes => _patchTest.Where(x => x.HasPrefix).Select(x => new object[] { x }); + public static IEnumerable Transpilers => _patchTest.Where(x => x.HasTranspile).Select(x => new object[] { x }); + public static IEnumerable Suffixes => _patchTest.Where(x => x.HasSuffix).Select(x => new object[] { x }); + public static IEnumerable Combo => _patchTest.Where(x => x.HasPrefix || x.HasTranspile || x.HasSuffix).Select(x => new object[] { x }); + #endregion + + #region PatchTests + + [PatchTest] + private class StaticNoRetNoParm + { + private static bool _prefixHit, _normalHit, _suffixHit, _transpileHit; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix() + { + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target() + { + _normalHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Suffix() + { + _suffixHit = true; + } + + public static IEnumerable Transpile(IEnumerable instructions) + { + yield return new MsilInstruction(OpCodes.Ldnull); + yield return new MsilInstruction(OpCodes.Ldc_I4_1); + yield return new MsilInstruction(OpCodes.Stfld).InlineValue(typeof(StaticNoRetNoParm).GetField("_transpileHit", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public)); + foreach (MsilInstruction i in instructions) + yield return i; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = _transpileHit = false; + } + + public static void AssertTranspile() + { + Assert.True(_transpileHit, "Failed to transpile"); + } + + public static void AssertSuffix() + { + Assert.True(_suffixHit, "Failed to suffix"); + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + + [PatchTest] + private class StaticNoRetParam + { + private static bool _prefixHit, _normalHit, _suffixHit; + private static readonly object[] _targetParams = { "test", 1, new StringBuilder("test1") }; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix(string str, int i, StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target(string str, int i, StringBuilder o) + { + _normalHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Suffix(string str, int i, StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + _suffixHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertSuffix() + { + Assert.True(_suffixHit, "Failed to suffix"); + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + + [PatchTest] + private class StaticNoRetParamReplace + { + private static bool _prefixHit, _normalHit, _suffixHit; + private static readonly object[] _targetParams = { "test", 1, new StringBuilder("stest1") }; + private static readonly object[] _replacedParams = { "test2", 2, new StringBuilder("stest2") }; + private static object[] _calledParams; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Prefix(ref string str, ref int i, ref StringBuilder o) + { + Assert.Equal(_targetParams[0], str); + Assert.Equal(_targetParams[1], i); + Assert.Equal(_targetParams[2], o); + str = (string)_replacedParams[0]; + i = (int)_replacedParams[1]; + o = (StringBuilder)_replacedParams[2]; + _prefixHit = true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target(string str, int i, StringBuilder o) + { + _calledParams = new object[] { str, i, o }; + _normalHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertNormal() + { + Assert.True(_normalHit, "Failed to execute normally"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + for (var i = 0; i < 3; i++) + Assert.Equal(_replacedParams[i], _calledParams[i]); + } + } + + [PatchTest] + private class StaticCancelExec + { + private static bool _prefixHit, _normalHit, _suffixHit; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static bool Prefix() + { + _prefixHit = true; + return false; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static void Target() + { + _normalHit = true; + } + + public static void Reset() + { + _prefixHit = _normalHit = _suffixHit = false; + } + + public static void AssertNormal() + { + Assert.False(_normalHit, "Executed normally when canceled"); + } + + public static void AssertPrefix() + { + Assert.True(_prefixHit, "Failed to prefix"); + } + } + #endregion + } +#pragma warning restore 414 +} diff --git a/Torch.Tests/Torch.Tests.csproj b/Torch.Tests/Torch.Tests.csproj index fc5f81a..d99bc0b 100644 --- a/Torch.Tests/Torch.Tests.csproj +++ b/Torch.Tests/Torch.Tests.csproj @@ -63,6 +63,7 @@ Properties\AssemblyVersion.cs + diff --git a/Torch/ChatMessage.cs b/Torch/ChatMessage.cs deleted file mode 100644 index f2307ab..0000000 --- a/Torch/ChatMessage.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Sandbox.Engine.Multiplayer; -using Sandbox.Engine.Networking; -using Torch.API; -using VRage.Network; - -namespace Torch -{ - public class ChatMessage : IChatMessage - { - public DateTime Timestamp { get; } - public ulong SteamId { get; } - public string Name { get; } - public string Message { get; } - - public ChatMessage(DateTime timestamp, ulong steamId, string name, string message) - { - Timestamp = timestamp; - SteamId = steamId; - Name = name; - Message = message; - } - - public static ChatMessage FromChatMsg(ChatMsg msg, DateTime dt = default(DateTime)) - { - return new ChatMessage( - dt == default(DateTime) ? DateTime.Now : dt, - msg.Author, - MyMultiplayer.Static.GetMemberName(msg.Author), - msg.Text); - } - } -} diff --git a/Torch/Managers/PatchManager/AssemblyMemory.cs b/Torch/Managers/PatchManager/AssemblyMemory.cs new file mode 100644 index 0000000..4504993 --- /dev/null +++ b/Torch/Managers/PatchManager/AssemblyMemory.cs @@ -0,0 +1,104 @@ +using System; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Torch.Utils; + +namespace Torch.Managers.PatchManager +{ + internal class AssemblyMemory + { +#pragma warning disable 649 + [ReflectedMethod(Name = "GetMethodDescriptor")] + private static Func _getMethodHandle; +#pragma warning restore 649 + + /// + /// Gets the address, in RAM, where the body of a method starts. + /// + /// Method to find the start of + /// Address of the method's start + public static long GetMethodBodyStart(MethodBase method) + { + RuntimeMethodHandle handle; + if (method is DynamicMethod dyn) + handle = _getMethodHandle.Invoke(dyn); + else + handle = method.MethodHandle; + RuntimeHelpers.PrepareMethod(handle); + return handle.GetFunctionPointer().ToInt64(); + } + + + + // x64 ISA format: + // [prefixes] [opcode] [mod-r/m] + // [mod-r/m] is bitfield: + // [7-6] = "mod" adressing mode + // [5-3] = register or opcode extension + // [2-0] = "r/m" extra addressing mode + + + // http://ref.x86asm.net/coder64.html + /// Direct register addressing mode. (Jump directly to register) + private const byte MODRM_MOD_DIRECT = 0b11; + + /// Long-mode prefix (64-bit operand) + private const byte REX_W = 0x48; + + /// Moves a 16/32/64 operand into register i when opcode is (MOV_R0+i) + private const byte MOV_R0 = 0xB8; + + // Extra opcodes. Used with opcode extension. + private const byte EXT = 0xFF; + + /// Opcode extension used with for the JMP opcode. + private const byte OPCODE_EXTENSION_JMP = 4; + + + /// + /// Reads a byte array from a memory location + /// + /// Address to read from + /// Number of bytes to read + /// The bytes that were read + public static byte[] ReadMemory(long memory, int bytes) + { + var data = new byte[bytes]; + Marshal.Copy(new IntPtr(memory), data,0, bytes); + return data; + } + + /// + /// Writes a byte array to a memory location. + /// + /// Address to write to + /// Data to write + public static void WriteMemory(long memory, byte[] bytes) + { + Marshal.Copy(bytes,0, new IntPtr(memory), bytes.Length); + } + + /// + /// Writes an x64 assembly jump instruction at the given address. + /// + /// Address to write the instruction at + /// Target address of the jump + /// The bytes that were overwritten + public static byte[] WriteJump(long memory, long jumpTarget) + { + byte[] result = ReadMemory(memory, 12); + unsafe + { + var ptr = (byte*)memory; + *ptr = REX_W; + *(ptr + 1) = MOV_R0; + *((long*)(ptr + 2)) = jumpTarget; + *(ptr + 10) = EXT; + *(ptr + 11) = (MODRM_MOD_DIRECT << 6) | (OPCODE_EXTENSION_JMP << 3) | 0; + } + return result; + } + } +} diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs new file mode 100644 index 0000000..025b61a --- /dev/null +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -0,0 +1,219 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using NLog; +using Torch.Managers.PatchManager.Transpile; +using Torch.Utils; + +namespace Torch.Managers.PatchManager +{ + internal class DecoratedMethod : MethodRewritePattern + { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + private readonly MethodBase _method; + + internal DecoratedMethod(MethodBase method) : base(null) + { + _method = method; + } + + private long _revertAddress; + private byte[] _revertData = null; + private GCHandle? _pinnedPatch; + + internal void Commit() + { + if (!Prefixes.HasChanges() && !Suffixes.HasChanges() && !Transpilers.HasChanges()) + return; + Revert(); + + if (Prefixes.Count == 0 && Suffixes.Count == 0 && Transpilers.Count == 0) + return; + var patch = ComposePatchedMethod(); + + _revertAddress = AssemblyMemory.GetMethodBodyStart(_method); + var newAddress = AssemblyMemory.GetMethodBodyStart(patch); + _revertData = AssemblyMemory.WriteJump(_revertAddress, newAddress); + _pinnedPatch = GCHandle.Alloc(patch); + } + + internal void Revert() + { + if (_pinnedPatch.HasValue) + { + AssemblyMemory.WriteMemory(_revertAddress, _revertData); + _revertData = null; + _pinnedPatch.Value.Free(); + _pinnedPatch = null; + } + } + + #region Create + private int _patchSalt = 0; + private DynamicMethod AllocatePatchMethod() + { + Debug.Assert(_method.DeclaringType != null); + var methodName = _method.Name + $"_{_patchSalt++}"; + var returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void); + var parameters = _method.GetParameters(); + var parameterTypes = (_method.IsStatic ? Enumerable.Empty() : new[] { typeof(object) }) + .Concat(parameters.Select(x => x.ParameterType)).ToArray(); + + var patchMethod = new DynamicMethod(methodName, MethodAttributes.Public | MethodAttributes.Static, CallingConventions.Standard, + returnType, parameterTypes, _method.DeclaringType, true); + if (!_method.IsStatic) + patchMethod.DefineParameter(0, ParameterAttributes.None, INSTANCE_PARAMETER); + for (var i = 0; i < parameters.Length; i++) + patchMethod.DefineParameter((patchMethod.IsStatic ? 0 : 1) + i, parameters[i].Attributes, parameters[i].Name); + + return patchMethod; + } + + + private const string INSTANCE_PARAMETER = "__instance"; + private const string RESULT_PARAMETER = "__result"; + +#pragma warning disable 649 + [ReflectedStaticMethod(Type = typeof(RuntimeHelpers), Name = "_CompileMethod", OverrideTypeNames = new[] { "System.IRuntimeMethodInfo" })] + private static Action _compileDynamicMethod; + [ReflectedMethod(Name = "GetMethodInfo")] + private static Func _getMethodInfo; + [ReflectedMethod(Name = "GetMethodDescriptor")] + private static Func _getMethodHandle; +#pragma warning restore 649 + + public DynamicMethod ComposePatchedMethod() + { + DynamicMethod method = AllocatePatchMethod(); + var generator = new LoggingIlGenerator(method.GetILGenerator()); + EmitPatched(generator); + + // Force it to compile + RuntimeMethodHandle handle = _getMethodHandle.Invoke(method); + object runtimeMethodInfo = _getMethodInfo.Invoke(handle); + _compileDynamicMethod.Invoke(runtimeMethodInfo); + return method; + } + #endregion + + #region Emit + private void EmitPatched(LoggingIlGenerator target) + { + var originalLocalVariables = _method.GetMethodBody().LocalVariables + .Select(x => + { + Debug.Assert(x.LocalType != null); + return target.DeclareLocal(x.LocalType, x.IsPinned); + }).ToArray(); + + var specialVariables = new Dictionary(); + + Label? labelAfterOriginalContent = Suffixes.Count > 0 ? target.DefineLabel() : (Label?)null; + Label? labelAfterOriginalReturn = Prefixes.Any(x => x.ReturnType == typeof(bool)) ? target.DefineLabel() : (Label?)null; + + + var returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void); + var resultVariable = returnType != typeof(void) && (labelAfterOriginalReturn.HasValue || // If we jump past main content we need local to store return val + Prefixes.Concat(Suffixes).SelectMany(x => x.GetParameters()).Any(x => x.Name == RESULT_PARAMETER)) + ? target.DeclareLocal(returnType) + : null; + resultVariable?.SetToDefault(target); + + if (resultVariable != null) + specialVariables.Add(RESULT_PARAMETER, resultVariable); + + target.EmitComment("Prefixes Begin"); + foreach (var prefix in Prefixes) + { + EmitMonkeyCall(target, prefix, specialVariables); + if (prefix.ReturnType == typeof(bool)) + { + Debug.Assert(labelAfterOriginalReturn.HasValue); + target.Emit(OpCodes.Brfalse, labelAfterOriginalReturn.Value); + } + else if (prefix.ReturnType != typeof(void)) + throw new Exception( + $"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}"); + } + target.EmitComment("Prefixes End"); + + target.EmitComment("Original Begin"); + MethodTranspiler.Transpile(_method, Transpilers, target, labelAfterOriginalContent); + target.EmitComment("Original End"); + if (labelAfterOriginalContent.HasValue) + { + target.MarkLabel(labelAfterOriginalContent.Value); + if (resultVariable != null) + target.Emit(OpCodes.Stloc, resultVariable); + } + if (labelAfterOriginalReturn.HasValue) + target.MarkLabel(labelAfterOriginalReturn.Value); + + target.EmitComment("Suffixes Begin"); + foreach (var suffix in Suffixes) + { + EmitMonkeyCall(target, suffix, specialVariables); + if (suffix.ReturnType != typeof(void)) + throw new Exception($"Suffixes must return void. {suffix.DeclaringType?.FullName}.{suffix.Name} returns {suffix.ReturnType}"); + } + target.EmitComment("Suffixes End"); + + if (labelAfterOriginalContent.HasValue || labelAfterOriginalReturn.HasValue) + { + if (resultVariable != null) + target.Emit(OpCodes.Ldloc, resultVariable); + target.Emit(OpCodes.Ret); + } + } + + private void EmitMonkeyCall(LoggingIlGenerator target, MethodInfo patch, + IReadOnlyDictionary specialVariables) + { + target.EmitComment($"Call {patch.DeclaringType?.FullName}#{patch.Name}"); + foreach (var param in patch.GetParameters()) + { + switch (param.Name) + { + case INSTANCE_PARAMETER: + if (_method.IsStatic) + throw new Exception("Can't use an instance parameter for a static method"); + target.Emit(OpCodes.Ldarg_0); + break; + case RESULT_PARAMETER: + var retType = param.ParameterType.IsByRef + ? param.ParameterType.GetElementType() + : param.ParameterType; + if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].LocalType)) + throw new Exception($"Return type {specialVariables[RESULT_PARAMETER].LocalType} can't be assigned to result parameter type {retType}"); + target.Emit(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc, specialVariables[RESULT_PARAMETER]); + break; + default: + var declParam = _method.GetParameters().FirstOrDefault(x => x.Name == param.Name); + if (declParam == null) + throw new Exception($"Parameter name {param.Name} not found"); + var paramIdx = (_method.IsStatic ? 0 : 1) + declParam.Position; + + var patchByRef = param.IsOut || param.ParameterType.IsByRef; + var declByRef = declParam.IsOut || declParam.ParameterType.IsByRef; + if (patchByRef == declByRef) + target.Emit(OpCodes.Ldarg, paramIdx); + else if (patchByRef) + target.Emit(OpCodes.Ldarga, paramIdx); + else + { + target.Emit(OpCodes.Ldarg, paramIdx); + target.EmitDereference(declParam.ParameterType); + } + break; + } + } + target.Emit(OpCodes.Call, patch); + } + #endregion + } +} diff --git a/Torch/Managers/PatchManager/EmitExtensions.cs b/Torch/Managers/PatchManager/EmitExtensions.cs new file mode 100644 index 0000000..90f1ffd --- /dev/null +++ b/Torch/Managers/PatchManager/EmitExtensions.cs @@ -0,0 +1,75 @@ +using System; +using System.Diagnostics; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager +{ + internal static class EmitExtensions + { + /// + /// Sets the given local to its default value in the given IL generator. + /// + /// Local to set to default + /// The IL generator + public static void SetToDefault(this LocalBuilder local, LoggingIlGenerator target) + { + Debug.Assert(local.LocalType != null); + if (local.LocalType.IsEnum || local.LocalType.IsPrimitive) + { + if (local.LocalType == typeof(float)) + target.Emit(OpCodes.Ldc_R4, 0f); + else if (local.LocalType == typeof(double)) + target.Emit(OpCodes.Ldc_R8, 0d); + else if (local.LocalType == typeof(long) || local.LocalType == typeof(ulong)) + target.Emit(OpCodes.Ldc_I8, 0L); + else + target.Emit(OpCodes.Ldc_I4, 0); + target.Emit(OpCodes.Stloc, local); + } + else if (local.LocalType.IsValueType) // struct + { + target.Emit(OpCodes.Ldloca, local); + target.Emit(OpCodes.Initobj, local.LocalType); + } + else // class + { + target.Emit(OpCodes.Ldnull); + target.Emit(OpCodes.Stloc, local); + } + } + + /// + /// Emits a dereference for the given type. + /// + /// IL Generator to emit on + /// Type to dereference + public static void EmitDereference(this LoggingIlGenerator target, Type type) + { + if (type.IsByRef) + type = type.GetElementType(); + Debug.Assert(type != null); + + if (type == typeof(float)) + target.Emit(OpCodes.Ldind_R4); + else if (type == typeof(double)) + target.Emit(OpCodes.Ldind_R8); + else if (type == typeof(byte)) + target.Emit(OpCodes.Ldind_U1); + else if (type == typeof(ushort) || type == typeof(char)) + target.Emit(OpCodes.Ldind_U2); + else if (type == typeof(uint)) + target.Emit(OpCodes.Ldind_U4); + else if (type == typeof(sbyte)) + target.Emit(OpCodes.Ldind_I1); + else if (type == typeof(short)) + target.Emit(OpCodes.Ldind_I2); + else if (type == typeof(int) || type.IsEnum) + target.Emit(OpCodes.Ldind_I4); + else if (type == typeof(long) || type == typeof(ulong)) + target.Emit(OpCodes.Ldind_I8); + else + target.Emit(OpCodes.Ldind_Ref); + } + } +} diff --git a/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs b/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs new file mode 100644 index 0000000..ffaac0a --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs @@ -0,0 +1,226 @@ +using System; +using System.Reflection; +using System.Reflection.Emit; + +namespace Torch.Managers.PatchManager.MSIL +{ + //https://stackoverflow.com/questions/4148297/resolving-the-tokens-found-in-the-il-from-a-dynamic-method/35711376#35711376 + internal interface ITokenResolver + { + MemberInfo ResolveMember(int token); + Type ResolveType(int token); + FieldInfo ResolveField(int token); + MethodBase ResolveMethod(int token); + byte[] ResolveSignature(int token); + string ResolveString(int token); + } + + internal sealed class NormalTokenResolver : ITokenResolver + { + private readonly Type[] _genericTypeArgs, _genericMethArgs; + private readonly Module _module; + + internal NormalTokenResolver(MethodBase method) + { + _module = method.Module; + _genericTypeArgs = method.DeclaringType?.GenericTypeArguments ?? new Type[0]; + _genericMethArgs = method.GetGenericArguments(); + } + + public MemberInfo ResolveMember(int token) + { + return _module.ResolveMember(token, _genericTypeArgs, _genericMethArgs); + } + + public Type ResolveType(int token) + { + return _module.ResolveType(token, _genericTypeArgs, _genericMethArgs); + } + + public FieldInfo ResolveField(int token) + { + return _module.ResolveField(token, _genericTypeArgs, _genericMethArgs); + } + + public MethodBase ResolveMethod(int token) + { + return _module.ResolveMethod(token, _genericTypeArgs, _genericMethArgs); + } + + public byte[] ResolveSignature(int token) + { + return _module.ResolveSignature(token); + } + + public string ResolveString(int token) + { + return _module.ResolveString(token); + } + } + + internal sealed class NullTokenResolver : ITokenResolver + { + internal static readonly NullTokenResolver Instance = new NullTokenResolver(); + + private NullTokenResolver() + { + } + + public MemberInfo ResolveMember(int token) + { + return null; + } + + public Type ResolveType(int token) + { + return null; + } + + public FieldInfo ResolveField(int token) + { + return null; + } + + public MethodBase ResolveMethod(int token) + { + return null; + } + + public byte[] ResolveSignature(int token) + { + return null; + } + + public string ResolveString(int token) + { + return null; + } + } + + internal sealed class DynamicMethodTokenResolver : ITokenResolver + { + private readonly MethodInfo _getFieldInfo; + private readonly MethodInfo _getMethodBase; + private readonly GetTypeFromHandleUnsafe _getTypeFromHandleUnsafe; + private readonly ConstructorInfo _runtimeFieldHandleStubCtor; + private readonly ConstructorInfo _runtimeMethodHandleInternalCtor; + private readonly SignatureResolver _signatureResolver; + private readonly StringResolver _stringResolver; + + private readonly TokenResolver _tokenResolver; + + public DynamicMethodTokenResolver(DynamicMethod dynamicMethod) + { + object resolver = typeof(DynamicMethod) + .GetField("m_resolver", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(dynamicMethod); + if (resolver == null) throw new ArgumentException("The dynamic method's IL has not been finalized."); + + _tokenResolver = (TokenResolver) resolver.GetType() + .GetMethod("ResolveToken", BindingFlags.Instance | BindingFlags.NonPublic) + .CreateDelegate(typeof(TokenResolver), resolver); + _stringResolver = (StringResolver) resolver.GetType() + .GetMethod("GetStringLiteral", BindingFlags.Instance | BindingFlags.NonPublic) + .CreateDelegate(typeof(StringResolver), resolver); + _signatureResolver = (SignatureResolver) resolver.GetType() + .GetMethod("ResolveSignature", BindingFlags.Instance | BindingFlags.NonPublic) + .CreateDelegate(typeof(SignatureResolver), resolver); + + _getTypeFromHandleUnsafe = (GetTypeFromHandleUnsafe) typeof(Type) + .GetMethod("GetTypeFromHandleUnsafe", BindingFlags.Static | BindingFlags.NonPublic, null, + new[] {typeof(IntPtr)}, null).CreateDelegate(typeof(GetTypeFromHandleUnsafe), null); + Type runtimeType = typeof(RuntimeTypeHandle).Assembly.GetType("System.RuntimeType"); + + Type runtimeMethodHandleInternal = + typeof(RuntimeTypeHandle).Assembly.GetType("System.RuntimeMethodHandleInternal"); + _getMethodBase = runtimeType.GetMethod("GetMethodBase", BindingFlags.Static | BindingFlags.NonPublic, null, + new[] {runtimeType, runtimeMethodHandleInternal}, null); + _runtimeMethodHandleInternalCtor = + runtimeMethodHandleInternal.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic, null, + new[] {typeof(IntPtr)}, null); + + Type runtimeFieldInfoStub = typeof(RuntimeTypeHandle).Assembly.GetType("System.RuntimeFieldInfoStub"); + _runtimeFieldHandleStubCtor = + runtimeFieldInfoStub.GetConstructor(BindingFlags.Instance | BindingFlags.Public, null, + new[] {typeof(IntPtr), typeof(object)}, null); + _getFieldInfo = runtimeType.GetMethod("GetFieldInfo", BindingFlags.Static | BindingFlags.NonPublic, null, + new[] {runtimeType, typeof(RuntimeTypeHandle).Assembly.GetType("System.IRuntimeFieldInfo")}, null); + } + + public Type ResolveType(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + return _getTypeFromHandleUnsafe.Invoke(typeHandle); + } + + public MethodBase ResolveMethod(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + return (MethodBase) _getMethodBase.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeMethodHandleInternalCtor.Invoke(new object[] {methodHandle}) + }); + } + + public FieldInfo ResolveField(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + return (FieldInfo) _getFieldInfo.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeFieldHandleStubCtor.Invoke(new object[] {fieldHandle, null}) + }); + } + + public MemberInfo ResolveMember(int token) + { + IntPtr typeHandle, methodHandle, fieldHandle; + _tokenResolver.Invoke(token, out typeHandle, out methodHandle, out fieldHandle); + + if (methodHandle != IntPtr.Zero) + return (MethodBase) _getMethodBase.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeMethodHandleInternalCtor.Invoke(new object[] {methodHandle}) + }); + + if (fieldHandle != IntPtr.Zero) + return (FieldInfo) _getFieldInfo.Invoke(null, new[] + { + typeHandle == IntPtr.Zero ? null : _getTypeFromHandleUnsafe.Invoke(typeHandle), + _runtimeFieldHandleStubCtor.Invoke(new object[] {fieldHandle, null}) + }); + + if (typeHandle != IntPtr.Zero) + return _getTypeFromHandleUnsafe.Invoke(typeHandle); + + throw new NotImplementedException( + "DynamicMethods are not able to reference members by token other than types, methods and fields."); + } + + public byte[] ResolveSignature(int token) + { + return _signatureResolver.Invoke(token, 0); + } + + public string ResolveString(int token) + { + return _stringResolver.Invoke(token); + } + + private delegate void TokenResolver(int token, out IntPtr typeHandle, out IntPtr methodHandle, + out IntPtr fieldHandle); + + private delegate string StringResolver(int token); + + private delegate byte[] SignatureResolver(int token, int fromMethod); + + private delegate Type GetTypeFromHandleUnsafe(IntPtr handle); + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs new file mode 100644 index 0000000..9ac0359 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs @@ -0,0 +1,162 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Reflection.Emit; +using System.Text; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents a single MSIL instruction, and its operand + /// + public class MsilInstruction + { + private MsilOperand _operandBacking; + + /// + /// Creates a new instruction with the given opcode. + /// + /// Opcode + public MsilInstruction(OpCode opcode) + { + OpCode = opcode; + switch (opcode.OperandType) + { + case OperandType.InlineNone: + Operand = null; + break; + case OperandType.ShortInlineBrTarget: + case OperandType.InlineBrTarget: + Operand = new MsilOperandBrTarget(this); + break; + case OperandType.InlineField: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.InlineI: + Operand = new MsilOperandInline.MsilOperandInt32(this); + break; + case OperandType.InlineI8: + Operand = new MsilOperandInline.MsilOperandInt64(this); + break; + case OperandType.InlineMethod: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.InlineR: + Operand = new MsilOperandInline.MsilOperandDouble(this); + break; + case OperandType.InlineSig: + Operand = new MsilOperandInline.MsilOperandSignature(this); + break; + case OperandType.InlineString: + Operand = new MsilOperandInline.MsilOperandString(this); + break; + case OperandType.InlineSwitch: + Operand = new MsilOperandSwitch(this); + break; + case OperandType.InlineTok: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.InlineType: + Operand = new MsilOperandInline.MsilOperandReflected(this); + break; + case OperandType.ShortInlineVar: + case OperandType.InlineVar: + if (OpCode.Name.IndexOf("loc", StringComparison.OrdinalIgnoreCase) != -1) + Operand = new MsilOperandInline.MsilOperandLocal(this); + else + Operand = new MsilOperandInline.MsilOperandParameter(this); + break; + case OperandType.ShortInlineI: + Operand = OpCode == OpCodes.Ldc_I4_S + ? (MsilOperand) new MsilOperandInline.MsilOperandInt8(this) + : new MsilOperandInline.MsilOperandUInt8(this); + break; + case OperandType.ShortInlineR: + Operand = new MsilOperandInline.MsilOperandSingle(this); + break; +#pragma warning disable 618 + case OperandType.InlinePhi: +#pragma warning restore 618 + default: + throw new ArgumentOutOfRangeException(); + } + } + + /// + /// Opcode of this instruction + /// + public OpCode OpCode { get; } + + /// + /// Raw memory offset of this instruction; optional. + /// + public int Offset { get; internal set; } + + /// + /// The operand for this instruction, or null. + /// + public MsilOperand Operand + { + get => _operandBacking; + set + { + if (_operandBacking != null && value.GetType() != _operandBacking.GetType()) + throw new ArgumentException($"Operand for {OpCode.Name} must be {_operandBacking.GetType().Name}"); + _operandBacking = value; + } + } + + /// + /// Labels pointing to this instruction. + /// + public HashSet Labels { get; } = new HashSet(); + + /// + /// Sets the inline value for this instruction. + /// + /// The type of the inline constraint + /// Value + /// This instruction + public MsilInstruction InlineValue(T o) + { + ((MsilOperandInline) Operand).Value = o; + return this; + } + + /// + /// Sets the inline branch target for this instruction. + /// + /// Target to jump to + /// This instruction + public MsilInstruction InlineTarget(MsilLabel label) + { + ((MsilOperandBrTarget) Operand).Target = label; + return this; + } + + /// + /// Emits this instruction to the given generator + /// + /// Emit target + public void Emit(LoggingIlGenerator target) + { + foreach (MsilLabel label in Labels) + target.MarkLabel(label.LabelFor(target)); + if (Operand != null) + Operand.Emit(target); + else + target.Emit(OpCode); + } + + /// + public override string ToString() + { + var sb = new StringBuilder(); + foreach (MsilLabel label in Labels) + sb.Append(label).Append(": "); + sb.Append(OpCode.Name).Append("\t").Append(Operand); + return sb.ToString(); + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilLabel.cs b/Torch/Managers/PatchManager/MSIL/MsilLabel.cs new file mode 100644 index 0000000..2111d27 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilLabel.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents an abstract label, identified by its reference. + /// + public class MsilLabel + { + private readonly List, Label>> _labelInstances = + new List, Label>>(); + + private readonly Label? _overrideLabel; + + /// + /// Creates an empty label the allocates a new when requested. + /// + public MsilLabel() + { + _overrideLabel = null; + } + + /// + /// Creates a label the always supplies the given + /// + public MsilLabel(Label overrideLabel) + { + _overrideLabel = overrideLabel; + } + + /// + /// Creates a label that supplies the given when a label for the given generator is requested, + /// otherwise it creates a new label. + /// + /// Generator to register the label on + /// Label to register + public MsilLabel(LoggingIlGenerator generator, Label label) + { + _labelInstances.Add( + new KeyValuePair, Label>( + new WeakReference(generator), label)); + } + + internal Label LabelFor(LoggingIlGenerator gen) + { + if (_overrideLabel.HasValue) + return _overrideLabel.Value; + foreach (KeyValuePair, Label> kv in _labelInstances) + if (kv.Key.TryGetTarget(out LoggingIlGenerator gen2) && gen2 == gen) + return kv.Value; + Label label = gen.DefineLabel(); + _labelInstances.Add( + new KeyValuePair, Label>(new WeakReference(gen), + label)); + return label; + } + + /// + public override string ToString() + { + return $"L{GetHashCode() & 0xFFFF:X4}"; + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperand.cs b/Torch/Managers/PatchManager/MSIL/MsilOperand.cs new file mode 100644 index 0000000..d9dbf2f --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperand.cs @@ -0,0 +1,25 @@ +using System.IO; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents an operand for a MSIL instruction + /// + public abstract class MsilOperand + { + protected MsilOperand(MsilInstruction instruction) + { + Instruction = instruction; + } + + /// + /// Instruction this operand is associated with + /// + public MsilInstruction Instruction { get; } + + internal abstract void Read(MethodContext context, BinaryReader reader); + + internal abstract void Emit(LoggingIlGenerator generator); + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs new file mode 100644 index 0000000..01e0913 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs @@ -0,0 +1,40 @@ +using System.IO; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents a branch target operand. + /// + public class MsilOperandBrTarget : MsilOperand + { + internal MsilOperandBrTarget(MsilInstruction instruction) : base(instruction) + { + } + + /// + /// Branch target + /// + public MsilLabel Target { get; set; } + + internal override void Read(MethodContext context, BinaryReader reader) + { + int val = Instruction.OpCode.OperandType == OperandType.InlineBrTarget + ? reader.ReadInt32() + : reader.ReadByte(); + Target = context.LabelAt((int) reader.BaseStream.Position + val); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Target.LabelFor(generator)); + } + + /// + public override string ToString() + { + return Target?.ToString() ?? "null"; + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs new file mode 100644 index 0000000..e327054 --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs @@ -0,0 +1,298 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Reflection; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents an inline value + /// + /// The type of the inline value + public abstract class MsilOperandInline : MsilOperand + { + internal MsilOperandInline(MsilInstruction instruction) : base(instruction) + { + } + + /// + /// Inline value + /// + public T Value { get; set; } + + /// + public override string ToString() + { + return Value?.ToString() ?? "null"; + } + } + + /// + /// Registry of different inline operand types + /// + public static class MsilOperandInline + { + /// + /// Inline unsigned byte + /// + public class MsilOperandUInt8 : MsilOperandInline + { + internal MsilOperandUInt8(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadByte(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline signed byte + /// + public class MsilOperandInt8 : MsilOperandInline + { + internal MsilOperandInt8(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + (sbyte) reader.ReadByte(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline integer + /// + public class MsilOperandInt32 : MsilOperandInline + { + internal MsilOperandInt32(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadInt32(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline single + /// + public class MsilOperandSingle : MsilOperandInline + { + internal MsilOperandSingle(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadSingle(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline double + /// + public class MsilOperandDouble : MsilOperandInline + { + internal MsilOperandDouble(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadDouble(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline long + /// + public class MsilOperandInt64 : MsilOperandInline + { + internal MsilOperandInt64(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = reader.ReadInt64(); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline string + /// + public class MsilOperandString : MsilOperandInline + { + internal MsilOperandString(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + context.TokenResolver.ResolveString(reader.ReadInt32()); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline CLR signature + /// + public class MsilOperandSignature : MsilOperandInline + { + internal MsilOperandSignature(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + byte[] sig = context.TokenResolver + .ResolveSignature(reader.ReadInt32()); + throw new ArgumentException("Can't figure out how to convert this."); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value); + } + } + + /// + /// Inline parameter reference + /// + public class MsilOperandParameter : MsilOperandInline + { + internal MsilOperandParameter(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + context.Method.GetParameters()[ + Instruction.OpCode.OperandType == OperandType.ShortInlineVar + ? reader.ReadByte() + : reader.ReadUInt16()]; + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value.Position); + } + } + + /// + /// Inline local variable reference + /// + public class MsilOperandLocal : MsilOperandInline + { + internal MsilOperandLocal(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + Value = + context.Method.GetMethodBody().LocalVariables[ + Instruction.OpCode.OperandType == OperandType.ShortInlineVar + ? reader.ReadByte() + : reader.ReadUInt16()]; + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Value.LocalIndex); + } + } + + /// + /// Inline or + /// + /// Actual member type + public class MsilOperandReflected : MsilOperandInline where TY : class + { + internal MsilOperandReflected(MsilInstruction instruction) : base(instruction) + { + } + + internal override void Read(MethodContext context, BinaryReader reader) + { + object value = null; + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineTok: + value = context.TokenResolver.ResolveMember(reader.ReadInt32()); + break; + case OperandType.InlineType: + value = context.TokenResolver.ResolveType(reader.ReadInt32()); + break; + case OperandType.InlineMethod: + value = context.TokenResolver.ResolveMethod(reader.ReadInt32()); + break; + case OperandType.InlineField: + value = context.TokenResolver.ResolveField(reader.ReadInt32()); + break; + default: + throw new ArgumentException("Reflected operand only applies to inline reflected types"); + } + if (value is TY vty) + Value = vty; + else + throw new Exception($"Expected type {typeof(TY).Name} from operand {Instruction.OpCode.OperandType}, got {value.GetType()?.Name ?? "null"}"); + } + + internal override void Emit(LoggingIlGenerator generator) + { + if (Value is ConstructorInfo) + generator.Emit(Instruction.OpCode, Value as ConstructorInfo); + else if (Value is FieldInfo) + generator.Emit(Instruction.OpCode, Value as FieldInfo); + else if (Value is Type) + generator.Emit(Instruction.OpCode, Value as Type); + else if (Value is MethodInfo) + generator.Emit(Instruction.OpCode, Value as MethodInfo); + } + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs new file mode 100644 index 0000000..66e3b2d --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandSwitch.cs @@ -0,0 +1,36 @@ +using System.IO; +using System.Linq; +using System.Reflection.Emit; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents the operand for an inline switch statement + /// + public class MsilOperandSwitch : MsilOperand + { + internal MsilOperandSwitch(MsilInstruction instruction) : base(instruction) + { + } + + /// + /// The target labels for this switch + /// + public MsilLabel[] Labels { get; set; } + + internal override void Read(MethodContext context, BinaryReader reader) + { + int length = reader.ReadInt32(); + int offset = (int) reader.BaseStream.Position + 4 * length; + Labels = new MsilLabel[length]; + for (var i = 0; i < Labels.Length; i++) + Labels[i] = context.LabelAt(offset + reader.ReadInt32()); + } + + internal override void Emit(LoggingIlGenerator generator) + { + generator.Emit(Instruction.OpCode, Labels?.Select(x => x.LabelFor(generator))?.ToArray() ?? new Label[0]); + } + } +} \ No newline at end of file diff --git a/Torch/Managers/PatchManager/MethodRewritePattern.cs b/Torch/Managers/PatchManager/MethodRewritePattern.cs new file mode 100644 index 0000000..3696371 --- /dev/null +++ b/Torch/Managers/PatchManager/MethodRewritePattern.cs @@ -0,0 +1,172 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Reflection; +using System.Threading; + +namespace Torch.Managers.PatchManager +{ + /// + /// Defines the different components used to rewrite a method. + /// + public class MethodRewritePattern + { + /// + /// Sorts methods so that their priority is in descending order. Assumes priority zero if no attribute exists. + /// + private class MethodPriorityCompare : Comparer + { + internal static readonly MethodPriorityCompare Instance = new MethodPriorityCompare(); + + public override int Compare(MethodInfo x, MethodInfo y) + { + return -(x?.GetCustomAttribute()?.Priority ?? 0).CompareTo( + y?.GetCustomAttribute()?.Priority ?? 0); + } + } + + /// + /// Stores an set of methods according to a certain order. + /// + public class MethodRewriteSet : IEnumerable + { + private readonly MethodRewriteSet _backingSet; + private bool _sortDirty = false; + private readonly List _backingList = new List(); + + private int _hasChanges = 0; + + internal bool HasChanges() + { + return Interlocked.Exchange(ref _hasChanges, 0) != 0; + } + + /// + /// + /// The set to track changes on + internal MethodRewriteSet(MethodRewriteSet backingSet) + { + _backingSet = backingSet; + } + + /// + /// Adds the given method to this set if it doesn't already exist in the tracked set and this set. + /// + /// Method to add + /// true if added + public bool Add(MethodInfo m) + { + if (!m.IsStatic) + throw new ArgumentException("Patch methods must be static"); + if (_backingSet != null && !_backingSet.Add(m)) + return false; + if (_backingList.Contains(m)) + return false; + _sortDirty = true; + Interlocked.Exchange(ref _hasChanges, 1); + _backingList.Add(m); + return true; + } + + /// + /// Removes the given method from this set, and from the tracked set if it existed in this set. + /// + /// Method to remove + /// true if removed + public bool Remove(MethodInfo m) + { + if (_backingList.Remove(m)) + { + _sortDirty = true; + Interlocked.Exchange(ref _hasChanges, 1); + return _backingSet == null || _backingSet.Remove(m); + } + return false; + } + + /// + /// Removes all methods from this set, and their matches in the tracked set. + /// + public void RemoveAll() + { + foreach (var k in _backingList) + _backingSet?.Remove(k); + _backingList.Clear(); + _sortDirty = true; + Interlocked.Exchange(ref _hasChanges, 1); + } + + /// + /// Gets the number of methods stored in this set. + /// + public int Count => _backingList.Count; + + /// + /// Gets an ordered enumerator over this set + /// + /// + public IEnumerator GetEnumerator() + { + CheckSort(); + return _backingList.GetEnumerator(); + } + + /// + /// Gets an ordered enumerator over this set + /// + /// + IEnumerator IEnumerable.GetEnumerator() + { + CheckSort(); + return _backingList.GetEnumerator(); + } + + private void CheckSort() + { + if (!_sortDirty) + return; + var tmp = _backingList.ToArray(); + MergeSort(tmp, _backingList, MethodPriorityCompare.Instance, 0, _backingList.Count); + _sortDirty = false; + } + + private static void MergeSort(IList src, IList dst, Comparer comparer, int left, int right) + { + if (left + 1 >= right) + return; + var mid = (left + right) / 2; + MergeSort(dst, src, comparer, left, mid); + MergeSort(dst, src, comparer, mid, right); + for (int i = left, j = left, k = mid; i < right; i++) + if ((k >= right || j < mid) && comparer.Compare(src[j], src[k]) <= 0) + dst[i] = src[j++]; + else + dst[i] = src[k++]; + } + } + + /// + /// Methods run before the original method is run. If they return false the original method is skipped. + /// + public MethodRewriteSet Prefixes { get; } + /// + /// Methods capable of accepting one and returing another, modified. + /// + public MethodRewriteSet Transpilers { get; } + /// + /// Methods run after the original method has run. + /// + public MethodRewriteSet Suffixes { get; } + + /// + /// + /// + /// The pattern to track changes on, or null + public MethodRewritePattern(MethodRewritePattern parentPattern) + { + Prefixes = new MethodRewriteSet(parentPattern?.Prefixes); + Transpilers = new MethodRewriteSet(parentPattern?.Transpilers); + Suffixes = new MethodRewriteSet(parentPattern?.Suffixes); + } + } +} diff --git a/Torch/Managers/PatchManager/PatchContext.cs b/Torch/Managers/PatchManager/PatchContext.cs new file mode 100644 index 0000000..a213d99 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchContext.cs @@ -0,0 +1,44 @@ +using System.Collections.Generic; +using System.Reflection; + +namespace Torch.Managers.PatchManager +{ + /// + /// Represents a set of common patches that can all be reversed in a single step. + /// + public class PatchContext + { + private readonly PatchManager _replacer; + private readonly Dictionary _rewritePatterns = new Dictionary(); + + internal PatchContext(PatchManager replacer) + { + _replacer = replacer; + } + + /// + /// Gets the rewrite pattern used to tracking changes in this context, creating one if it doesn't exist. + /// + /// Method to get the pattern for + /// + public MethodRewritePattern GetPattern(MethodBase method) + { + if (_rewritePatterns.TryGetValue(method, out MethodRewritePattern pattern)) + return pattern; + MethodRewritePattern parent = _replacer.GetPattern(method); + var res = new MethodRewritePattern(parent); + _rewritePatterns.Add(method, res); + return res; + } + + internal void RemoveAll() + { + foreach (MethodRewritePattern pattern in _rewritePatterns.Values) + { + pattern.Prefixes.RemoveAll(); + pattern.Transpilers.RemoveAll(); + pattern.Suffixes.RemoveAll(); + } + } + } +} diff --git a/Torch/Managers/PatchManager/PatchManager.cs b/Torch/Managers/PatchManager/PatchManager.cs new file mode 100644 index 0000000..86b0905 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchManager.cs @@ -0,0 +1,86 @@ +using System.Collections.Generic; +using System.Reflection; +using Torch.API; + +namespace Torch.Managers.PatchManager +{ + /// + /// Applies and removes patches from the IL of methods. + /// + public class PatchManager : Manager + { + /// + /// Creates a new patch manager. Only have one active at a time. + /// + /// + public PatchManager(ITorchBase torchInstance) : base(torchInstance) + { + } + + private readonly Dictionary _rewritePatterns = new Dictionary(); + private readonly HashSet _contexts = new HashSet(); + + /// + /// Gets the rewrite pattern for the given method, creating one if it doesn't exist. + /// + /// Method to get the pattern for + /// + public MethodRewritePattern GetPattern(MethodBase method) + { + if (_rewritePatterns.TryGetValue(method, out DecoratedMethod pattern)) + return pattern; + var res = new DecoratedMethod(method); + _rewritePatterns.Add(method, res); + return res; + } + + + /// + /// Creates a new used for tracking changes. A call to will apply the patches. + /// + public PatchContext AcquireContext() + { + var context = new PatchContext(this); + _contexts.Add(context); + return context; + } + + /// + /// Frees the given context, and unregister all patches from it. A call to will apply the unpatching operation. + /// + /// Context to remove + public void FreeContext(PatchContext context) + { + context.RemoveAll(); + _contexts.Remove(context); + } + + /// + /// Commits all method decorations into IL. + /// + public void Commit() + { + foreach (DecoratedMethod m in _rewritePatterns.Values) + m.Commit(); + } + + /// + /// Commits any existing patches. + /// + public override void Attach() + { + Commit(); + } + + /// + /// Unregisters and removes all patches, then applies the unpatching operation. + /// + public override void Detach() + { + foreach (DecoratedMethod m in _rewritePatterns.Values) + m.Revert(); + _rewritePatterns.Clear(); + _contexts.Clear(); + } + } +} diff --git a/Torch/Managers/PatchManager/PatchPriorityAttribute.cs b/Torch/Managers/PatchManager/PatchPriorityAttribute.cs new file mode 100644 index 0000000..a097693 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchPriorityAttribute.cs @@ -0,0 +1,24 @@ +using System; + +namespace Torch.Managers.PatchManager +{ + /// + /// Attribute used to decorate methods used for replacement. + /// + [AttributeUsage(AttributeTargets.Method)] + public class PatchPriorityAttribute : Attribute + { + /// + /// + /// + public PatchPriorityAttribute(int priority) + { + Priority = priority; + } + + /// + /// The priority of this replacement. A high priority prefix occurs first, and a high priority suffix or transpiler occurs last. + /// + public int Priority { get; set; } = 0; + } +} diff --git a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs new file mode 100644 index 0000000..e18eb30 --- /dev/null +++ b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs @@ -0,0 +1,173 @@ +using System; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using NLog; +using Torch.Utils; + +// ReSharper disable ConditionIsAlwaysTrueOrFalse +#pragma warning disable 162 // unreachable code +namespace Torch.Managers.PatchManager.Transpile +{ + /// + /// An ILGenerator that can log emit calls when the TRACE level is enabled. + /// + public class LoggingIlGenerator + { + private const int _opcodePadding = -10; + + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + + /// + /// Backing generator + /// + public ILGenerator Backing { get; } + + /// + /// Creates a new logging IL generator backed by the given generator. + /// + /// Backing generator + public LoggingIlGenerator(ILGenerator backing) + { + Backing = backing; + } + + /// + public LocalBuilder DeclareLocal(Type localType, bool isPinned = false) + { + LocalBuilder res = Backing.DeclareLocal(localType, isPinned); + _log?.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); + return res; + } + + + /// + public void Emit(OpCode op) + { + _log?.Trace($"Emit\t{op,_opcodePadding}"); + Backing.Emit(op); + } + + /// + public void Emit(OpCode op, LocalBuilder arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} L:{arg.LocalIndex} {arg.LocalType}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, int arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, long arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, float arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, double arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, string arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, Type arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, FieldInfo arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, MethodInfo arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + +#pragma warning disable 649 + [ReflectedGetter(Name="m_label")] + private static Func _labelID; +#pragma warning restore 649 + + /// + public void Emit(OpCode op, Label arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, Label[] arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, SignatureHelper arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, ConstructorInfo arg) + { + _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void MarkLabel(Label label) + { + _log?.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); + Backing.MarkLabel(label); + } + + /// + public Label DefineLabel() + { + return Backing.DefineLabel(); + } + + /// + /// Emits a comment to the log. + /// + /// Comment + [Conditional("DEBUG")] + public void EmitComment(string comment) + { + _log?.Trace($"// {comment}"); + } + } +#pragma warning restore 162 +} diff --git a/Torch/Managers/PatchManager/Transpile/MethodContext.cs b/Torch/Managers/PatchManager/Transpile/MethodContext.cs new file mode 100644 index 0000000..8c5be0d --- /dev/null +++ b/Torch/Managers/PatchManager/Transpile/MethodContext.cs @@ -0,0 +1,116 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using NLog; +using Torch.Managers.PatchManager.MSIL; + +namespace Torch.Managers.PatchManager.Transpile +{ + internal class MethodContext + { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + + public readonly MethodBase Method; + private readonly byte[] _msilBytes; + + internal Dictionary Labels { get; } = new Dictionary(); + private readonly List _instructions = new List(); + public IReadOnlyList Instructions => _instructions; + + internal ITokenResolver TokenResolver { get; } + + internal MsilLabel LabelAt(int i) + { + if (Labels.TryGetValue(i, out MsilLabel label)) + return label; + Labels.Add(i, label = new MsilLabel()); + return label; + } + + public MethodContext(MethodBase method) + { + Method = method; + _msilBytes = Method.GetMethodBody().GetILAsByteArray(); + TokenResolver = new NormalTokenResolver(method); + } + + public void Read() + { + ReadInstructions(); + ResolveLabels(); + } + + private void ReadInstructions() + { + Labels.Clear(); + _instructions.Clear(); + using (var memory = new MemoryStream(_msilBytes)) + using (var reader = new BinaryReader(memory)) + while (memory.Length > memory.Position) + { + var count = 1; + var instructionValue = (short)memory.ReadByte(); + if (Prefixes.Contains(instructionValue)) + { + instructionValue = (short) ((instructionValue << 8) | memory.ReadByte()); + count++; + } + if (!OpCodeLookup.TryGetValue(instructionValue, out OpCode opcode)) + throw new Exception($"Unknown opcode {instructionValue:X}"); + if (opcode.Size != count) + throw new Exception($"Opcode said it was {opcode.Size} but we read {count}"); + var instruction = new MsilInstruction(opcode) + { + Offset = (int) memory.Position + }; + _instructions.Add(instruction); + instruction.Operand?.Read(this, reader); + } + } + + private void ResolveLabels() + { + foreach (var label in Labels) + { + int min = 0, max = _instructions.Count - 1; + while (min <= max) + { + var mid = min + ((max - min) / 2); + if (label.Key < _instructions[mid].Offset) + max = mid - 1; + else + min = mid + 1; + } + _instructions[min]?.Labels?.Add(label.Value); + } + } + + public string ToHumanMsil() + { + return string.Join("\n", _instructions.Select(x => $"IL_{x.Offset:X4}: {x}")); + } + + private static readonly Dictionary OpCodeLookup; + private static readonly HashSet Prefixes; + + static MethodContext() + { + OpCodeLookup = new Dictionary(); + Prefixes = new HashSet(); + foreach (FieldInfo field in typeof(OpCodes).GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var opcode = (OpCode)field.GetValue(null); + if (opcode.OpCodeType != OpCodeType.Nternal) + OpCodeLookup.Add(opcode.Value, opcode); + if ((ushort) opcode.Value > 0xFF) + { + Prefixes.Add((short) ((ushort) opcode.Value >> 8)); + } + } + } + } +} diff --git a/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs new file mode 100644 index 0000000..d9a37eb --- /dev/null +++ b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Reflection.Emit; +using NLog; +using Torch.Managers.PatchManager.MSIL; + +namespace Torch.Managers.PatchManager.Transpile +{ + internal class MethodTranspiler + { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + + internal static void Transpile(MethodBase baseMethod, IEnumerable transpilers, LoggingIlGenerator output, Label? retLabel) + { + var context = new MethodContext(baseMethod); + context.Read(); + _log.Trace("Input Method:"); + _log.Trace(context.ToHumanMsil); + + var methodContent = (IEnumerable) context.Instructions; + foreach (var transpiler in transpilers) + methodContent = (IEnumerable)transpiler.Invoke(null, new object[] { methodContent }); + methodContent = FixBranchAndReturn(methodContent, retLabel); + foreach (var k in methodContent) + k.Emit(output); + } + + private static IEnumerable FixBranchAndReturn(IEnumerable insn, Label? retTarget) + { + foreach (var i in insn) + { + if (retTarget.HasValue && i.OpCode == OpCodes.Ret) + { + var j = new MsilInstruction(OpCodes.Br); + ((MsilOperandBrTarget)j.Operand).Target = new MsilLabel(retTarget.Value); + yield return j; + continue; + } + if (_opcodeReplaceRule.TryGetValue(i.OpCode, out OpCode replaceOpcode)) + { + yield return new MsilInstruction(replaceOpcode) { Operand = i.Operand }; + continue; + } + yield return i; + } + } + + private static readonly Dictionary _opcodeReplaceRule; + static MethodTranspiler() + { + _opcodeReplaceRule = new Dictionary(); + foreach (var field in typeof(OpCodes).GetFields(BindingFlags.Static | BindingFlags.Public)) + { + var opcode = (OpCode)field.GetValue(null); + if (opcode.OperandType == OperandType.ShortInlineBrTarget && + opcode.Name.EndsWith(".s", StringComparison.OrdinalIgnoreCase)) + { + var other = (OpCode?) typeof(OpCodes).GetField(field.Name.Substring(0, field.Name.Length - 2), + BindingFlags.Static | BindingFlags.Public)?.GetValue(null); + if (other.HasValue && other.Value.OperandType == OperandType.InlineBrTarget) + _opcodeReplaceRule.Add(opcode, other.Value); + } + } + } + } +} diff --git a/Torch/Managers/PluginManager.cs b/Torch/Managers/PluginManager.cs index 0b136ab..c6828e2 100644 --- a/Torch/Managers/PluginManager.cs +++ b/Torch/Managers/PluginManager.cs @@ -29,6 +29,7 @@ namespace Torch.Managers /// public IList Plugins { get; } = new ObservableList(); + public event Action PluginLoaded; public event Action> PluginsLoaded; public PluginManager(ITorchBase torchInstance) : base(torchInstance) @@ -42,46 +43,68 @@ namespace Torch.Managers /// public void UpdatePlugins() { - if (_sessionManager != null) - { - _sessionManager.SessionLoaded += AttachCommandsToSession; - _sessionManager.SessionUnloading += DetachCommandsFromSession; - } foreach (var plugin in Plugins) plugin.Update(); } + private Action _attachCommandsHandler = null; + + private void SessionStateChanged(ITorchSession session, TorchSessionState newState) + { + var cmdManager = session.Managers.GetManager(); + if (cmdManager == null) + return; + switch (newState) + { + case TorchSessionState.Loaded: + if (_attachCommandsHandler != null) + PluginLoaded -= _attachCommandsHandler; + _attachCommandsHandler = (x) => cmdManager.RegisterPluginCommands(x); + PluginLoaded += _attachCommandsHandler; + foreach (ITorchPlugin plugin in Plugins) + cmdManager.RegisterPluginCommands(plugin); + break; + case TorchSessionState.Unloading: + if (_attachCommandsHandler != null) + { + PluginLoaded -= _attachCommandsHandler; + _attachCommandsHandler = null; + } + foreach (ITorchPlugin plugin in Plugins) + { + // cmdMgr?.UnregisterPluginCommands(plugin); + } + break; + case TorchSessionState.Loading: + case TorchSessionState.Unloaded: + break; + default: + throw new ArgumentOutOfRangeException(nameof(newState), newState, null); + } + } + + /// + /// Prepares the plugin manager for loading. + /// + public override void Attach() + { + if (_sessionManager != null) + _sessionManager.SessionStateChanged += SessionStateChanged; + } + /// /// Unloads all plugins. /// public override void Detach() { if (_sessionManager != null) - { - _sessionManager.SessionLoaded -= AttachCommandsToSession; - _sessionManager.SessionUnloading -= DetachCommandsFromSession; - } + _sessionManager.SessionStateChanged -= SessionStateChanged; foreach (var plugin in Plugins) plugin.Dispose(); Plugins.Clear(); } - private void AttachCommandsToSession(ITorchSession session) - { - var cmdMgr = session.Managers.GetManager(); - foreach (ITorchPlugin plugin in Plugins) - cmdMgr?.RegisterPluginCommands(plugin); - } - - private void DetachCommandsFromSession(ITorchSession session) - { - var cmdMgr = session.Managers.GetManager(); - foreach (ITorchPlugin plugin in Plugins) { - // cmdMgr?.UnregisterPluginCommands(plugin); - } - } - private void DownloadPlugins() { var folders = Directory.GetDirectories(PluginDir); @@ -144,6 +167,7 @@ namespace Torch.Managers _log.Info($"Loading plugin {plugin.Name} ({plugin.Version})"); plugin.StoragePath = Torch.Config.InstancePath; Plugins.Add(plugin); + PluginLoaded?.Invoke(plugin); } catch (Exception e) { diff --git a/Torch/Session/TorchSession.cs b/Torch/Session/TorchSession.cs index 6ec4a2d..2120ef1 100644 --- a/Torch/Session/TorchSession.cs +++ b/Torch/Session/TorchSession.cs @@ -45,5 +45,20 @@ namespace Torch.Session { Managers.Detach(); } + + private TorchSessionState _state = TorchSessionState.Loading; + /// + public TorchSessionState State + { + get => _state; + internal set + { + _state = value; + StateChanged?.Invoke(this, _state); + } + } + + /// + public event TorchSessionStateChangedDel StateChanged; } } diff --git a/Torch/Session/TorchSessionManager.cs b/Torch/Session/TorchSessionManager.cs index 933f61f..426b320 100644 --- a/Torch/Session/TorchSessionManager.cs +++ b/Torch/Session/TorchSessionManager.cs @@ -22,10 +22,7 @@ namespace Torch.Session private TorchSession _currentSession; /// - public event TorchSessionLoadDel SessionLoaded; - - /// - public event TorchSessionLoadDel SessionUnloading; + public event TorchSessionStateChangedDel SessionStateChanged; /// public ITorchSession CurrentSession => _currentSession; @@ -52,50 +49,130 @@ namespace Torch.Session return _factories.Remove(factory); } - private void LoadSession() - { - if (_currentSession != null) - { - _log.Warn($"Override old torch session {_currentSession.KeenSession.Name}"); - _currentSession.Detach(); - } + #region Session events - _log.Info($"Starting new torch session for {MySession.Static.Name}"); - _currentSession = new TorchSession(Torch, MySession.Static); - foreach (SessionManagerFactoryDel factory in _factories) - { - IManager manager = factory(CurrentSession); - if (manager != null) - CurrentSession.Managers.AddManager(manager); - } - (CurrentSession as TorchSession)?.Attach(); - SessionLoaded?.Invoke(_currentSession); - } - - private void UnloadSession() + private void SetState(TorchSessionState state) { if (_currentSession == null) return; - SessionUnloading?.Invoke(_currentSession); - _log.Info($"Unloading torch session for {_currentSession.KeenSession.Name}"); - _currentSession.Detach(); - _currentSession = null; + _currentSession.State = state; + SessionStateChanged?.Invoke(_currentSession, _currentSession.State); } + private void SessionLoading() + { + try + { + if (_currentSession != null) + { + _log.Warn($"Override old torch session {_currentSession.KeenSession.Name}"); + _currentSession.Detach(); + } + + _log.Info($"Starting new torch session for {MySession.Static.Name}"); + + _currentSession = new TorchSession(Torch, MySession.Static); + SetState(TorchSessionState.Loading); + } + catch (Exception e) + { + _log.Error(e); + throw; + } + } + + private void SessionLoaded() + { + try + { + if (_currentSession == null) + { + _log.Warn("Session loaded event occurred when we don't have a session."); + return; + } + foreach (SessionManagerFactoryDel factory in _factories) + { + IManager manager = factory(CurrentSession); + if (manager != null) + CurrentSession.Managers.AddManager(manager); + } + (CurrentSession as TorchSession)?.Attach(); + SetState(TorchSessionState.Loaded); + } + catch (Exception e) + { + _log.Error(e); + throw; + } + } + + private void SessionUnloading() + { + try + { + if (_currentSession == null) + { + _log.Warn("Session unloading event occurred when we don't have a session."); + return; + } + SetState(TorchSessionState.Unloading); + } + catch (Exception e) + { + _log.Error(e); + throw; + } + } + + private void SessionUnloaded() + { + try + { + if (_currentSession == null) + { + _log.Warn("Session unloading event occurred when we don't have a session."); + return; + } + _log.Info($"Unloading torch session for {_currentSession.KeenSession.Name}"); + SetState(TorchSessionState.Unloaded); + _currentSession.Detach(); + _currentSession = null; + } + catch (Exception e) + { + _log.Error(e); + throw; + } + } + #endregion + /// public override void Attach() { - MySession.AfterLoading += LoadSession; - MySession.OnUnloaded += UnloadSession; + MySession.OnLoading += SessionLoading; + MySession.AfterLoading += SessionLoaded; + MySession.OnUnloading += SessionUnloading; + MySession.OnUnloaded += SessionUnloaded; } + /// public override void Detach() { - _currentSession?.Detach(); - _currentSession = null; - MySession.AfterLoading -= LoadSession; - MySession.OnUnloaded -= UnloadSession; + MySession.OnLoading -= SessionLoading; + MySession.AfterLoading -= SessionLoaded; + MySession.OnUnloading -= SessionUnloading; + MySession.OnUnloaded -= SessionUnloaded; + + if (_currentSession != null) + { + if (_currentSession.State == TorchSessionState.Loaded) + SetState(TorchSessionState.Unloading); + if (_currentSession.State == TorchSessionState.Unloading) + SetState(TorchSessionState.Unloaded); + _currentSession.Detach(); + _currentSession = null; + } } } } diff --git a/Torch/Torch.csproj b/Torch/Torch.csproj index 4b7be00..aed1db0 100644 --- a/Torch/Torch.csproj +++ b/Torch/Torch.csproj @@ -21,6 +21,7 @@ x64 prompt MinimumRecommendedRules.ruleset + true $(SolutionDir)\bin\x64\Release\ @@ -31,6 +32,7 @@ prompt MinimumRecommendedRules.ruleset $(SolutionDir)\bin\x64\Release\Torch.xml + true @@ -152,12 +154,28 @@ Properties\AssemblyVersion.cs - + + + + + + + + + + + + + + + + + diff --git a/Torch/Utils/ReflectedManager.cs b/Torch/Utils/ReflectedManager.cs index 2f52507..39dae0b 100644 --- a/Torch/Utils/ReflectedManager.cs +++ b/Torch/Utils/ReflectedManager.cs @@ -23,6 +23,15 @@ namespace Torch.Utils /// Declaring type of the member to access. If null, inferred from the instance argument type. /// public Type Type { get; set; } = null; + + /// + /// Assembly qualified name of + /// + public string TypeName + { + get => Type?.AssemblyQualifiedName; + set => Type = value == null ? null : Type.GetType(value, true); + } } #region MemberInfoAttributes @@ -160,6 +169,19 @@ namespace Torch.Utils [AttributeUsage(AttributeTargets.Field)] public class ReflectedMethodAttribute : ReflectedMemberAttribute { + /// + /// When set the parameters types for the method are assumed to be this. + /// + public Type[] OverrideTypes { get; set; } + + /// + /// Assembly qualified names of + /// + public string[] OverrideTypeNames + { + get => OverrideTypes.Select(x => x.AssemblyQualifiedName).ToArray(); + set => OverrideTypes = value?.Select(x => x == null ? null : Type.GetType(x)).ToArray(); + } } /// @@ -534,10 +556,14 @@ namespace Torch.Utils trueParameterTypes = parameters.Skip(1).Select(x => x.ParameterType).ToArray(); } + var invokeTypes = new Type[trueParameterTypes.Length]; + for (var i = 0; i < invokeTypes.Length; i++) + invokeTypes[i] = attr.OverrideTypes?[i] ?? trueParameterTypes[i]; + MethodInfo methodInstance = trueType.GetMethod(attr.Name ?? field.Name, (attr is ReflectedStaticMethodAttribute ? BindingFlags.Static : BindingFlags.Instance) | BindingFlags.Public | - BindingFlags.NonPublic, null, CallingConventions.Any, trueParameterTypes, null); + BindingFlags.NonPublic, null, CallingConventions.Any, invokeTypes, null); if (methodInstance == null) { string methodType = attr is ReflectedStaticMethodAttribute ? "static" : "instance"; @@ -547,13 +573,38 @@ namespace Torch.Utils $"Unable to find {methodType} method {attr.Name ?? field.Name} in type {trueType.FullName} with parameters {methodParams}"); } + if (attr is ReflectedStaticMethodAttribute) - field.SetValue(null, Delegate.CreateDelegate(field.FieldType, methodInstance)); + { + if (attr.OverrideTypes != null) + { + ParameterExpression[] paramExp = + parameters.Select(x => Expression.Parameter(x.ParameterType)).ToArray(); + var argExp = new Expression[invokeTypes.Length]; + for (var i = 0; i < argExp.Length; i++) + if (invokeTypes[i] != paramExp[i].Type) + argExp[i] = Expression.Convert(paramExp[i], invokeTypes[i]); + else + argExp[i] = paramExp[i]; + field.SetValue(null, + Expression.Lambda(Expression.Call(methodInstance, argExp), paramExp) + .Compile()); + } + else + field.SetValue(null, Delegate.CreateDelegate(field.FieldType, methodInstance)); + } else { - ParameterExpression[] paramExp = parameters.Select(x => Expression.Parameter(x.ParameterType, x.Name)).ToArray(); + ParameterExpression[] paramExp = + parameters.Select(x => Expression.Parameter(x.ParameterType)).ToArray(); + var argExp = new Expression[invokeTypes.Length]; + for (var i = 0; i < argExp.Length; i++) + if (invokeTypes[i] != paramExp[i + 1].Type) + argExp[i] = Expression.Convert(paramExp[i + 1], invokeTypes[i]); + else + argExp[i] = paramExp[i + 1]; field.SetValue(null, - Expression.Lambda(Expression.Call(paramExp[0], methodInstance, paramExp.Skip(1)), paramExp) + Expression.Lambda(Expression.Call(paramExp[0], methodInstance, argExp), paramExp) .Compile()); } } @@ -589,7 +640,7 @@ namespace Torch.Utils if (trueType == null && isStatic) throw new ArgumentException("Static field setters need their type defined", nameof(field)); - if (!isStatic) + if (!isStatic && trueType == null) trueType = parameters[0].ParameterType; } else if (attr is ReflectedGetterAttribute) @@ -602,7 +653,7 @@ namespace Torch.Utils if (trueType == null && isStatic) throw new ArgumentException("Static field getters need their type defined", nameof(field)); - if (!isStatic) + if (!isStatic && trueType == null) trueType = parameters[0].ParameterType; } else @@ -620,10 +671,15 @@ namespace Torch.Utils $"Unable to find field or property for {trueName} in {trueType.FullName} or its base types", nameof(field)); ParameterExpression[] paramExp = parameters.Select(x => Expression.Parameter(x.ParameterType)).ToArray(); + Expression instanceExpr = null; + if (!isStatic) + { + instanceExpr = trueType == paramExp[0].Type ? (Expression) paramExp[0] : Expression.Convert(paramExp[0], trueType); + } MemberExpression fieldExp = sourceField != null - ? Expression.Field(isStatic ? null : paramExp[0], sourceField) - : Expression.Property(isStatic ? null : paramExp[0], sourceProperty); + ? Expression.Field(instanceExpr, sourceField) + : Expression.Property(instanceExpr, sourceProperty); Expression impl; if (attr is ReflectedSetterAttribute) {