diff --git a/Jenkinsfile b/Jenkinsfile index 7180626..3b62660 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,8 +39,8 @@ node { } else { buildMode = "Debug" } - bat "rmdir /Q /S \"bin\"" - bat "rmdir /Q /S \"bin-test\"" + bat "IF EXIST \"bin\" rmdir /Q /S \"bin\"" + bat "IF EXIST \"bin-test\" rmdir /Q /S \"bin-test\"" bat "\"${tool 'MSBuild'}msbuild\" Torch.sln /p:Configuration=${buildMode} /p:Platform=x64 /t:Clean" bat "\"${tool 'MSBuild'}msbuild\" Torch.sln /p:Configuration=${buildMode} /p:Platform=x64" } diff --git a/Torch.Server/Managers/EntityControlManager.cs b/Torch.Server/Managers/EntityControlManager.cs new file mode 100644 index 0000000..5f43884 --- /dev/null +++ b/Torch.Server/Managers/EntityControlManager.cs @@ -0,0 +1,265 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Windows.Controls; +using NLog; +using NLog.Fluent; +using Torch.API; +using Torch.Collections; +using Torch.Managers; +using Torch.Server.ViewModels.Entities; +using Torch.Utils; + +namespace Torch.Server.Managers +{ + /// + /// Manager that lets users bind random view models to entities in Torch's Entity Manager + /// + public class EntityControlManager : Manager + { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + + /// + /// Creates an entity control manager for the given instance of torch + /// + /// Torch instance + internal EntityControlManager(ITorchBase torchInstance) : base(torchInstance) + { + } + + private abstract class ModelFactory + { + private readonly ConditionalWeakTable _models = new ConditionalWeakTable(); + + public abstract Delegate Delegate { get; } + + protected abstract EntityControlViewModel Create(EntityViewModel evm); + +#pragma warning disable 649 + [ReflectedGetter(Name = "Keys")] + private static readonly Func, ICollection> _weakTableKeys; +#pragma warning restore 649 + + /// + /// Warning: Creates a giant list, avoid if possible. + /// + internal ICollection Keys => _weakTableKeys(_models); + + internal EntityControlViewModel GetOrCreate(EntityViewModel evm) + { + return _models.GetValue(evm, Create); + } + + internal bool TryGet(EntityViewModel evm, out EntityControlViewModel res) + { + return _models.TryGetValue(evm, out res); + } + } + + private class ModelFactory : ModelFactory where T : EntityViewModel + { + private readonly Func _factory; + public override Delegate Delegate => _factory; + + internal ModelFactory(Func factory) + { + _factory = factory; + } + + + protected override EntityControlViewModel Create(EntityViewModel evm) + { + if (evm is T m) + { + var result = _factory(m); + _log.Debug($"Model factory {_factory.Method} created {result} for {evm}"); + return result; + } + return null; + } + } + + private readonly List _modelFactories = new List(); + private readonly List _controlFactories = new List(); + + private readonly List> _boundEntityViewModels = new List>(); + private readonly ConditionalWeakTable> _boundViewModels = new ConditionalWeakTable>(); + + /// + /// This factory will be used to create component models for matching entity models. + /// + /// entity model type to match + /// Method to create component model from entity model. + public void RegisterModelFactory(Func modelFactory) + where TEntityBaseModel : EntityViewModel + { + if (!typeof(TEntityBaseModel).IsAssignableFrom(modelFactory.Method.GetParameters()[0].ParameterType)) + throw new ArgumentException("Generic type must match lamda type", nameof(modelFactory)); + lock (this) + { + var factory = new ModelFactory(modelFactory); + _modelFactories.Add(factory); + + var i = 0; + while (i < _boundEntityViewModels.Count) + { + if (_boundEntityViewModels[i].TryGetTarget(out EntityViewModel target) && + _boundViewModels.TryGetValue(target, out MtObservableList components)) + { + if (target is TEntityBaseModel tent) + UpdateBinding(target, components); + i++; + } + else + _boundEntityViewModels.RemoveAtFast(i); + } + } + } + + /// + /// Unregisters a factory registered with + /// + /// entity model type to match + /// Method to create component model from entity model. + public void UnregisterModelFactory(Func modelFactory) + where TEntityBaseModel : EntityViewModel + { + if (!typeof(TEntityBaseModel).IsAssignableFrom(modelFactory.Method.GetParameters()[0].ParameterType)) + throw new ArgumentException("Generic type must match lamda type", nameof(modelFactory)); + lock (this) + { + for (var i = 0; i < _modelFactories.Count; i++) + { + if (_modelFactories[i].Delegate == (Delegate)modelFactory) + { + foreach (var entry in _modelFactories[i].Keys) + if (_modelFactories[i].TryGet(entry, out EntityControlViewModel ecvm) && ecvm != null + && _boundViewModels.TryGetValue(entry, out var binding)) + binding.Remove(ecvm); + _modelFactories.RemoveAt(i); + break; + } + } + } + } + + /// + /// This factory will be used to create controls for matching view models. + /// + /// component model to match + /// Method to create control from component model + public void RegisterControlFactory( + Func controlFactory) + where TEntityComponentModel : EntityControlViewModel + { + if (!typeof(TEntityComponentModel).IsAssignableFrom(controlFactory.Method.GetParameters()[0].ParameterType)) + throw new ArgumentException("Generic type must match lamda type", nameof(controlFactory)); + lock (this) + { + _controlFactories.Add(controlFactory); + RefreshControls(); + } + } + + /// + /// Unregisters a factory registered with + /// + /// component model to match + /// Method to create control from component model + public void UnregisterControlFactory( + Func controlFactory) + where TEntityComponentModel : EntityControlViewModel + { + if (!typeof(TEntityComponentModel).IsAssignableFrom(controlFactory.Method.GetParameters()[0].ParameterType)) + throw new ArgumentException("Generic type must match lamda type", nameof(controlFactory)); + lock (this) + { + _controlFactories.Remove(controlFactory); + RefreshControls(); + } + } + + private void RefreshControls() where TEntityComponentModel : EntityControlViewModel + { + var i = 0; + while (i < _boundEntityViewModels.Count) + { + if (_boundEntityViewModels[i].TryGetTarget(out EntityViewModel target) && + _boundViewModels.TryGetValue(target, out MtObservableList components)) + { + foreach (EntityControlViewModel component in components) + if (component is TEntityComponentModel) + component.InvalidateControl(); + i++; + } + else + _boundEntityViewModels.RemoveAtFast(i); + } + } + + /// + /// Gets the models bound to the given entity view model. + /// + /// view model to query + /// + public MtObservableList BoundModels(EntityViewModel entity) + { + return _boundViewModels.GetValue(entity, CreateFreshBinding); + } + + /// + /// Gets a control for the given view model type. + /// + /// model to create a control for + /// control, or null if none + public Control CreateControl(EntityControlViewModel model) + { + lock (this) + foreach (Delegate factory in _controlFactories) + if (factory.Method.GetParameters()[0].ParameterType.IsInstanceOfType(model) && + factory.DynamicInvoke(model) is Control result) + { + _log.Debug($"Control factory {factory.Method} created {result}"); + return result; + } + _log.Warn($"No control created for {model}"); + return null; + } + + private MtObservableList CreateFreshBinding(EntityViewModel key) + { + var binding = new MtObservableList(); + lock (this) + { + _boundEntityViewModels.Add(new WeakReference(key)); + } + binding.PropertyChanged += (x, args) => + { + if (nameof(binding.IsObserved).Equals(args.PropertyName)) + UpdateBinding(key, binding); + }; + return binding; + } + + private void UpdateBinding(EntityViewModel key, MtObservableList binding) + { + if (!binding.IsObserved) + return; + + lock (this) + { + foreach (ModelFactory factory in _modelFactories) + { + var result = factory.GetOrCreate(key); + if (result != null && !binding.Contains(result)) + binding.Add(result); + } + } + } + } +} diff --git a/Torch.Server/Torch.Server.csproj b/Torch.Server/Torch.Server.csproj index 3208cd3..ea5c705 100644 --- a/Torch.Server/Torch.Server.csproj +++ b/Torch.Server/Torch.Server.csproj @@ -195,6 +195,7 @@ Properties\AssemblyVersion.cs + @@ -214,6 +215,13 @@ + + + EntityControlHost.xaml + + + EntityControlsView.xaml + @@ -305,6 +313,14 @@ + + Designer + MSBuild:Compile + + + Designer + MSBuild:Compile + Designer MSBuild:Compile @@ -333,6 +349,10 @@ Designer MSBuild:Compile + + Designer + MSBuild:Compile + Designer MSBuild:Compile @@ -345,10 +365,6 @@ Designer MSBuild:Compile - - Designer - MSBuild:Compile - MSBuild:Compile Designer diff --git a/Torch.Server/TorchServer.cs b/Torch.Server/TorchServer.cs index 44729df..502d94b 100644 --- a/Torch.Server/TorchServer.cs +++ b/Torch.Server/TorchServer.cs @@ -72,6 +72,7 @@ namespace Torch.Server { DedicatedInstance = new InstanceManager(this); AddManager(DedicatedInstance); + AddManager(new EntityControlManager(this)); Config = config ?? new TorchConfig(); var sessionManager = Managers.GetManager(); @@ -102,9 +103,19 @@ namespace Torch.Server MyPlugins.Load(); MyGlobalTypeMetadata.Static.Init(); + Managers.GetManager().SessionStateChanged += OnSessionStateChanged; GetManager().LoadInstance(Config.InstancePath); } + private void OnSessionStateChanged(ITorchSession session, TorchSessionState newState) + { + if (newState == TorchSessionState.Unloading || newState == TorchSessionState.Unloaded) + { + _watchdog?.Dispose(); + _watchdog = null; + } + } + private void InvokeBeforeRun() { MySandboxGame.Log.Init("SpaceEngineers-Dedicated.log", MyFinalBuildConstants.APP_VERSION_STRING); @@ -202,11 +213,18 @@ namespace Torch.Server ((TorchServer)state).Invoke(() => mre.Set()); if (!mre.WaitOne(TimeSpan.FromSeconds(Instance.Config.TickTimeout))) { +#if DEBUG + Log.Error($"Server watchdog detected that the server was frozen for at least {((TorchServer)state).Config.TickTimeout} seconds."); + Log.Error(DumpFrozenThread(MySandboxGame.Static.UpdateThread)); +#else Log.Error(DumpFrozenThread(MySandboxGame.Static.UpdateThread)); throw new TimeoutException($"Server watchdog detected that the server was frozen for at least {((TorchServer)state).Config.TickTimeout} seconds."); +#endif + } + else + { + Log.Debug("Server watchdog responded"); } - - Log.Debug("Server watchdog responded"); } private static string DumpFrozenThread(Thread thread, int traces = 3, int pause = 5000) diff --git a/Torch.Server/ViewModels/Entities/EntityControlViewModel.cs b/Torch.Server/ViewModels/Entities/EntityControlViewModel.cs new file mode 100644 index 0000000..3258365 --- /dev/null +++ b/Torch.Server/ViewModels/Entities/EntityControlViewModel.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Windows; + +namespace Torch.Server.ViewModels.Entities +{ + public class EntityControlViewModel : ViewModel + { + internal const string SignalPropertyInvalidateControl = + "InvalidateControl-4124a476-704f-4762-8b5e-336a18e2f7e5"; + + internal void InvalidateControl() + { + // ReSharper disable once ExplicitCallerInfoArgument + OnPropertyChanged(SignalPropertyInvalidateControl); + } + + private bool _hide; + + /// + /// Should this element be forced into the + /// + public bool Hide + { + get => _hide; + protected set + { + if (_hide == value) + return; + _hide = value; + OnPropertyChanged(); + } + } + } +} diff --git a/Torch.Server/ViewModels/Entities/EntityViewModel.cs b/Torch.Server/ViewModels/Entities/EntityViewModel.cs index 08ad9ae..05eabbc 100644 --- a/Torch.Server/ViewModels/Entities/EntityViewModel.cs +++ b/Torch.Server/ViewModels/Entities/EntityViewModel.cs @@ -1,4 +1,8 @@ -using VRage.Game.ModAPI; +using System.Windows.Controls; +using Torch.API.Managers; +using Torch.Collections; +using Torch.Server.Managers; +using VRage.Game.ModAPI; using VRage.ModAPI; using VRageMath; @@ -7,9 +11,25 @@ namespace Torch.Server.ViewModels.Entities public class EntityViewModel : ViewModel { protected EntityTreeViewModel Tree { get; } - public IMyEntity Entity { get; } + + private IMyEntity _backing; + public IMyEntity Entity + { + get => _backing; + protected set + { + _backing = value; + OnPropertyChanged(); + EntityControls = TorchBase.Instance?.Managers.GetManager()?.BoundModels(this); + // ReSharper disable once ExplicitCallerInfoArgument + OnPropertyChanged(nameof(EntityControls)); + } + } + public long Id => Entity.EntityId; + public MtObservableList EntityControls { get; private set; } + public virtual string Name { get => Entity.DisplayName; diff --git a/Torch.Server/ViewModels/Entities/GridViewModel.cs b/Torch.Server/ViewModels/Entities/GridViewModel.cs index b591c99..8d06f26 100644 --- a/Torch.Server/ViewModels/Entities/GridViewModel.cs +++ b/Torch.Server/ViewModels/Entities/GridViewModel.cs @@ -2,6 +2,7 @@ using System.Linq; using Sandbox.Game.Entities; using Sandbox.ModAPI; +using Torch.API.Managers; using Torch.Collections; using Torch.Server.ViewModels.Blocks; diff --git a/Torch.Server/ViewModels/EntityTreeViewModel.cs b/Torch.Server/ViewModels/EntityTreeViewModel.cs index 87b00cf..76de7da 100644 --- a/Torch.Server/ViewModels/EntityTreeViewModel.cs +++ b/Torch.Server/ViewModels/EntityTreeViewModel.cs @@ -17,6 +17,8 @@ namespace Torch.Server.ViewModels { public class EntityTreeViewModel : ViewModel { + private static readonly Logger _log = LogManager.GetCurrentClassLogger(); + //TODO: these should be sorted sets for speed public MtObservableList Grids { get; set; } = new MtObservableList(); public MtObservableList Characters { get; set; } = new MtObservableList(); @@ -46,39 +48,55 @@ namespace Torch.Server.ViewModels private void MyEntities_OnEntityRemove(VRage.Game.Entity.MyEntity obj) { - switch (obj) + try { - case MyCubeGrid grid: - Grids.RemoveWhere(m => m.Id == grid.EntityId); - break; - case MyCharacter character: - Characters.RemoveWhere(m => m.Id == character.EntityId); - break; - case MyFloatingObject floating: - FloatingObjects.RemoveWhere(m => m.Id == floating.EntityId); - break; - case MyVoxelBase voxel: - VoxelMaps.RemoveWhere(m => m.Id == voxel.EntityId); - break; + switch (obj) + { + case MyCubeGrid grid: + Grids.RemoveWhere(m => m.Id == grid.EntityId); + break; + case MyCharacter character: + Characters.RemoveWhere(m => m.Id == character.EntityId); + break; + case MyFloatingObject floating: + FloatingObjects.RemoveWhere(m => m.Id == floating.EntityId); + break; + case MyVoxelBase voxel: + VoxelMaps.RemoveWhere(m => m.Id == voxel.EntityId); + break; + } + } + catch (Exception e) + { + _log.Error(e); + // ignore error "it's only UI" } } private void MyEntities_OnEntityAdd(VRage.Game.Entity.MyEntity obj) { - switch (obj) + try { - case MyCubeGrid grid: - Grids.Add(new GridViewModel(grid, this)); - break; - case MyCharacter character: - Characters.Add(new CharacterViewModel(character, this)); - break; - case MyFloatingObject floating: - FloatingObjects.Add(new FloatingObjectViewModel(floating, this)); - break; - case MyVoxelBase voxel: - VoxelMaps.Add(new VoxelMapViewModel(voxel, this)); - break; + switch (obj) + { + case MyCubeGrid grid: + Grids.Add(new GridViewModel(grid, this)); + break; + case MyCharacter character: + Characters.Add(new CharacterViewModel(character, this)); + break; + case MyFloatingObject floating: + FloatingObjects.Add(new FloatingObjectViewModel(floating, this)); + break; + case MyVoxelBase voxel: + VoxelMaps.Add(new VoxelMapViewModel(voxel, this)); + break; + } + } + catch (Exception e) + { + _log.Error(e); + // ignore error "it's only UI" } } } diff --git a/Torch.Server/Views/ChatControl.xaml.cs b/Torch.Server/Views/ChatControl.xaml.cs index 6c966cc..a6c1cb7 100644 --- a/Torch.Server/Views/ChatControl.xaml.cs +++ b/Torch.Server/Views/ChatControl.xaml.cs @@ -44,7 +44,7 @@ namespace Torch.Server public void BindServer(ITorchServer server) { _server = (TorchBase)server; - Dispatcher.Invoke(() => + Dispatcher.InvokeAsync(() => { ChatItems.Inlines.Clear(); }); @@ -59,7 +59,7 @@ namespace Torch.Server switch (state) { case TorchSessionState.Loading: - Dispatcher.Invoke(() => ChatItems.Inlines.Clear()); + Dispatcher.InvokeAsync(() => ChatItems.Inlines.Clear()); break; case TorchSessionState.Loaded: { @@ -112,7 +112,7 @@ namespace Torch.Server ChatScroller.ScrollToBottom(); } else - Dispatcher.Invoke(() => InsertMessage(msg)); + Dispatcher.InvokeAsync(() => InsertMessage(msg)); } private void SendButton_Click(object sender, RoutedEventArgs e) diff --git a/Torch.Server/Views/Entities/Blocks/BlockView.xaml b/Torch.Server/Views/Entities/Blocks/BlockView.xaml index 8020097..53cca60 100644 --- a/Torch.Server/Views/Entities/Blocks/BlockView.xaml +++ b/Torch.Server/Views/Entities/Blocks/BlockView.xaml @@ -5,6 +5,8 @@ xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:local="clr-namespace:Torch.Server.Views.Blocks" xmlns:blocks="clr-namespace:Torch.Server.ViewModels.Blocks" + xmlns:entities="clr-namespace:Torch.Server.Views.Entities" + xmlns:entities1="clr-namespace:Torch.Server.ViewModels.Entities" mc:Ignorable="d"> @@ -12,6 +14,7 @@ + @@ -22,22 +25,27 @@ \ No newline at end of file diff --git a/Torch.Server/Views/Entities/EntityControlHost.xaml b/Torch.Server/Views/Entities/EntityControlHost.xaml new file mode 100644 index 0000000..a7b11af --- /dev/null +++ b/Torch.Server/Views/Entities/EntityControlHost.xaml @@ -0,0 +1,8 @@ + + diff --git a/Torch.Server/Views/Entities/EntityControlHost.xaml.cs b/Torch.Server/Views/Entities/EntityControlHost.xaml.cs new file mode 100644 index 0000000..f4a080f --- /dev/null +++ b/Torch.Server/Views/Entities/EntityControlHost.xaml.cs @@ -0,0 +1,72 @@ +using System.ComponentModel; +using System.Threading; +using System.Windows; +using System.Windows.Controls; +using Torch.Server.Managers; +using Torch.API.Managers; +using Torch.Server.ViewModels.Entities; + +namespace Torch.Server.Views.Entities +{ + /// + /// Interaction logic for EntityControlHost.xaml + /// + public partial class EntityControlHost : UserControl + { + public EntityControlHost() + { + InitializeComponent(); + DataContextChanged += OnDataContextChanged; + } + + private void OnDataContextChanged(object sender, DependencyPropertyChangedEventArgs e) + { + if (e.OldValue is ViewModel vmo) + { + vmo.PropertyChanged -= DataContext_OnPropertyChanged; + } + if (e.NewValue is ViewModel vmn) + { + vmn.PropertyChanged += DataContext_OnPropertyChanged; + } + RefreshControl(); + } + + private void DataContext_OnPropertyChanged(object sender, PropertyChangedEventArgs pa) + { + if (pa.PropertyName.Equals(EntityControlViewModel.SignalPropertyInvalidateControl)) + RefreshControl(); + else if (pa.PropertyName.Equals(nameof(EntityControlViewModel.Hide))) + RefreshVisibility(); + } + + private Control _currentControl; + + private void RefreshControl() + { + if (Dispatcher.Thread != Thread.CurrentThread) + { + Dispatcher.InvokeAsync(RefreshControl); + return; + } + + _currentControl = DataContext is EntityControlViewModel ecvm + ? TorchBase.Instance?.Managers.GetManager()?.CreateControl(ecvm) + : null; + Content = _currentControl; + RefreshVisibility(); + } + + private void RefreshVisibility() + { + if (Dispatcher.Thread != Thread.CurrentThread) + { + Dispatcher.InvokeAsync(RefreshVisibility); + return; + } + Visibility = (DataContext is EntityControlViewModel ecvm) && !ecvm.Hide && _currentControl != null + ? Visibility.Visible + : Visibility.Collapsed; + } + } +} diff --git a/Torch.Server/Views/Entities/EntityControlsView.xaml b/Torch.Server/Views/Entities/EntityControlsView.xaml new file mode 100644 index 0000000..9f4fba1 --- /dev/null +++ b/Torch.Server/Views/Entities/EntityControlsView.xaml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + diff --git a/Torch.Server/Views/Entities/EntityControlsView.xaml.cs b/Torch.Server/Views/Entities/EntityControlsView.xaml.cs new file mode 100644 index 0000000..b490a5a --- /dev/null +++ b/Torch.Server/Views/Entities/EntityControlsView.xaml.cs @@ -0,0 +1,15 @@ +using System.Windows.Controls; + +namespace Torch.Server.Views.Entities +{ + /// + /// Interaction logic for EntityControlsView.xaml + /// + public partial class EntityControlsView : ItemsControl + { + public EntityControlsView() + { + InitializeComponent(); + } + } +} diff --git a/Torch.Server/Views/Entities/GridView.xaml b/Torch.Server/Views/Entities/GridView.xaml index e4a63b6..f183117 100644 --- a/Torch.Server/Views/Entities/GridView.xaml +++ b/Torch.Server/Views/Entities/GridView.xaml @@ -3,20 +3,28 @@ xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" - xmlns:local="clr-namespace:Torch.Server.Views.Entities" xmlns:entities="clr-namespace:Torch.Server.ViewModels.Entities" + xmlns:local="clr-namespace:Torch.Server.Views.Entities" mc:Ignorable="d"> - - + + + + + + + - + - + + + + \ No newline at end of file diff --git a/Torch.Server/Views/Entities/VoxelMapView.xaml b/Torch.Server/Views/Entities/VoxelMapView.xaml index 4333edf..3c0409b 100644 --- a/Torch.Server/Views/Entities/VoxelMapView.xaml +++ b/Torch.Server/Views/Entities/VoxelMapView.xaml @@ -9,14 +9,23 @@ - - - - - - - - - - + + + + + + + + + + + + + + + + + + + diff --git a/Torch.Server/Views/PlayerListControl.xaml.cs b/Torch.Server/Views/PlayerListControl.xaml.cs index a920eb3..49c40a7 100644 --- a/Torch.Server/Views/PlayerListControl.xaml.cs +++ b/Torch.Server/Views/PlayerListControl.xaml.cs @@ -57,10 +57,10 @@ namespace Torch.Server switch (newState) { case TorchSessionState.Loaded: - Dispatcher.Invoke(() => DataContext = _server?.CurrentSession?.Managers.GetManager()); + Dispatcher.InvokeAsync(() => DataContext = _server?.CurrentSession?.Managers.GetManager()); break; case TorchSessionState.Unloading: - Dispatcher.Invoke(() => DataContext = null); + Dispatcher.InvokeAsync(() => DataContext = null); break; } } diff --git a/Torch/Collections/MTObservableCollection.cs b/Torch/Collections/MTObservableCollection.cs index 6cc3d38..dec4b60 100644 --- a/Torch/Collections/MTObservableCollection.cs +++ b/Torch/Collections/MTObservableCollection.cs @@ -3,13 +3,10 @@ using System.Collections; using System.Collections.Generic; using System.Collections.Specialized; using System.ComponentModel; -using System.Diagnostics; using System.Linq; -using System.Reflection; using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Windows.Threading; using Torch.Utils; namespace Torch.Collections @@ -35,6 +32,11 @@ namespace Torch.Collections _threadViews = new ThreadLocal(() => new ThreadView(this)); } + /// + /// Should this observable collection actually dispatch events. + /// + public bool NotificationsEnabled { get; protected set; } = true; + /// /// Takes a snapshot of this collection. Note: This call is only done when a read lock is acquired. /// @@ -54,7 +56,7 @@ namespace Torch.Collections /// public void Add(TV item) { - using(Lock.WriteUsing()) + using (Lock.WriteUsing()) { Backing.Add(item); MarkSnapshotsDirty(); @@ -66,7 +68,7 @@ namespace Torch.Collections /// public void Clear() { - using(Lock.WriteUsing()) + using (Lock.WriteUsing()) { Backing.Clear(); MarkSnapshotsDirty(); @@ -92,11 +94,13 @@ namespace Torch.Collections /// public bool Remove(TV item) { - using(Lock.UpgradableReadUsing()) { + using (Lock.UpgradableReadUsing()) + { int? oldIndex = (Backing as IList)?.IndexOf(item); if (oldIndex == -1) return false; - using(Lock.WriteUsing()) { + using (Lock.WriteUsing()) + { if (!Backing.Remove(item)) return false; MarkSnapshotsDirty(); @@ -125,6 +129,56 @@ namespace Torch.Collections #endregion #region Event Wrappers + private readonly WeakReference _deferredSnapshot = new WeakReference(null); + private bool _deferredSnapshotTaken = false; + /// + /// Disposable that stops update signals and signals a full refresh when disposed. + /// + public IDisposable DeferredUpdate() + { + using (Lock.WriteUsing()) + { + if (_deferredSnapshotTaken) + return new DummyToken(); + DeferredUpdateToken token; + if (!_deferredSnapshot.TryGetTarget(out token)) + _deferredSnapshot.SetTarget(token = new DeferredUpdateToken()); + token.SetCollection(this); + return token; + } + } + + private struct DummyToken : IDisposable + { + public void Dispose() + { + } + } + + private class DeferredUpdateToken : IDisposable + { + private MtObservableCollection _collection; + + internal void SetCollection(MtObservableCollection c) + { + c._deferredSnapshotTaken = true; + _collection = c; + c.NotificationsEnabled = false; + } + + public void Dispose() + { + using (_collection.Lock.WriteUsing()) + { + _collection.NotificationsEnabled = true; + _collection.OnPropertyChanged(nameof(Count)); + _collection.OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Reset)); + _collection._deferredSnapshotTaken = false; + } + + } + } + protected void OnPropertyChanged(string propName) { NotifyEvent(this, new PropertyChangedEventArgs(propName)); @@ -133,108 +187,60 @@ namespace Torch.Collections protected void OnCollectionChanged(NotifyCollectionChangedEventArgs e) { NotifyEvent(this, e); + OnPropertyChanged("Item[]"); } protected void NotifyEvent(object sender, PropertyChangedEventArgs args) { - _propertyChangedEvent.Raise(sender, args); + if (NotificationsEnabled) + _propertyChangedEvent.Raise(sender, args); } protected void NotifyEvent(object sender, NotifyCollectionChangedEventArgs args) { - _collectionChangedEvent.Raise(sender, args); + if (NotificationsEnabled) + _collectionChangedEvent.Raise(sender, args); } - private readonly DispatcherEvent _propertyChangedEvent = - new DispatcherEvent(); + private readonly MtObservableEvent _propertyChangedEvent = + new MtObservableEvent(); /// public event PropertyChangedEventHandler PropertyChanged { - add => _propertyChangedEvent.Add(value); - remove => _propertyChangedEvent.Remove(value); + add + { + _propertyChangedEvent.Add(value); + OnPropertyChanged(nameof(IsObserved)); + } + remove + { + _propertyChangedEvent.Remove(value); + OnPropertyChanged(nameof(IsObserved)); + } } - private readonly DispatcherEvent _collectionChangedEvent = - new DispatcherEvent(); + private readonly MtObservableEvent _collectionChangedEvent = + new MtObservableEvent(); /// public event NotifyCollectionChangedEventHandler CollectionChanged { - add => _collectionChangedEvent.Add(value); - remove => _collectionChangedEvent.Remove(value); - } - /// - /// Event that invokes handlers registered by dispatchers on dispatchers. - /// - /// Event argument type - /// Event handler delegate type - private sealed class DispatcherEvent where TEvtArgs : EventArgs - { - private delegate void DelInvokeHandler(TEvtHandle handler, object sender, TEvtArgs args); - - private static readonly DelInvokeHandler _invokeDirectly; - static DispatcherEvent() + add { - MethodInfo invoke = typeof(TEvtHandle).GetMethod("Invoke", BindingFlags.Instance | BindingFlags.Public | BindingFlags.DeclaredOnly); - Debug.Assert(invoke != null, "No invoke method on handler type"); - _invokeDirectly = (DelInvokeHandler)Delegate.CreateDelegate(typeof(DelInvokeHandler), invoke); + _collectionChangedEvent.Add(value); + OnPropertyChanged(nameof(IsObserved)); } - - private static Dispatcher CurrentDispatcher => Dispatcher.FromThread(Thread.CurrentThread); - - - private event EventHandler _event; - - internal void Raise(object sender, TEvtArgs args) + remove { - _event?.Invoke(sender, args); - } - - internal void Add(TEvtHandle evt) - { - if (evt == null) - return; - _event += new DispatcherDelegate(evt).Invoke; - } - - internal void Remove(TEvtHandle evt) - { - if (_event == null || evt == null) - return; - Delegate[] invokeList = _event.GetInvocationList(); - for (int i = invokeList.Length - 1; i >= 0; i--) - { - var wrapper = (DispatcherDelegate)invokeList[i].Target; - if (wrapper._delegate.Equals(evt)) - { - _event -= wrapper.Invoke; - return; - } - } - } - - private struct DispatcherDelegate - { - private readonly Dispatcher _dispatcher; - internal readonly TEvtHandle _delegate; - - internal DispatcherDelegate(TEvtHandle del) - { - _dispatcher = CurrentDispatcher; - _delegate = del; - } - - public void Invoke(object sender, TEvtArgs args) - { - if (_dispatcher == null || _dispatcher == CurrentDispatcher) - _invokeDirectly(_delegate, sender, args); - else - // (Delegate) (object) == dual cast so that the compiler likes it - _dispatcher.BeginInvoke((Delegate)(object)_delegate, DispatcherPriority.DataBind, sender, args); - } + _collectionChangedEvent.Remove(value); + OnPropertyChanged(nameof(IsObserved)); } } - #endregion + + /// + /// Is this collection observed by any listeners. + /// + public bool IsObserved => _collectionChangedEvent.IsObserved || _propertyChangedEvent.IsObserved; #region Enumeration /// diff --git a/Torch/Collections/MtObservableDictionary.cs b/Torch/Collections/MtObservableDictionary.cs index feeda30..00bf2d7 100644 --- a/Torch/Collections/MtObservableDictionary.cs +++ b/Torch/Collections/MtObservableDictionary.cs @@ -91,12 +91,6 @@ namespace Torch.Collections /// private ProxyCollection ObservableValues { get; } - internal void RaiseFullReset() - { - OnPropertyChanged(nameof(Count)); - OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Reset)); - } - /// /// Proxy collection capable of raising notifications when the parent collection changes. /// diff --git a/Torch/Collections/MtObservableEvent.cs b/Torch/Collections/MtObservableEvent.cs new file mode 100644 index 0000000..4b8857a --- /dev/null +++ b/Torch/Collections/MtObservableEvent.cs @@ -0,0 +1,102 @@ +using System; +using System.Diagnostics; +using System.Reflection; +using System.Threading; +using System.Windows.Threading; + +namespace Torch.Collections +{ + /// + /// Event that invokes handlers registered by dispatchers on dispatchers. + /// + /// Event argument type + /// Event handler delegate type + public sealed class MtObservableEvent where TEvtArgs : EventArgs + { + private delegate void DelInvokeHandler(TEvtHandle handler, object sender, TEvtArgs args); + + private static readonly DelInvokeHandler _invokeDirectly; + static MtObservableEvent() + { + MethodInfo invoke = typeof(TEvtHandle).GetMethod("Invoke", BindingFlags.Instance | BindingFlags.Public | BindingFlags.DeclaredOnly); + Debug.Assert(invoke != null, "No invoke method on handler type"); + _invokeDirectly = (DelInvokeHandler)Delegate.CreateDelegate(typeof(DelInvokeHandler), invoke); + } + + private static Dispatcher CurrentDispatcher => Dispatcher.FromThread(Thread.CurrentThread); + + + private event EventHandler Event; + + private int _observerCount = 0; + + /// + /// Determines if this event has an observers. + /// + public bool IsObserved => _observerCount > 0; + + /// + /// Raises this event for the given sender, with the given args + /// + /// sender + /// args + public void Raise(object sender, TEvtArgs args) + { + Event?.Invoke(sender, args); + } + + /// + /// Adds the given event handler. + /// + /// + public void Add(TEvtHandle evt) + { + if (evt == null) + return; + _observerCount++; + Event += new DispatcherDelegate(evt).Invoke; + } + + /// + /// Removes the given event handler + /// + /// + public void Remove(TEvtHandle evt) + { + if (Event == null || evt == null) + return; + Delegate[] invokeList = Event.GetInvocationList(); + for (int i = invokeList.Length - 1; i >= 0; i--) + { + var wrapper = (DispatcherDelegate)invokeList[i].Target; + if (wrapper._delegate.Equals(evt)) + { + Event -= wrapper.Invoke; + _observerCount--; + return; + } + } + } + + private struct DispatcherDelegate + { + private readonly Dispatcher _dispatcher; + internal readonly TEvtHandle _delegate; + + internal DispatcherDelegate(TEvtHandle del) + { + _dispatcher = CurrentDispatcher; + _delegate = del; + } + + public void Invoke(object sender, TEvtArgs args) + { + if (_dispatcher == null || _dispatcher == CurrentDispatcher) + _invokeDirectly(_delegate, sender, args); + else + // (Delegate) (object) == dual cast so that the compiler likes it + _dispatcher.BeginInvoke((Delegate)(object)_delegate, DispatcherPriority.DataBind, sender, args); + } + } + } +} \ No newline at end of file diff --git a/Torch/Collections/MtObservableList.cs b/Torch/Collections/MtObservableList.cs index ec64c64..b290d03 100644 --- a/Torch/Collections/MtObservableList.cs +++ b/Torch/Collections/MtObservableList.cs @@ -106,7 +106,8 @@ namespace Torch.Collections /// public void Sort(Func selector, IComparer comparer = null) { - using (Lock.ReadUsing()) + using (DeferredUpdate()) + using (Lock.WriteUsing()) { comparer = comparer ?? Comparer.Default; if (Backing is List lst) @@ -118,8 +119,6 @@ namespace Torch.Collections foreach (T v in sortedItems) Backing.Add(v); } - - OnCollectionChanged(new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Reset)); } } } diff --git a/Torch/Managers/PatchManager/DecoratedMethod.cs b/Torch/Managers/PatchManager/DecoratedMethod.cs index ec0814c..22e3073 100644 --- a/Torch/Managers/PatchManager/DecoratedMethod.cs +++ b/Torch/Managers/PatchManager/DecoratedMethod.cs @@ -27,22 +27,38 @@ namespace Torch.Managers.PatchManager private byte[] _revertData = null; private GCHandle? _pinnedPatch; + internal bool HasChanged() + { + return Prefixes.HasChanges() || Suffixes.HasChanges() || Transpilers.HasChanges() || PostTranspilers.HasChanges(); + } + internal void Commit() { - if (!Prefixes.HasChanges() && !Suffixes.HasChanges() && !Transpilers.HasChanges()) - return; - Revert(); + try + { + // non-greedy so they are all reset + if (!Prefixes.HasChanges(true) & !Suffixes.HasChanges(true) & !Transpilers.HasChanges(true) & !PostTranspilers.HasChanges(true)) + return; + Revert(); - if (Prefixes.Count == 0 && Suffixes.Count == 0 && Transpilers.Count == 0) - return; - _log.Debug($"Begin patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); - var patch = ComposePatchedMethod(); + if (Prefixes.Count == 0 && Suffixes.Count == 0 && Transpilers.Count == 0 && PostTranspilers.Count == 0) + return; + _log.Log(PrintMsil ? LogLevel.Info : LogLevel.Debug, + $"Begin patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); + var patch = ComposePatchedMethod(); - _revertAddress = AssemblyMemory.GetMethodBodyStart(_method); - var newAddress = AssemblyMemory.GetMethodBodyStart(patch); - _revertData = AssemblyMemory.WriteJump(_revertAddress, newAddress); - _pinnedPatch = GCHandle.Alloc(patch); - _log.Debug($"Done patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); + _revertAddress = AssemblyMemory.GetMethodBodyStart(_method); + var newAddress = AssemblyMemory.GetMethodBodyStart(patch); + _revertData = AssemblyMemory.WriteJump(_revertAddress, newAddress); + _pinnedPatch = GCHandle.Alloc(patch); + _log.Log(PrintMsil ? LogLevel.Info : LogLevel.Debug, + $"Done patching {_method.DeclaringType?.FullName}#{_method.Name}({string.Join(", ", _method.GetParameters().Select(x => x.ParameterType.Name))})"); + } + catch (Exception exception) + { + _log.Fatal(exception, $"Error patching {_method.DeclaringType?.FullName}#{_method}"); + throw; + } } internal void Revert() @@ -95,100 +111,119 @@ namespace Torch.Managers.PatchManager public DynamicMethod ComposePatchedMethod() { DynamicMethod method = AllocatePatchMethod(); - var generator = new LoggingIlGenerator(method.GetILGenerator()); - EmitPatched(generator); + var generator = new LoggingIlGenerator(method.GetILGenerator(), PrintMsil ? LogLevel.Info : LogLevel.Trace); + List il = EmitPatched((type, pinned) => new MsilLocal(generator.DeclareLocal(type, pinned))).ToList(); + if (PrintMsil) + { + lock (_log) + { + MethodTranspiler.IntegrityAnalysis(LogLevel.Info, il); + } + } + MethodTranspiler.EmitMethod(il, generator); - // Force it to compile - RuntimeMethodHandle handle = _getMethodHandle.Invoke(method); - object runtimeMethodInfo = _getMethodInfo.Invoke(handle); - _compileDynamicMethod.Invoke(runtimeMethodInfo); + try + { + // Force it to compile + RuntimeMethodHandle handle = _getMethodHandle.Invoke(method); + object runtimeMethodInfo = _getMethodInfo.Invoke(handle); + _compileDynamicMethod.Invoke(runtimeMethodInfo); + } + catch + { + lock (_log) + { + var ctx = new MethodContext(method); + ctx.Read(); + MethodTranspiler.IntegrityAnalysis(LogLevel.Warn, ctx.Instructions); + } + throw; + } return method; } #endregion #region Emit - private void EmitPatched(LoggingIlGenerator target) + private IEnumerable EmitPatched(Func declareLocal) { - var originalLocalVariables = _method.GetMethodBody().LocalVariables - .Select(x => - { - Debug.Assert(x.LocalType != null); - return target.DeclareLocal(x.LocalType, x.IsPinned); - }).ToArray(); + var methodBody = _method.GetMethodBody(); + Debug.Assert(methodBody != null, "Method body is null"); + foreach (var localVar in methodBody.LocalVariables) + { + Debug.Assert(localVar.LocalType != null); + declareLocal(localVar.LocalType, localVar.IsPinned); + } + var instructions = new List(); + var specialVariables = new Dictionary(); - var specialVariables = new Dictionary(); - - Label labelAfterOriginalContent = target.DefineLabel(); - Label labelSkipMethodContent = target.DefineLabel(); + var labelAfterOriginalContent = new MsilLabel(); + var labelSkipMethodContent = new MsilLabel(); Type returnType = _method is MethodInfo meth ? meth.ReturnType : typeof(void); - LocalBuilder resultVariable = null; + MsilLocal resultVariable = null; if (returnType != typeof(void)) { - if (Prefixes.Concat(Suffixes).SelectMany(x => x.GetParameters()).Any(x => x.Name == RESULT_PARAMETER)) - resultVariable = target.DeclareLocal(returnType); - else if (Prefixes.Any(x => x.ReturnType == typeof(bool))) - resultVariable = target.DeclareLocal(returnType); + if (Prefixes.Concat(Suffixes).SelectMany(x => x.GetParameters()).Any(x => x.Name == RESULT_PARAMETER) + || Prefixes.Any(x => x.ReturnType == typeof(bool))) + resultVariable = declareLocal(returnType, false); } - resultVariable?.SetToDefault(target); - LocalBuilder prefixSkippedVariable = null; + if (resultVariable != null) + instructions.AddRange(resultVariable.SetToDefault()); + MsilLocal prefixSkippedVariable = null; if (Prefixes.Count > 0 && Suffixes.Any(x => x.GetParameters() .Any(y => y.Name.Equals(PREFIX_SKIPPED_PARAMETER)))) { - prefixSkippedVariable = target.DeclareLocal(typeof(bool)); + prefixSkippedVariable = declareLocal(typeof(bool), false); specialVariables.Add(PREFIX_SKIPPED_PARAMETER, prefixSkippedVariable); } if (resultVariable != null) specialVariables.Add(RESULT_PARAMETER, resultVariable); - target.EmitComment("Prefixes Begin"); foreach (MethodInfo prefix in Prefixes) { - EmitMonkeyCall(target, prefix, specialVariables); + instructions.AddRange(EmitMonkeyCall(prefix, specialVariables)); if (prefix.ReturnType == typeof(bool)) - target.Emit(OpCodes.Brfalse, labelSkipMethodContent); + instructions.Add(new MsilInstruction(OpCodes.Brfalse).InlineTarget(labelSkipMethodContent)); else if (prefix.ReturnType != typeof(void)) throw new Exception( $"Prefixes must return void or bool. {prefix.DeclaringType?.FullName}.{prefix.Name} returns {prefix.ReturnType}"); } - target.EmitComment("Prefixes End"); + instructions.AddRange(MethodTranspiler.Transpile(_method, (x) => declareLocal(x, false), Transpilers, labelAfterOriginalContent)); - target.EmitComment("Original Begin"); - MethodTranspiler.Transpile(_method, (type) => new MsilLocal(target.DeclareLocal(type)), Transpilers, target, labelAfterOriginalContent); - target.EmitComment("Original End"); - - target.MarkLabel(labelAfterOriginalContent); + instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(labelAfterOriginalContent)); if (resultVariable != null) - target.Emit(OpCodes.Stloc, resultVariable); - Label notSkip = target.DefineLabel(); - target.Emit(OpCodes.Br, notSkip); - target.MarkLabel(labelSkipMethodContent); + instructions.Add(new MsilInstruction(OpCodes.Stloc).InlineValue(resultVariable)); + var notSkip = new MsilLabel(); + instructions.Add(new MsilInstruction(OpCodes.Br).InlineTarget(notSkip)); + instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(labelSkipMethodContent)); if (prefixSkippedVariable != null) { - target.Emit(OpCodes.Ldc_I4_1); - target.Emit(OpCodes.Stloc, prefixSkippedVariable); + instructions.Add(new MsilInstruction(OpCodes.Ldc_I4_1)); + instructions.Add(new MsilInstruction(OpCodes.Stloc).InlineValue(prefixSkippedVariable)); } - target.MarkLabel(notSkip); + instructions.Add(new MsilInstruction(OpCodes.Nop).LabelWith(notSkip)); - target.EmitComment("Suffixes Begin"); foreach (MethodInfo suffix in Suffixes) { - EmitMonkeyCall(target, suffix, specialVariables); + instructions.AddRange(EmitMonkeyCall(suffix, specialVariables)); if (suffix.ReturnType != typeof(void)) throw new Exception($"Suffixes must return void. {suffix.DeclaringType?.FullName}.{suffix.Name} returns {suffix.ReturnType}"); } - target.EmitComment("Suffixes End"); if (resultVariable != null) - target.Emit(OpCodes.Ldloc, resultVariable); - target.Emit(OpCodes.Ret); + instructions.Add(new MsilInstruction(OpCodes.Ldloc).InlineValue(resultVariable)); + instructions.Add(new MsilInstruction(OpCodes.Ret)); + + var result = MethodTranspiler.Transpile(_method, instructions, (x) => declareLocal(x, false), PostTranspilers, null).ToList(); + if (result.Last().OpCode != OpCodes.Ret) + result.Add(new MsilInstruction(OpCodes.Ret)); + return result; } - private void EmitMonkeyCall(LoggingIlGenerator target, MethodInfo patch, - IReadOnlyDictionary specialVariables) + private IEnumerable EmitMonkeyCall(MethodInfo patch, + IReadOnlyDictionary specialVariables) { - target.EmitComment($"Call {patch.DeclaringType?.FullName}#{patch.Name}"); foreach (var param in patch.GetParameters()) { switch (param.Name) @@ -196,25 +231,26 @@ namespace Torch.Managers.PatchManager case INSTANCE_PARAMETER: if (_method.IsStatic) throw new Exception("Can't use an instance parameter for a static method"); - target.Emit(OpCodes.Ldarg_0); + yield return new MsilInstruction(OpCodes.Ldarg_0); break; case PREFIX_SKIPPED_PARAMETER: if (param.ParameterType != typeof(bool)) throw new Exception($"Prefix skipped parameter {param.ParameterType} must be of type bool"); if (param.ParameterType.IsByRef || param.IsOut) throw new Exception($"Prefix skipped parameter {param.ParameterType} can't be a reference type"); - if (specialVariables.TryGetValue(PREFIX_SKIPPED_PARAMETER, out LocalBuilder prefixSkip)) - target.Emit(OpCodes.Ldloc, prefixSkip); + if (specialVariables.TryGetValue(PREFIX_SKIPPED_PARAMETER, out MsilLocal prefixSkip)) + yield return new MsilInstruction(OpCodes.Ldloc).InlineValue(prefixSkip); else - target.Emit(OpCodes.Ldc_I4_0); + yield return new MsilInstruction(OpCodes.Ldc_I4_0); break; case RESULT_PARAMETER: Type retType = param.ParameterType.IsByRef ? param.ParameterType.GetElementType() : param.ParameterType; - if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].LocalType)) - throw new Exception($"Return type {specialVariables[RESULT_PARAMETER].LocalType} can't be assigned to result parameter type {retType}"); - target.Emit(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc, specialVariables[RESULT_PARAMETER]); + if (retType == null || !retType.IsAssignableFrom(specialVariables[RESULT_PARAMETER].Type)) + throw new Exception($"Return type {specialVariables[RESULT_PARAMETER].Type} can't be assigned to result parameter type {retType}"); + yield return new MsilInstruction(param.ParameterType.IsByRef ? OpCodes.Ldloca : OpCodes.Ldloc) + .InlineValue(specialVariables[RESULT_PARAMETER]); break; default: ParameterInfo declParam = _method.GetParameters().FirstOrDefault(x => x.Name == param.Name); @@ -225,18 +261,18 @@ namespace Torch.Managers.PatchManager bool patchByRef = param.IsOut || param.ParameterType.IsByRef; bool declByRef = declParam.IsOut || declParam.ParameterType.IsByRef; if (patchByRef == declByRef) - target.Emit(OpCodes.Ldarg, paramIdx); + yield return new MsilInstruction(OpCodes.Ldarg).InlineValue(new MsilArgument(paramIdx)); else if (patchByRef) - target.Emit(OpCodes.Ldarga, paramIdx); + yield return new MsilInstruction(OpCodes.Ldarga).InlineValue(new MsilArgument(paramIdx)); else { - target.Emit(OpCodes.Ldarg, paramIdx); - target.EmitDereference(declParam.ParameterType); + yield return new MsilInstruction(OpCodes.Ldarg).InlineValue(new MsilArgument(paramIdx)); + yield return EmitExtensions.EmitDereference(declParam.ParameterType); } break; } } - target.Emit(OpCodes.Call, patch); + yield return new MsilInstruction(OpCodes.Call).InlineValue(patch); } #endregion } diff --git a/Torch/Managers/PatchManager/EmitExtensions.cs b/Torch/Managers/PatchManager/EmitExtensions.cs index 90f1ffd..dad1a1a 100644 --- a/Torch/Managers/PatchManager/EmitExtensions.cs +++ b/Torch/Managers/PatchManager/EmitExtensions.cs @@ -1,6 +1,8 @@ using System; +using System.Collections.Generic; using System.Diagnostics; using System.Reflection.Emit; +using Torch.Managers.PatchManager.MSIL; using Torch.Managers.PatchManager.Transpile; namespace Torch.Managers.PatchManager @@ -11,65 +13,64 @@ namespace Torch.Managers.PatchManager /// Sets the given local to its default value in the given IL generator. /// /// Local to set to default - /// The IL generator - public static void SetToDefault(this LocalBuilder local, LoggingIlGenerator target) + /// Instructions + public static IEnumerable SetToDefault(this MsilLocal local) { - Debug.Assert(local.LocalType != null); - if (local.LocalType.IsEnum || local.LocalType.IsPrimitive) + Debug.Assert(local.Type != null); + if (local.Type.IsEnum || local.Type.IsPrimitive) { - if (local.LocalType == typeof(float)) - target.Emit(OpCodes.Ldc_R4, 0f); - else if (local.LocalType == typeof(double)) - target.Emit(OpCodes.Ldc_R8, 0d); - else if (local.LocalType == typeof(long) || local.LocalType == typeof(ulong)) - target.Emit(OpCodes.Ldc_I8, 0L); + if (local.Type == typeof(float)) + yield return new MsilInstruction(OpCodes.Ldc_R4).InlineValue(0f); + else if (local.Type == typeof(double)) + yield return new MsilInstruction(OpCodes.Ldc_R8).InlineValue(0d); + else if (local.Type == typeof(long) || local.Type == typeof(ulong)) + yield return new MsilInstruction(OpCodes.Ldc_I8).InlineValue(0L); else - target.Emit(OpCodes.Ldc_I4, 0); - target.Emit(OpCodes.Stloc, local); + yield return new MsilInstruction(OpCodes.Ldc_I4).InlineValue(0); + yield return new MsilInstruction(OpCodes.Stloc).InlineValue(local); } - else if (local.LocalType.IsValueType) // struct + else if (local.Type.IsValueType) // struct { - target.Emit(OpCodes.Ldloca, local); - target.Emit(OpCodes.Initobj, local.LocalType); + yield return new MsilInstruction(OpCodes.Ldloca).InlineValue(local); + yield return new MsilInstruction(OpCodes.Initobj).InlineValue(local.Type); } else // class { - target.Emit(OpCodes.Ldnull); - target.Emit(OpCodes.Stloc, local); + yield return new MsilInstruction(OpCodes.Ldnull); + yield return new MsilInstruction(OpCodes.Stloc).InlineValue(local); } } /// /// Emits a dereference for the given type. /// - /// IL Generator to emit on /// Type to dereference - public static void EmitDereference(this LoggingIlGenerator target, Type type) + /// Derference instruction + public static MsilInstruction EmitDereference(Type type) { if (type.IsByRef) type = type.GetElementType(); Debug.Assert(type != null); if (type == typeof(float)) - target.Emit(OpCodes.Ldind_R4); - else if (type == typeof(double)) - target.Emit(OpCodes.Ldind_R8); - else if (type == typeof(byte)) - target.Emit(OpCodes.Ldind_U1); - else if (type == typeof(ushort) || type == typeof(char)) - target.Emit(OpCodes.Ldind_U2); - else if (type == typeof(uint)) - target.Emit(OpCodes.Ldind_U4); - else if (type == typeof(sbyte)) - target.Emit(OpCodes.Ldind_I1); - else if (type == typeof(short)) - target.Emit(OpCodes.Ldind_I2); - else if (type == typeof(int) || type.IsEnum) - target.Emit(OpCodes.Ldind_I4); - else if (type == typeof(long) || type == typeof(ulong)) - target.Emit(OpCodes.Ldind_I8); - else - target.Emit(OpCodes.Ldind_Ref); + return new MsilInstruction(OpCodes.Ldind_R4); + if (type == typeof(double)) + return new MsilInstruction(OpCodes.Ldind_R8); + if (type == typeof(byte)) + return new MsilInstruction(OpCodes.Ldind_U1); + if (type == typeof(ushort) || type == typeof(char)) + return new MsilInstruction(OpCodes.Ldind_U2); + if (type == typeof(uint)) + return new MsilInstruction(OpCodes.Ldind_U4); + if (type == typeof(sbyte)) + return new MsilInstruction(OpCodes.Ldind_I1); + if (type == typeof(short)) + return new MsilInstruction(OpCodes.Ldind_I2); + if (type == typeof(int) || type.IsEnum) + return new MsilInstruction(OpCodes.Ldind_I4); + if (type == typeof(long) || type == typeof(ulong)) + return new MsilInstruction(OpCodes.Ldind_I8); + return new MsilInstruction(OpCodes.Ldind_Ref); } } } diff --git a/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs b/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs index 5b85c93..b8403fb 100644 --- a/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs +++ b/Torch/Managers/PatchManager/MSIL/ITokenResolver.cs @@ -62,7 +62,7 @@ namespace Torch.Managers.PatchManager.MSIL { internal static readonly NullTokenResolver Instance = new NullTokenResolver(); - private NullTokenResolver() + internal NullTokenResolver() { } diff --git a/Torch/Managers/PatchManager/MSIL/MsilArgument.cs b/Torch/Managers/PatchManager/MSIL/MsilArgument.cs index 1246546..cb2e081 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilArgument.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilArgument.cs @@ -28,9 +28,20 @@ namespace Torch.Managers.PatchManager.MSIL /// public string Name { get; } - internal MsilArgument(ParameterInfo local) + /// + /// Creates an argument from the given parameter info. + /// + /// parameter info to use + public MsilArgument(ParameterInfo local) { - Position = (((MethodBase)local.Member).IsStatic ? 0 : 1) + local.Position; + bool isStatic; + if (local.Member is FieldInfo fi) + isStatic = fi.IsStatic; + else if (local.Member is MethodBase mb) + isStatic = mb.IsStatic; + else + throw new ArgumentException("ParameterInfo.Member must be MethodBase or FieldInfo", nameof(local)); + Position = (isStatic ? 0 : 1) + local.Position; Type = local.ParameterType; Name = local.Name; } diff --git a/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs index f39072d..879343d 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilInstruction.cs @@ -34,6 +34,7 @@ namespace Torch.Managers.PatchManager.MSIL case OperandType.InlineField: Operand = new MsilOperandInline.MsilOperandReflected(this); break; + case OperandType.ShortInlineI: case OperandType.InlineI: Operand = new MsilOperandInline.MsilOperandInt32(this); break; @@ -63,16 +64,11 @@ namespace Torch.Managers.PatchManager.MSIL break; case OperandType.ShortInlineVar: case OperandType.InlineVar: - if (OpCode.Name.IndexOf("loc", StringComparison.OrdinalIgnoreCase) != -1) + if (OpCode.IsLocalStore() || OpCode.IsLocalLoad() || OpCode.IsLocalLoadByRef()) Operand = new MsilOperandInline.MsilOperandLocal(this); else Operand = new MsilOperandInline.MsilOperandArgument(this); break; - case OperandType.ShortInlineI: - Operand = OpCode == OpCodes.Ldc_I4_S - ? (MsilOperand)new MsilOperandInline.MsilOperandInt8(this) - : new MsilOperandInline.MsilOperandUInt8(this); - break; case OperandType.ShortInlineR: Operand = new MsilOperandInline.MsilOperandSingle(this); break; @@ -104,6 +100,11 @@ namespace Torch.Managers.PatchManager.MSIL /// public HashSet Labels { get; } = new HashSet(); + /// + /// The try catch operation that is performed here. + /// + public MsilTryCatchOperation TryCatchOperation { get; set; } = null; + private static readonly ConcurrentDictionary _setterInfoForInlines = new ConcurrentDictionary(); @@ -147,6 +148,7 @@ namespace Torch.Managers.PatchManager.MSIL Operand?.CopyTo(result.Operand); foreach (MsilLabel x in Labels) result.Labels.Add(x); + result.TryCatchOperation = TryCatchOperation; return result; } @@ -172,20 +174,6 @@ namespace Torch.Managers.PatchManager.MSIL return this; } - /// - /// Emits this instruction to the given generator - /// - /// Emit target - public void Emit(LoggingIlGenerator target) - { - foreach (MsilLabel label in Labels) - target.MarkLabel(label.LabelFor(target)); - if (Operand != null) - Operand.Emit(target); - else - target.Emit(OpCode); - } - /// public override string ToString() { @@ -214,6 +202,8 @@ namespace Torch.Managers.PatchManager.MSIL Operand is MsilOperandInline inline) { MethodBase op = inline.Value; + if (op == null) + return num; if (op is MethodInfo mi && mi.ReturnType != typeof(void)) num++; num -= op.GetParameters().Length; diff --git a/Torch/Managers/PatchManager/MSIL/MsilInstructionExtensions.cs b/Torch/Managers/PatchManager/MSIL/MsilInstructionExtensions.cs index 74b8912..a69ada4 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilInstructionExtensions.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilInstructionExtensions.cs @@ -19,8 +19,7 @@ namespace Torch.Managers.PatchManager.MSIL /// public static bool IsLocalLoad(this MsilInstruction me) { - return me.OpCode == OpCodes.Ldloc || me.OpCode == OpCodes.Ldloc_S || me.OpCode == OpCodes.Ldloc_0 || - me.OpCode == OpCodes.Ldloc_1 || me.OpCode == OpCodes.Ldloc_2 || me.OpCode == OpCodes.Ldloc_3; + return me.OpCode.IsLocalLoad(); } /// @@ -28,7 +27,7 @@ namespace Torch.Managers.PatchManager.MSIL /// public static bool IsLocalLoadByRef(this MsilInstruction me) { - return me.OpCode == OpCodes.Ldloca || me.OpCode == OpCodes.Ldloca_S; + return me.OpCode.IsLocalLoadByRef(); } /// @@ -36,8 +35,33 @@ namespace Torch.Managers.PatchManager.MSIL /// public static bool IsLocalStore(this MsilInstruction me) { - return me.OpCode == OpCodes.Stloc || me.OpCode == OpCodes.Stloc_S || me.OpCode == OpCodes.Stloc_0 || - me.OpCode == OpCodes.Stloc_1 || me.OpCode == OpCodes.Stloc_2 || me.OpCode == OpCodes.Stloc_3; + return me.OpCode.IsLocalStore(); + } + + /// + /// Is this instruction a local load-by-value instruction. + /// + public static bool IsLocalLoad(this OpCode opcode) + { + return opcode == OpCodes.Ldloc || opcode == OpCodes.Ldloc_S || opcode == OpCodes.Ldloc_0 || + opcode == OpCodes.Ldloc_1 || opcode == OpCodes.Ldloc_2 || opcode == OpCodes.Ldloc_3; + } + + /// + /// Is this instruction a local load-by-reference instruction. + /// + public static bool IsLocalLoadByRef(this OpCode opcode) + { + return opcode == OpCodes.Ldloca || opcode == OpCodes.Ldloca_S; + } + + /// + /// Is this instruction a local store instruction. + /// + public static bool IsLocalStore(this OpCode opcode) + { + return opcode == OpCodes.Stloc || opcode == OpCodes.Stloc_S || opcode == OpCodes.Stloc_0 || + opcode == OpCodes.Stloc_1 || opcode == OpCodes.Stloc_2 || opcode == OpCodes.Stloc_3; } /// diff --git a/Torch/Managers/PatchManager/MSIL/MsilLocal.cs b/Torch/Managers/PatchManager/MSIL/MsilLocal.cs index 1db7eb1..0918118 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilLocal.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilLocal.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Reflection; using System.Reflection.Emit; diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs index ea93ce4..8e1c934 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandBrTarget.cs @@ -21,15 +21,40 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - int val = Instruction.OpCode.OperandType == OperandType.InlineBrTarget - ? reader.ReadInt32() - : reader.ReadSByte(); - Target = context.LabelAt((int)reader.BaseStream.Position + val); + + long offset; + + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineBrTarget: + offset = reader.ReadSByte(); + break; + case OperandType.InlineBrTarget: + offset = reader.ReadInt32(); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } + + Target = context.LabelAt((int)(reader.BaseStream.Position + offset)); } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Target.LabelFor(generator)); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineBrTarget: + case OperandType.InlineBrTarget: + generator.Emit(Instruction.OpCode, Target.LabelFor(generator)); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } + } internal override void CopyTo(MsilOperand operand) diff --git a/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs index 2442fd9..cf0dc34 100644 --- a/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs +++ b/Torch/Managers/PatchManager/MSIL/MsilOperandInline.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Reflection; using System.Reflection.Emit; using Torch.Managers.PatchManager.Transpile; +using Torch.Utils; namespace Torch.Managers.PatchManager.MSIL { @@ -44,47 +45,6 @@ namespace Torch.Managers.PatchManager.MSIL /// public static class MsilOperandInline { - /// - /// Inline unsigned byte - /// - public class MsilOperandUInt8 : MsilOperandInline - { - internal MsilOperandUInt8(MsilInstruction instruction) : base(instruction) - { - } - - internal override void Read(MethodContext context, BinaryReader reader) - { - Value = reader.ReadByte(); - } - - internal override void Emit(LoggingIlGenerator generator) - { - generator.Emit(Instruction.OpCode, Value); - } - } - - /// - /// Inline signed byte - /// - public class MsilOperandInt8 : MsilOperandInline - { - internal MsilOperandInt8(MsilInstruction instruction) : base(instruction) - { - } - - internal override void Read(MethodContext context, BinaryReader reader) - { - Value = - (sbyte)reader.ReadByte(); - } - - internal override void Emit(LoggingIlGenerator generator) - { - generator.Emit(Instruction.OpCode, Value); - } - } - /// /// Inline integer /// @@ -96,12 +56,36 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - Value = reader.ReadInt32(); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineI: + Value = reader.ReadByte(); + return; + case OperandType.InlineI: + Value = reader.ReadInt32(); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineI: + generator.Emit(Instruction.OpCode, (byte)Value); + return; + case OperandType.InlineI: + generator.Emit(Instruction.OpCode, Value); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -116,12 +100,30 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - Value = reader.ReadSingle(); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineR: + Value = reader.ReadSingle(); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineR: + generator.Emit(Instruction.OpCode, Value); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -136,12 +138,30 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - Value = reader.ReadDouble(); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineR: + Value = reader.ReadDouble(); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineR: + generator.Emit(Instruction.OpCode, Value); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -156,12 +176,30 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - Value = reader.ReadInt64(); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineI8: + Value = reader.ReadInt64(); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineI8: + generator.Emit(Instruction.OpCode, Value); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -176,13 +214,30 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - Value = - context.TokenResolver.ResolveString(reader.ReadInt32()); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineString: + Value = context.TokenResolver.ResolveString(reader.ReadInt32()); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineString: + generator.Emit(Instruction.OpCode, Value); + return; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -197,14 +252,28 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - byte[] sig = context.TokenResolver - .ResolveSignature(reader.ReadInt32()); - throw new ArgumentException("Can't figure out how to convert this."); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineSig: + throw new NotImplementedException(); + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineSig: + throw new NotImplementedException(); + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -219,18 +288,45 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - int paramID = - Instruction.OpCode.OperandType == OperandType.ShortInlineVar - ? reader.ReadByte() - : reader.ReadUInt16(); - if (paramID == 0 && !context.Method.IsStatic) + int id; + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineVar: + id = reader.ReadByte(); + break; + case OperandType.InlineVar: + id = reader.ReadUInt16(); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } + + if (id == 0 && !context.Method.IsStatic) throw new ArgumentException("Haven't figured out how to ldarg with the \"this\" argument"); - Value = new MsilArgument(context.Method.GetParameters()[paramID - (context.Method.IsStatic ? 0 : 1)]); + // ReSharper disable once ConvertIfStatementToConditionalTernaryExpression + if (context.Method == null) + Value = new MsilArgument(id); + else + Value = new MsilArgument(context.Method.GetParameters()[id - (context.Method.IsStatic ? 0 : 1)]); } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value.Position); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineVar: + generator.Emit(Instruction.OpCode, (byte) Value.Position); + break; + case OperandType.InlineVar: + generator.Emit(Instruction.OpCode, (short)Value.Position); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -245,16 +341,42 @@ namespace Torch.Managers.PatchManager.MSIL internal override void Read(MethodContext context, BinaryReader reader) { - Value = - new MsilLocal(context.Method.GetMethodBody().LocalVariables[ - Instruction.OpCode.OperandType == OperandType.ShortInlineVar - ? reader.ReadByte() - : reader.ReadUInt16()]); + int id; + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineVar: + id = reader.ReadByte(); + break; + case OperandType.InlineVar: + id = reader.ReadUInt16(); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } + // ReSharper disable once ConvertIfStatementToConditionalTernaryExpression + if (context.MethodBody == null) + Value = new MsilLocal(id); + else + Value = new MsilLocal(context.MethodBody.LocalVariables[id]); } internal override void Emit(LoggingIlGenerator generator) { - generator.Emit(Instruction.OpCode, Value.Index); + // ReSharper disable once SwitchStatementMissingSomeCases + switch (Instruction.OpCode.OperandType) + { + case OperandType.ShortInlineVar: + generator.Emit(Instruction.OpCode, (byte)Value.Index); + break; + case OperandType.InlineVar: + generator.Emit(Instruction.OpCode, (short)Value.Index); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } } } @@ -286,16 +408,40 @@ namespace Torch.Managers.PatchManager.MSIL value = context.TokenResolver.ResolveField(reader.ReadInt32()); break; default: - throw new ArgumentException("Reflected operand only applies to inline reflected types"); + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); } if (value is TY vty) Value = vty; + else if (value == null) + Value = null; else throw new Exception($"Expected type {typeof(TY).Name} from operand {Instruction.OpCode.OperandType}, got {value.GetType()?.Name ?? "null"}"); } internal override void Emit(LoggingIlGenerator generator) { + + switch (Instruction.OpCode.OperandType) + { + case OperandType.InlineTok: + Debug.Assert(Value is MethodBase || Value is Type || Value is FieldInfo, + $"Value {Value?.GetType()} doesn't match operand type"); + break; + case OperandType.InlineType: + Debug.Assert(Value is Type, $"Value {Value?.GetType()} doesn't match operand type"); + break; + case OperandType.InlineMethod: + Debug.Assert(Value is MethodBase, $"Value {Value?.GetType()} doesn't match operand type"); + break; + case OperandType.InlineField: + Debug.Assert(Value is FieldInfo, $"Value {Value?.GetType()} doesn't match operand type"); + break; + default: + throw new InvalidBranchException( + $"OpCode {Instruction.OpCode}, operand type {Instruction.OpCode.OperandType} doesn't match {GetType().Name}"); + } + if (Value is ConstructorInfo) generator.Emit(Instruction.OpCode, Value as ConstructorInfo); else if (Value is FieldInfo) diff --git a/Torch/Managers/PatchManager/MSIL/MsilTryCatchOperation.cs b/Torch/Managers/PatchManager/MSIL/MsilTryCatchOperation.cs new file mode 100644 index 0000000..d9aecae --- /dev/null +++ b/Torch/Managers/PatchManager/MSIL/MsilTryCatchOperation.cs @@ -0,0 +1,54 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace Torch.Managers.PatchManager.MSIL +{ + /// + /// Represents a try/catch block operation type + /// + public enum MsilTryCatchOperationType + { + // TryCatchBlockIL: + // var exBlock = ILGenerator.BeginExceptionBlock(); + // try{ + // ILGenerator.BeginCatchBlock(typeof(Exception)); + // } catch(Exception e) { + // ILGenerator.BeginCatchBlock(null); + // } catch { + // ILGenerator.BeginFinallyBlock(); + // }finally { + // ILGenerator.EndExceptionBlock(); + // } + BeginExceptionBlock, + BeginCatchBlock, + BeginFinallyBlock, + EndExceptionBlock + } + + /// + /// Represents a try catch operation. + /// + public class MsilTryCatchOperation + { + /// + /// Operation type + /// + public readonly MsilTryCatchOperationType Type; + /// + /// Type caught by this operation, or null if none. + /// + public readonly Type CatchType; + + public MsilTryCatchOperation(MsilTryCatchOperationType op, Type caughtType = null) + { + Type = op; + if (caughtType != null && op != MsilTryCatchOperationType.BeginCatchBlock) + throw new ArgumentException($"Can't use caught type with operation type {op}", nameof(caughtType)); + CatchType = caughtType; + } + } +} diff --git a/Torch/Managers/PatchManager/MethodRewritePattern.cs b/Torch/Managers/PatchManager/MethodRewritePattern.cs index 3696371..83ce8eb 100644 --- a/Torch/Managers/PatchManager/MethodRewritePattern.cs +++ b/Torch/Managers/PatchManager/MethodRewritePattern.cs @@ -36,9 +36,11 @@ namespace Torch.Managers.PatchManager private int _hasChanges = 0; - internal bool HasChanges() + internal bool HasChanges(bool reset = false) { - return Interlocked.Exchange(ref _hasChanges, 0) != 0; + if (reset) + return Interlocked.Exchange(ref _hasChanges, 0) != 0; + return _hasChanges != 0; } /// @@ -154,10 +156,33 @@ namespace Torch.Managers.PatchManager /// public MethodRewriteSet Transpilers { get; } /// + /// Methods capable of accepting one and returing another, modified. + /// Runs after prefixes, suffixes, and normal transpilers are applied. + /// + public MethodRewriteSet PostTranspilers { get; } + /// /// Methods run after the original method has run. /// public MethodRewriteSet Suffixes { get; } + /// + /// Should the resulting MSIL of the transpile operation be printed. + /// + public bool PrintMsil + { + get => _parent?.PrintMsil ?? _printMsilBacking; + set + { + if (_parent != null) + _parent.PrintMsil = value; + else + _printMsilBacking = value; + } + } + private bool _printMsilBacking; + + private readonly MethodRewritePattern _parent; + /// /// /// @@ -166,7 +191,9 @@ namespace Torch.Managers.PatchManager { Prefixes = new MethodRewriteSet(parentPattern?.Prefixes); Transpilers = new MethodRewriteSet(parentPattern?.Transpilers); + PostTranspilers = new MethodRewriteSet(parentPattern?.PostTranspilers); Suffixes = new MethodRewriteSet(parentPattern?.Suffixes); + _parent = parentPattern; } } } diff --git a/Torch/Managers/PatchManager/PatchManager.cs b/Torch/Managers/PatchManager/PatchManager.cs index fce9991..0062a66 100644 --- a/Torch/Managers/PatchManager/PatchManager.cs +++ b/Torch/Managers/PatchManager/PatchManager.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; +using System.Threading; using NLog; using Torch.API; using Torch.Managers.PatchManager.Transpile; @@ -148,12 +150,43 @@ namespace Torch.Managers.PatchManager return count; } + + private static int _finishedPatchCount, _dirtyPatchCount; + + private static void DoCommit(DecoratedMethod method) + { + if (!method.HasChanged()) + return; + method.Commit(); + int value = Interlocked.Increment(ref _finishedPatchCount); + var actualPercentage = (value * 100) / _dirtyPatchCount; + var currentPrintGroup = actualPercentage / 10; + var prevPrintGroup = (value - 1) * 10 / _dirtyPatchCount; + if (currentPrintGroup != prevPrintGroup && value >= 1) + { + _log.Info($"Patched {value}/{_dirtyPatchCount}. ({actualPercentage:D2}%)"); + } + } + /// internal static void CommitInternal() { lock (_rewritePatterns) + { + _log.Info("Patching begins..."); + _finishedPatchCount = 0; + _dirtyPatchCount = _rewritePatterns.Values.Sum(x => x.HasChanged() ? 1 : 0); +#if true + ParallelTasks.Parallel.ForEach(_rewritePatterns.Values.Where(x => !x.PrintMsil), DoCommit); + foreach (DecoratedMethod m in _rewritePatterns.Values.Where(x => x.PrintMsil)) + DoCommit(m); +#else foreach (DecoratedMethod m in _rewritePatterns.Values) - m.Commit(); + DoCommit(m); +#endif + _log.Info("Patching done"); + + } } /// @@ -164,12 +197,9 @@ namespace Torch.Managers.PatchManager CommitInternal(); } - /// - /// Commits any existing patches. - /// + /// public override void Attach() { - Commit(); } /// diff --git a/Torch/Managers/PatchManager/PatchUtilities.cs b/Torch/Managers/PatchManager/PatchUtilities.cs new file mode 100644 index 0000000..0320742 --- /dev/null +++ b/Torch/Managers/PatchManager/PatchUtilities.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Text; +using System.Threading.Tasks; +using Torch.Managers.PatchManager.MSIL; +using Torch.Managers.PatchManager.Transpile; + +namespace Torch.Managers.PatchManager +{ + /// + /// Functions that let you read and write MSIL to methods directly. + /// + public class PatchUtilities + { + /// + /// Gets the content of a method as an instruction stream + /// + /// Method to examine + /// instruction stream + public static IEnumerable ReadInstructions(MethodBase method) + { + var context = new MethodContext(method); + context.Read(); + return context.Instructions; + } + + /// + /// Writes the given instruction stream to the given IL generator, fixing short branch instructions. + /// + /// Instruction stream + /// Output + public static void EmitInstructions(IEnumerable insn, LoggingIlGenerator generator) + { + MethodTranspiler.EmitMethod(insn.ToList(), generator); + } + } +} diff --git a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs index 1835ac2..2de3b21 100644 --- a/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs +++ b/Torch/Managers/PatchManager/Transpile/LoggingILGenerator.cs @@ -24,20 +24,23 @@ namespace Torch.Managers.PatchManager.Transpile /// public ILGenerator Backing { get; } + private readonly LogLevel _level; + /// /// Creates a new logging IL generator backed by the given generator. /// /// Backing generator - public LoggingIlGenerator(ILGenerator backing) + public LoggingIlGenerator(ILGenerator backing, LogLevel level) { Backing = backing; + _level = level; } /// public LocalBuilder DeclareLocal(Type localType, bool isPinned = false) { LocalBuilder res = Backing.DeclareLocal(localType, isPinned); - _log?.Trace($"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); + _log?.Log(_level, $"DclLoc\t{res.LocalIndex}\t=> {res.LocalType} {res.IsPinned}"); return res; } @@ -45,111 +48,170 @@ namespace Torch.Managers.PatchManager.Transpile /// public void Emit(OpCode op) { - _log?.Trace($"Emit\t{op,_opcodePadding}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding}"); Backing.Emit(op); } /// public void Emit(OpCode op, LocalBuilder arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} Local:{arg.LocalIndex}/{arg.LocalType}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} Local:{arg.LocalIndex}/{arg.LocalType}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, byte arg) + { + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); + Backing.Emit(op, arg); + } + + /// + public void Emit(OpCode op, short arg) + { + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, int arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, long arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, float arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, double arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, string arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Type arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, FieldInfo arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, MethodInfo arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } #pragma warning disable 649 - [ReflectedGetter(Name="m_label")] + [ReflectedGetter(Name = "m_label")] private static Func _labelID; #pragma warning restore 649 /// public void Emit(OpCode op, Label arg) { - _log?.Trace($"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding}\tL:{_labelID.Invoke(arg)}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, Label[] arg) { - _log?.Trace($"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding}\t{string.Join(", ", arg.Select(x => "L:" + _labelID.Invoke(x)))}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, SignatureHelper arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } /// public void Emit(OpCode op, ConstructorInfo arg) { - _log?.Trace($"Emit\t{op,_opcodePadding} {arg}"); + _log?.Log(_level, $"Emit\t{op,_opcodePadding} {arg}"); Backing.Emit(op, arg); } + + #region Exceptions + /// + public Label BeginExceptionBlock() + { + _log?.Log(_level, $"BeginExceptionBlock"); + return Backing.BeginExceptionBlock(); + } + + /// + public void BeginCatchBlock(Type caught) + { + _log?.Log(_level, $"BeginCatchBlock {caught}"); + Backing.BeginCatchBlock(caught); + } + + /// + public void BeginExceptFilterBlock() + { + _log?.Log(_level, $"BeginExceptFilterBlock"); + Backing.BeginExceptFilterBlock(); + } + + /// + public void BeginFaultBlock() + { + _log?.Log(_level, $"BeginFaultBlock"); + Backing.BeginFaultBlock(); + } + + /// + public void BeginFinallyBlock() + { + _log?.Log(_level, $"BeginFinallyBlock"); + Backing.BeginFinallyBlock(); + } + + /// + public void EndExceptionBlock() + { + _log?.Log(_level, $"EndExceptionBlock"); + Backing.EndExceptionBlock(); + } + #endregion + /// public void MarkLabel(Label label) { - _log?.Trace($"MkLbl\tL:{_labelID.Invoke(label)}"); + _log?.Log(_level, $"MkLbl\tL:{_labelID.Invoke(label)}"); Backing.MarkLabel(label); } @@ -166,7 +228,7 @@ namespace Torch.Managers.PatchManager.Transpile [Conditional("DEBUG")] public void EmitComment(string comment) { - _log?.Trace($"// {comment}"); + _log?.Log(_level, $"// {comment}"); } } #pragma warning restore 162 diff --git a/Torch/Managers/PatchManager/Transpile/MethodContext.cs b/Torch/Managers/PatchManager/Transpile/MethodContext.cs index 7d316f7..aa7b203 100644 --- a/Torch/Managers/PatchManager/Transpile/MethodContext.cs +++ b/Torch/Managers/PatchManager/Transpile/MethodContext.cs @@ -16,6 +16,7 @@ namespace Torch.Managers.PatchManager.Transpile private static readonly Logger _log = LogManager.GetCurrentClassLogger(); public readonly MethodBase Method; + public readonly MethodBody MethodBody; private readonly byte[] _msilBytes; internal Dictionary Labels { get; } = new Dictionary(); @@ -35,14 +36,32 @@ namespace Torch.Managers.PatchManager.Transpile public MethodContext(MethodBase method) { Method = method; - _msilBytes = Method.GetMethodBody().GetILAsByteArray(); + MethodBody = method.GetMethodBody(); + Debug.Assert(MethodBody != null, "Method body is null"); + _msilBytes = MethodBody.GetILAsByteArray(); TokenResolver = new NormalTokenResolver(method); } + + +#pragma warning disable 649 + [ReflectedMethod(Name = "BakeByteArray")] + private static Func _ilGeneratorBakeByteArray; +#pragma warning restore 649 + + public MethodContext(DynamicMethod method) + { + Method = null; + MethodBody = null; + _msilBytes = _ilGeneratorBakeByteArray(method.GetILGenerator()); + TokenResolver = new DynamicMethodTokenResolver(method); + } + public void Read() { ReadInstructions(); ResolveLabels(); + ResolveCatchClauses(); } private void ReadInstructions() @@ -53,14 +72,19 @@ namespace Torch.Managers.PatchManager.Transpile using (var reader = new BinaryReader(memory)) while (memory.Length > memory.Position) { - var opcodeOffset = (int) memory.Position; + var opcodeOffset = (int)memory.Position; var instructionValue = (short)memory.ReadByte(); if (Prefixes.Contains(instructionValue)) { instructionValue = (short)((instructionValue << 8) | memory.ReadByte()); } if (!OpCodeLookup.TryGetValue(instructionValue, out OpCode opcode)) - throw new Exception($"Unknown opcode {instructionValue:X}"); + { + var msg = $"Unknown opcode {instructionValue:X}"; + _log.Error(msg); + Debug.Assert(false, msg); + continue; + } if (opcode.Size != memory.Position - opcodeOffset) throw new Exception($"Opcode said it was {opcode.Size} but we read {memory.Position - opcodeOffset}"); var instruction = new MsilInstruction(opcode) @@ -72,75 +96,49 @@ namespace Torch.Managers.PatchManager.Transpile } } + private void ResolveCatchClauses() + { + if (MethodBody == null) + return; + foreach (ExceptionHandlingClause clause in MethodBody.ExceptionHandlingClauses) + { + var beginInstruction = FindInstruction(clause.TryOffset); + var catchInstruction = FindInstruction(clause.HandlerOffset); + var finalInstruction = FindInstruction(clause.HandlerOffset + clause.HandlerLength); + beginInstruction.TryCatchOperation = new MsilTryCatchOperation(MsilTryCatchOperationType.BeginExceptionBlock); + if ((clause.Flags & ExceptionHandlingClauseOptions.Clause) != 0) + catchInstruction.TryCatchOperation = new MsilTryCatchOperation(MsilTryCatchOperationType.BeginCatchBlock, clause.CatchType); + else if ((clause.Flags & ExceptionHandlingClauseOptions.Finally) != 0) + catchInstruction.TryCatchOperation = new MsilTryCatchOperation(MsilTryCatchOperationType.BeginFinallyBlock); + finalInstruction.TryCatchOperation = new MsilTryCatchOperation(MsilTryCatchOperationType.EndExceptionBlock); + } + } + + private MsilInstruction FindInstruction(int offset) + { + int min = 0, max = _instructions.Count; + while (min != max) + { + int mid = (min + max) / 2; + if (_instructions[mid].Offset < offset) + min = mid + 1; + else + max = mid; + } + return min >= 0 && min < _instructions.Count ? _instructions[min] : null; + } + private void ResolveLabels() { foreach (var label in Labels) { - int min = 0, max = _instructions.Count; - while (min != max) - { - int mid = (min + max) / 2; - if (_instructions[mid].Offset < label.Key) - min = mid + 1; - else - max = mid; - } -#if DEBUG - if (min >= _instructions.Count || min < 0) - { - _log.Trace( - $"Want offset {label.Key} for {label.Value}, instruction offsets at\n {string.Join("\n", _instructions.Select(x => $"IL_{x.Offset:X4} {x}"))}"); - } - MsilInstruction prevInsn = min > 0 ? _instructions[min - 1] : null; - if ((prevInsn == null || prevInsn.Offset >= label.Key) || - _instructions[min].Offset < label.Key) - _log.Error($"Label {label.Value} wanted {label.Key} but instruction is at {_instructions[min].Offset}. Previous instruction is at {prevInsn?.Offset ?? -1}"); -#endif - _instructions[min]?.Labels?.Add(label.Value); + MsilInstruction target = FindInstruction(label.Key); + Debug.Assert(target != null, $"No label for offset {label.Key}"); + target?.Labels?.Add(label.Value); } } - [Conditional("DEBUG")] - public void CheckIntegrity() - { - var entryStackCount = new Dictionary>(); - var currentStackSize = 0; - foreach (MsilInstruction insn in _instructions) - { - // I don't want to deal with this, so I won't - if (insn.OpCode == OpCodes.Br || insn.OpCode == OpCodes.Br_S || insn.OpCode == OpCodes.Jmp || - insn.OpCode == OpCodes.Leave || insn.OpCode == OpCodes.Leave_S) - break; - foreach (MsilLabel label in insn.Labels) - if (entryStackCount.TryGetValue(label, out Dictionary dict)) - dict.Add(insn, currentStackSize); - else - (entryStackCount[label] = new Dictionary()).Add(insn, currentStackSize); - - currentStackSize += insn.StackChange(); - - if (insn.Operand is MsilOperandBrTarget br) - if (entryStackCount.TryGetValue(br.Target, out Dictionary dict)) - dict.Add(insn, currentStackSize); - else - (entryStackCount[br.Target] = new Dictionary()).Add(insn, currentStackSize); - } - foreach (KeyValuePair> label in entryStackCount) - { - if (label.Value.Values.Aggregate(new HashSet(), (a, b) => - { - a.Add(b); - return a; - }).Count > 1) - { - _log.Warn($"Label {label.Key} has multiple entry stack counts"); - foreach (KeyValuePair kv in label.Value) - _log.Warn($"{kv.Key.Offset:X4} {kv.Key} => {kv.Value}"); - } - } - } - public string ToHumanMsil() { return string.Join("\n", _instructions.Select(x => $"IL_{x.Offset:X4}: {x.StackChange():+0;-#} {x}")); diff --git a/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs index e2db491..b989e49 100644 --- a/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs +++ b/Torch/Managers/PatchManager/Transpile/MethodTranspiler.cs @@ -1,7 +1,10 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Reflection; using System.Reflection.Emit; +using System.Text; +using System.Windows.Documents; using NLog; using Torch.Managers.PatchManager.MSIL; @@ -11,15 +14,19 @@ namespace Torch.Managers.PatchManager.Transpile { public static readonly Logger _log = LogManager.GetCurrentClassLogger(); - internal static void Transpile(MethodBase baseMethod, Func localCreator, IEnumerable transpilers, LoggingIlGenerator output, Label? retLabel) + internal static IEnumerable Transpile(MethodBase baseMethod, Func localCreator, + IEnumerable transpilers, MsilLabel retLabel) { var context = new MethodContext(baseMethod); context.Read(); - context.CheckIntegrity(); - // _log.Trace("Input Method:"); - // _log.Trace(context.ToHumanMsil); + // IntegrityAnalysis(LogLevel.Trace, context.Instructions); + return Transpile(baseMethod, context.Instructions, localCreator, transpilers, retLabel); + } - var methodContent = (IEnumerable)context.Instructions; + internal static IEnumerable Transpile(MethodBase baseMethod, IEnumerable methodContent, + Func localCreator, + IEnumerable transpilers, MsilLabel retLabel) + { foreach (MethodInfo transpiler in transpilers) { var paramList = new List(); @@ -27,6 +34,8 @@ namespace Torch.Managers.PatchManager.Transpile { if (parameter.Name.Equals("__methodBody")) paramList.Add(baseMethod.GetMethodBody()); + else if (parameter.Name.Equals("__methodBase")) + paramList.Add(baseMethod); else if (parameter.Name.Equals("__localCreator")) paramList.Add(localCreator); else if (parameter.ParameterType == typeof(IEnumerable)) @@ -37,20 +46,157 @@ namespace Torch.Managers.PatchManager.Transpile } methodContent = (IEnumerable)transpiler.Invoke(null, paramList.ToArray()); } - methodContent = FixBranchAndReturn(methodContent, retLabel); - foreach (var k in methodContent) - k.Emit(output); + return FixBranchAndReturn(methodContent, retLabel); } - private static IEnumerable FixBranchAndReturn(IEnumerable insn, Label? retTarget) + internal static void EmitMethod(IReadOnlyList instructions, LoggingIlGenerator target) + { + for (var i = 0; i < instructions.Count; i++) + { + MsilInstruction il = instructions[i]; + if (il.TryCatchOperation != null) + switch (il.TryCatchOperation.Type) + { + case MsilTryCatchOperationType.BeginExceptionBlock: + target.BeginExceptionBlock(); + break; + case MsilTryCatchOperationType.BeginCatchBlock: + target.BeginCatchBlock(il.TryCatchOperation.CatchType); + break; + case MsilTryCatchOperationType.BeginFinallyBlock: + target.BeginFinallyBlock(); + break; + case MsilTryCatchOperationType.EndExceptionBlock: + target.EndExceptionBlock(); + break; + default: + throw new ArgumentOutOfRangeException(); + } + + foreach (MsilLabel label in il.Labels) + target.MarkLabel(label.LabelFor(target)); + + MsilInstruction ilNext = i < instructions.Count - 1 ? instructions[i + 1] : null; + + // Leave opcodes emitted by these: + if (il.OpCode == OpCodes.Endfilter && ilNext?.TryCatchOperation?.Type == + MsilTryCatchOperationType.BeginCatchBlock) + continue; + if ((il.OpCode == OpCodes.Leave || il.OpCode == OpCodes.Leave_S) && + (ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.EndExceptionBlock || + ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.BeginCatchBlock || + ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.BeginFinallyBlock)) + continue; + if ((il.OpCode == OpCodes.Leave || il.OpCode == OpCodes.Leave_S || il.OpCode == OpCodes.Endfinally) && + ilNext?.TryCatchOperation?.Type == MsilTryCatchOperationType.EndExceptionBlock) + continue; + + if (il.Operand != null) + il.Operand.Emit(target); + else + target.Emit(il.OpCode); + } + } + + /// + /// Analyzes the integrity of a set of instructions. + /// + /// default logging level + /// instructions + public static void IntegrityAnalysis(LogLevel level, IReadOnlyList instructions) + { + var targets = new Dictionary(); + for (var i = 0; i < instructions.Count; i++) + foreach (var label in instructions[i].Labels) + { + if (targets.TryGetValue(label, out var other)) + _log.Warn($"Label {label} is applied to ({i}: {instructions[i]}) and ({other}: {instructions[other]})"); + targets[label] = i; + } + + var reparsed = new HashSet(); + var labelStackSize = new Dictionary>(); + var stack = 0; + var unreachable = false; + var data = new StringBuilder[instructions.Count]; + for (var i = 0; i < instructions.Count; i++) + { + var k = instructions[i]; + var line = (data[i] ?? (data[i] = new StringBuilder())).Clear(); + if (!unreachable) + { + foreach (var label in k.Labels) + { + if (!labelStackSize.TryGetValue(label, out Dictionary otherStack)) + labelStackSize[label] = otherStack = new Dictionary(); + + otherStack[i - 1] = stack; + if (otherStack.Values.Distinct().Count() > 1 || (otherStack.Count == 1 && !otherStack.ContainsValue(stack))) + { + string otherDesc = string.Join(", ", otherStack.Select(x => $"{x.Key:X4}=>{x.Value}")); + line.AppendLine($"WARN// | Label {label} has multiple entry stack sizes ({otherDesc})"); + } + } + } + foreach (var label in k.Labels) + { + if (!labelStackSize.TryGetValue(label, out var entry)) + continue; + string desc = string.Join(", ", entry.Select(x => $"{x.Key:X4}=>{x.Value}")); + line.AppendLine($"// \\/ Label {label} has stack sizes {desc}"); + if (unreachable && entry.Any()) + { + stack = entry.Values.First(); + unreachable = false; + } + } + stack += k.StackChange(); + if (k.TryCatchOperation != null) + line.AppendLine($"// .{k.TryCatchOperation.Type} ({k.TryCatchOperation.CatchType})"); + line.AppendLine($"{i:X4} S:{stack:D2} dS:{k.StackChange():+0;-#}\t{k}" + (unreachable ? "\t// UNREACHABLE" : "")); + if (k.Operand is MsilOperandBrTarget br) + { + if (!targets.ContainsKey(br.Target)) + line.AppendLine($"WARN// ^ Unknown target {br.Target}"); + + if (!labelStackSize.TryGetValue(br.Target, out Dictionary otherStack)) + labelStackSize[br.Target] = otherStack = new Dictionary(); + + otherStack[i] = stack; + if (otherStack.Values.Distinct().Count() > 1 || (otherStack.Count == 1 && !otherStack.ContainsValue(stack))) + { + string otherDesc = string.Join(", ", otherStack.Select(x => $"{x.Key:X4}=>{x.Value}")); + line.AppendLine($"WARN// ^ Label {br.Target} has multiple entry stack sizes ({otherDesc})"); + } + if (targets.TryGetValue(br.Target, out var target) && target < i && reparsed.Add(br.Target)) + { + i = target - 1; + unreachable = false; + continue; + } + } + if (k.OpCode == OpCodes.Br || k.OpCode == OpCodes.Br_S) + unreachable = true; + } + foreach (var k in data) + foreach (var line in k.ToString().Split('\n')) + { + if (string.IsNullOrWhiteSpace(line)) + continue; + if (line.StartsWith("WARN", StringComparison.OrdinalIgnoreCase)) + _log.Warn(line.Substring(4).Trim()); + else + _log.Log(level, line.Trim()); + } + } + + private static IEnumerable FixBranchAndReturn(IEnumerable insn, MsilLabel retTarget) { foreach (MsilInstruction i in insn) { - if (retTarget.HasValue && i.OpCode == OpCodes.Ret) + if (retTarget != null && i.OpCode == OpCodes.Ret) { - MsilInstruction j = new MsilInstruction(OpCodes.Br).InlineTarget(new MsilLabel(retTarget.Value)); - foreach (MsilLabel l in i.Labels) - j.Labels.Add(l); + var j = i.CopyWith(OpCodes.Br).InlineTarget(retTarget); _log.Trace($"Replacing {i} with {j}"); yield return j; } diff --git a/Torch/Managers/PluginManager.cs b/Torch/Managers/PluginManager.cs deleted file mode 100644 index 92112f2..0000000 --- a/Torch/Managers/PluginManager.cs +++ /dev/null @@ -1,196 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Threading.Tasks; -using NLog; -using Torch.API; -using Torch.API.Managers; -using Torch.API.Plugins; -using Torch.API.Session; -using Torch.Commands; -using VRage.Collections; - -namespace Torch.Managers -{ - /// - public class PluginManager : Manager, IPluginManager - { - private static Logger _log = LogManager.GetLogger(nameof(PluginManager)); - public readonly string PluginDir = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "Plugins"); - [Dependency] - private UpdateManager _updateManager; - [Dependency(Optional = true)] - private ITorchSessionManager _sessionManager; - - /// - public IList Plugins { get; } = new ObservableList(); - - public event Action PluginLoaded; - public event Action> PluginsLoaded; - - public PluginManager(ITorchBase torchInstance) : base(torchInstance) - { - if (!Directory.Exists(PluginDir)) - Directory.CreateDirectory(PluginDir); - } - - /// - /// Updates loaded plugins in parallel. - /// - public void UpdatePlugins() - { - foreach (var plugin in Plugins) - plugin.Update(); - } - - private Action _attachCommandsHandler = null; - - private void SessionStateChanged(ITorchSession session, TorchSessionState newState) - { - var cmdManager = session.Managers.GetManager(); - if (cmdManager == null) - return; - switch (newState) - { - case TorchSessionState.Loaded: - if (_attachCommandsHandler != null) - PluginLoaded -= _attachCommandsHandler; - _attachCommandsHandler = (x) => cmdManager.RegisterPluginCommands(x); - PluginLoaded += _attachCommandsHandler; - foreach (ITorchPlugin plugin in Plugins) - cmdManager.RegisterPluginCommands(plugin); - break; - case TorchSessionState.Unloading: - if (_attachCommandsHandler != null) - { - PluginLoaded -= _attachCommandsHandler; - _attachCommandsHandler = null; - } - foreach (ITorchPlugin plugin in Plugins) - { - // cmdMgr?.UnregisterPluginCommands(plugin); - } - break; - case TorchSessionState.Loading: - case TorchSessionState.Unloaded: - break; - default: - throw new ArgumentOutOfRangeException(nameof(newState), newState, null); - } - } - - /// - /// Prepares the plugin manager for loading. - /// - public override void Attach() - { - if (_sessionManager != null) - _sessionManager.SessionStateChanged += SessionStateChanged; - } - - /// - /// Unloads all plugins. - /// - public override void Detach() - { - if (_sessionManager != null) - _sessionManager.SessionStateChanged -= SessionStateChanged; - foreach (var plugin in Plugins) - plugin.Dispose(); - - Plugins.Clear(); - } - - private void DownloadPlugins() - { - var folders = Directory.GetDirectories(PluginDir); - var taskList = new List(); - - //Copy list because we don't want to modify the config. - var toDownload = Torch.Config.Plugins.ToList(); - - foreach (var folder in folders) - { - var manifestPath = Path.Combine(folder, "manifest.xml"); - if (!File.Exists(manifestPath)) - { - _log.Debug($"No manifest in {folder}, skipping"); - continue; - } - - var manifest = PluginManifest.Load(manifestPath); - toDownload.RemoveAll(x => string.Compare(manifest.Repository, x, StringComparison.InvariantCultureIgnoreCase) == 0); - taskList.Add(_updateManager.CheckAndUpdatePlugin(manifest)); - } - - foreach (var repository in toDownload) - { - var manifest = new PluginManifest { Repository = repository, Version = "0.0" }; - taskList.Add(_updateManager.CheckAndUpdatePlugin(manifest)); - } - - Task.WaitAll(taskList.ToArray()); - } - - /// - public void LoadPlugins() - { - if (Torch.Config.ShouldUpdatePlugins) - DownloadPlugins(); - else - _log.Warn("Automatic plugin updates are disabled."); - - _log.Info("Loading plugins"); - var dlls = Directory.GetFiles(PluginDir, "*.dll", SearchOption.AllDirectories); - foreach (var dllPath in dlls) - { - _log.Info($"Loading plugin {dllPath}"); - var asm = Assembly.UnsafeLoadFrom(dllPath); - - foreach (var type in asm.GetExportedTypes()) - { - if (type.GetInterfaces().Contains(typeof(ITorchPlugin))) - { - if (type.GetCustomAttribute() == null) - continue; - - try - { - var plugin = (TorchPluginBase)Activator.CreateInstance(type); - if (plugin.Id == default(Guid)) - throw new TypeLoadException($"Plugin '{type.FullName}' is missing a {nameof(PluginAttribute)}"); - - _log.Info($"Loading plugin {plugin.Name} ({plugin.Version})"); - plugin.StoragePath = Torch.Config.InstancePath; - Plugins.Add(plugin); - PluginLoaded?.Invoke(plugin); - } - catch (Exception e) - { - _log.Error($"Error loading plugin '{type.FullName}'"); - _log.Error(e); - throw; - } - } - } - } - - Plugins.ForEach(p => p.Init(Torch)); - PluginsLoaded?.Invoke(Plugins.ToList()); - } - - public IEnumerator GetEnumerator() - { - return Plugins.GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - } -} diff --git a/Torch/Patches/ObjectFactoryInitPatch.cs b/Torch/Patches/ObjectFactoryInitPatch.cs new file mode 100644 index 0000000..51fc46e --- /dev/null +++ b/Torch/Patches/ObjectFactoryInitPatch.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Sandbox; +using Sandbox.Game.Entities; +using Torch.Utils; +using VRage.Game.Common; +using VRage.Game.Components; +using VRage.Game.Entity; +using VRage.ObjectBuilders; +using VRage.Plugins; +using VRage.Utils; + +namespace Torch.Patches +{ + + /// + /// There are places in static ctors where the registered assembly depends on the + /// or . Here we force those registrations with the proper assemblies to ensure they work correctly. + /// + internal static class ObjectFactoryInitPatch + { +#pragma warning disable 649 + [ReflectedGetter(Name = "m_objectFactory", TypeName = "Sandbox.Game.Entities.MyEntityFactory, Sandbox.Game")] + private static readonly Func> _entityFactoryObjectFactory; +#pragma warning restore 649 + + internal static void ForceRegisterAssemblies() + { + // static MyEntities() called by MySandboxGame.ForceStaticCtor + RuntimeHelpers.RunClassConstructor(typeof(MyEntities).TypeHandle); + { + MyObjectFactory factory = _entityFactoryObjectFactory(); + ObjectFactory_RegisterFromAssemblySafe(factory, typeof(MySandboxGame).Assembly); // calling assembly + ObjectFactory_RegisterFromAssemblySafe(factory, MyPlugins.GameAssembly); + ObjectFactory_RegisterFromAssemblySafe(factory, MyPlugins.SandboxAssembly); + ObjectFactory_RegisterFromAssemblySafe(factory, MyPlugins.UserAssembly); + } + + // static MyGuiManager(): + // MyGuiControlsFactory.RegisterDescriptorsFromAssembly(); + + // static MyComponentTypeFactory() called by MyComponentContainer.Add + RuntimeHelpers.RunClassConstructor(typeof(MyComponentTypeFactory).TypeHandle); + { + ComponentTypeFactory_RegisterFromAssemblySafe(typeof(MyComponentContainer).Assembly); // calling assembly + ComponentTypeFactory_RegisterFromAssemblySafe(MyPlugins.SandboxAssembly); + ComponentTypeFactory_RegisterFromAssemblySafe(MyPlugins.GameAssembly); + ComponentTypeFactory_RegisterFromAssemblySafe(MyPlugins.SandboxGameAssembly); + ComponentTypeFactory_RegisterFromAssemblySafe(MyPlugins.UserAssembly); + } + + // static MyObjectPoolManager() + // Render, so should be fine. + } + + #region MyObjectFactory Adders + private static void ObjectFactory_RegisterDescriptorSafe( + MyObjectFactory factory, TAttribute descriptor, Type type) where TAttribute : MyFactoryTagAttribute where TCreatedObjectBase : class + { + if (factory.Attributes.TryGetValue(type, out _)) + return; + if (descriptor.ObjectBuilderType != null && factory.TryGetProducedType(descriptor.ObjectBuilderType) != null) + return; + if (typeof(MyObjectBuilder_Base).IsAssignableFrom(descriptor.ProducedType) && + factory.TryGetProducedType(descriptor.ProducedType) != null) + return; + factory.RegisterDescriptor(descriptor, type); + } + + private static void ObjectFactory_RegisterFromAssemblySafe(MyObjectFactory factory, Assembly assembly) where TAttribute : MyFactoryTagAttribute where TCreatedObjectBase : class + { + if (assembly == null) + { + return; + } + foreach (Type type in assembly.GetTypes()) + { + foreach (TAttribute descriptor in type.GetCustomAttributes()) + { + ObjectFactory_RegisterDescriptorSafe(factory, descriptor, type); + } + } + } + #endregion + #region MyComponentTypeFactory Adders + + [ReflectedGetter(Name = "m_idToType", Type = typeof(MyComponentTypeFactory))] + private static Func> _componentTypeFactoryIdToType; + [ReflectedGetter(Name = "m_typeToId", Type = typeof(MyComponentTypeFactory))] + private static Func> _componentTypeFactoryTypeToId; + [ReflectedGetter(Name = "m_typeToContainerComponentType", Type = typeof(MyComponentTypeFactory))] + private static Func> _componentTypeFactoryContainerComponentType; + + private static void ComponentTypeFactory_RegisterFromAssemblySafe(Assembly assembly) + { + if (assembly == null) + return; + foreach (Type type in assembly.GetTypes()) + if (typeof(MyComponentBase).IsAssignableFrom(type)) + { + ComponentTypeFactory_AddIdSafe(type, MyStringId.GetOrCompute(type.Name)); + ComponentTypeFactory_RegisterComponentTypeAttributeSafe(type); + } + } + + private static void ComponentTypeFactory_RegisterComponentTypeAttributeSafe(Type type) + { + Type componentType = type.GetCustomAttribute(true)?.ComponentType; + if (componentType != null) + _componentTypeFactoryContainerComponentType()[type] = componentType; + } + + private static void ComponentTypeFactory_AddIdSafe(Type type, MyStringId id) + { + _componentTypeFactoryIdToType()[id] = type; + _componentTypeFactoryTypeToId()[type] = id; + } + #endregion + } +} diff --git a/Torch/Persistent.cs b/Torch/Persistent.cs index 3beff31..670d2ba 100644 --- a/Torch/Persistent.cs +++ b/Torch/Persistent.cs @@ -44,7 +44,7 @@ namespace Torch path = Path; var ser = new XmlSerializer(typeof(T)); - using (var f = File.Create(path)) + using (var f = File.CreateText(path)) { ser.Serialize(f, Data); } @@ -57,7 +57,7 @@ namespace Torch if (File.Exists(path)) { var ser = new XmlSerializer(typeof(T)); - using (var f = File.OpenRead(path)) + using (var f = File.OpenText(path)) { config.Data = (T)ser.Deserialize(f); } diff --git a/Torch/Plugins/PluginManager.cs b/Torch/Plugins/PluginManager.cs index 848a2da..7c32ad3 100644 --- a/Torch/Plugins/PluginManager.cs +++ b/Torch/Plugins/PluginManager.cs @@ -17,6 +17,7 @@ using Torch.API.Plugins; using Torch.API.Session; using Torch.Collections; using Torch.Commands; +using Torch.Utils; namespace Torch.Managers { @@ -235,11 +236,35 @@ namespace Torch.Managers if (!file.Contains(".dll", StringComparison.CurrentCultureIgnoreCase)) continue; + if (false) + { + var asm = Assembly.LoadFrom(file); + assemblies.Add(asm); + TorchBase.RegisterAuxAssembly(asm); + continue; + } + using (var stream = File.OpenRead(file)) { - var data = new byte[stream.Length]; - stream.Read(data, 0, data.Length); + var data = stream.ReadToEnd(); +#if DEBUG + byte[] symbol = null; + var symbolPath = Path.Combine(Path.GetDirectoryName(file) ?? ".", + Path.GetFileNameWithoutExtension(file) + ".pdb"); + if (File.Exists(symbolPath)) + try + { + using (var symbolStream = File.OpenRead(symbolPath)) + symbol = symbolStream.ReadToEnd(); + } + catch (Exception e) + { + _log.Warn(e, $"Failed to read debugging symbols from {symbolPath}"); + } + Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data); +#else Assembly asm = Assembly.Load(data); +#endif assemblies.Add(asm); TorchBase.RegisterAuxAssembly(asm); } @@ -266,11 +291,29 @@ namespace Torch.Managers if (!entry.Name.Contains(".dll", StringComparison.CurrentCultureIgnoreCase)) continue; + using (var stream = entry.Open()) { - var data = new byte[entry.Length]; - stream.Read(data, 0, data.Length); + var data = stream.ReadToEnd((int)entry.Length); +#if DEBUG + byte[] symbol = null; + var symbolEntryName = entry.FullName.Substring(0, entry.FullName.Length - "dll".Length) + "pdb"; + var symbolEntry = zipFile.GetEntry(symbolEntryName); + if (symbolEntry != null) + try + { + using (var symbolStream = symbolEntry.Open()) + symbol = symbolStream.ReadToEnd((int)symbolEntry.Length); + } + catch (Exception e) + { + _log.Warn(e, $"Failed to read debugging symbols from {path}:{symbolEntryName}"); + } + Assembly asm = symbol != null ? Assembly.Load(data, symbol) : Assembly.Load(data); +#else Assembly asm = Assembly.Load(data); +#endif + assemblies.Add(asm); TorchBase.RegisterAuxAssembly(asm); } } diff --git a/Torch/Torch.csproj b/Torch/Torch.csproj index ab4a85f..a77ba4a 100644 --- a/Torch/Torch.csproj +++ b/Torch/Torch.csproj @@ -156,6 +156,7 @@ + @@ -183,15 +184,18 @@ + + + @@ -216,12 +220,24 @@ + + + + + + + + + + + + - + diff --git a/Torch/TorchBase.cs b/Torch/TorchBase.cs index e000fd7..688e49a 100644 --- a/Torch/TorchBase.cs +++ b/Torch/TorchBase.cs @@ -15,6 +15,7 @@ using Sandbox.Game; using Sandbox.Game.Multiplayer; using Sandbox.Game.Screens.Helpers; using Sandbox.Game.World; +using Sandbox.Graphics.GUI; using Sandbox.ModAPI; using SpaceEngineers.Game; using Torch.API; @@ -32,6 +33,8 @@ using Torch.Session; using VRage.Collections; using VRage.FileSystem; using VRage.Game; +using VRage.Game.Common; +using VRage.Game.Components; using VRage.Game.ObjectBuilder; using VRage.ObjectBuilders; using VRage.Plugins; @@ -251,6 +254,18 @@ namespace Torch Debug.Assert(!_init, "Torch instance is already initialized."); SpaceEngineersGame.SetupBasicGameInfo(); SpaceEngineersGame.SetupPerGameSettings(); + // If the attached assemblies change (MySandboxGame.ctor => MySandboxGame.ParseArgs => MyPlugins.RegisterFromArgs) + // attach assemblies to object factories again. + ObjectFactoryInitPatch.ForceRegisterAssemblies(); + GameStateChanged += (game, state) => + { + if (state == TorchGameState.Created) + { + ObjectFactoryInitPatch.ForceRegisterAssemblies(); + // safe to commit here; all important static ctors have run + PatchManager.CommitInternal(); + } + }; Debug.Assert(MyPerGameSettings.BasicGameInfo.GameVersion != null, "MyPerGameSettings.BasicGameInfo.GameVersion != null"); GameVersion = new Version(new MyVersion(MyPerGameSettings.BasicGameInfo.GameVersion.Value).FormattedText.ToString().Replace("_", ".")); @@ -281,6 +296,10 @@ namespace Torch Managers.GetManager().LoadPlugins(); Managers.Attach(); _init = true; + + if (GameState >= TorchGameState.Created && GameState < TorchGameState.Unloading) + // safe to commit here; all important static ctors have run + PatchManager.CommitInternal(); } private void OnSessionLoading() diff --git a/Torch/Utils/MiscExtensions.cs b/Torch/Utils/MiscExtensions.cs new file mode 100644 index 0000000..c3a0362 --- /dev/null +++ b/Torch/Utils/MiscExtensions.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Torch.Utils +{ + public static class MiscExtensions + { + private static readonly ThreadLocal> _streamBuffer = new ThreadLocal>(() => new WeakReference(null)); + + private static long LengthSafe(this Stream stream) + { + try + { + return stream.Length; + } + catch + { + return 512; + } + } + + public static byte[] ReadToEnd(this Stream stream, int optionalDataLength = -1) + { + byte[] buffer; + if (!_streamBuffer.Value.TryGetTarget(out buffer)) + buffer = new byte[stream.LengthSafe()]; + var initialBufferSize = optionalDataLength > 0 ? optionalDataLength : stream.LengthSafe(); + if (buffer.Length < initialBufferSize) + buffer = new byte[initialBufferSize]; + if (buffer.Length < 1024) + buffer = new byte[1024]; + var streamPosition = 0; + while (true) + { + if (buffer.Length == streamPosition) + Array.Resize(ref buffer, Math.Max((int)stream.LengthSafe(), buffer.Length * 2)); + int count = stream.Read(buffer, streamPosition, buffer.Length - streamPosition); + if (count == 0) + break; + streamPosition += count; + } + var result = new byte[streamPosition]; + Array.Copy(buffer, 0, result, 0, result.Length); + _streamBuffer.Value.SetTarget(buffer); + return result; + } + } +} diff --git a/Torch/Utils/Reflected/ReflectedEventReplaceAttribute.cs b/Torch/Utils/Reflected/ReflectedEventReplaceAttribute.cs new file mode 100644 index 0000000..948614d --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedEventReplaceAttribute.cs @@ -0,0 +1,51 @@ +using System; + +namespace Torch.Utils +{ + /// + /// Attribute used to indicate that the the given field, of type ]]>, should be filled with + /// a function used to create a new event replacer. + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedEventReplaceAttribute : Attribute + { + /// + /// Type that the event is declared in + /// + public Type EventDeclaringType { get; set; } + /// + /// Name of the event + /// + public string EventName { get; set; } + + /// + /// Type that the method to replace is declared in + /// + public Type TargetDeclaringType { get; set; } + /// + /// Name of the method to replace + /// + public string TargetName { get; set; } + /// + /// Optional parameters of the method to replace. Null to ignore. + /// + public Type[] TargetParameters { get; set; } = null; + + /// + /// Creates a reflected event replacer attribute to, for the event defined as eventName in eventDeclaringType, + /// replace the method defined as targetName in targetDeclaringType with a custom callback. + /// + /// Type the event is declared in + /// Name of the event + /// Type the method to remove is declared in + /// Name of the method to remove + public ReflectedEventReplaceAttribute(Type eventDeclaringType, string eventName, Type targetDeclaringType, + string targetName) + { + EventDeclaringType = eventDeclaringType; + EventName = eventName; + TargetDeclaringType = targetDeclaringType; + TargetName = targetName; + } + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedEventReplacer.cs b/Torch/Utils/Reflected/ReflectedEventReplacer.cs new file mode 100644 index 0000000..f100e33 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedEventReplacer.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Torch.Utils +{ + /// + /// Instance of statefully replacing and restoring the callbacks of an event. + /// + public class ReflectedEventReplacer + { + private const BindingFlags BindFlagAll = BindingFlags.Static | + BindingFlags.Instance | + BindingFlags.Public | + BindingFlags.NonPublic; + + private object _instance; + private Func> _backingStoreReader; + private Action _callbackAdder; + private Action _callbackRemover; + private readonly ReflectedEventReplaceAttribute _attributes; + private readonly HashSet _registeredCallbacks = new HashSet(); + private readonly MethodInfo _targetMethodInfo; + + internal ReflectedEventReplacer(ReflectedEventReplaceAttribute attr) + { + _attributes = attr; + FieldInfo backingStore = GetEventBackingField(attr.EventName, attr.EventDeclaringType); + if (backingStore == null) + throw new ArgumentException($"Unable to find backing field for event {attr.EventDeclaringType.FullName}#{attr.EventName}"); + EventInfo evtInfo = ReflectedManager.GetFieldPropRecursive(attr.EventDeclaringType, attr.EventName, BindFlagAll, (a, b, c) => a.GetEvent(b, c)); + if (evtInfo == null) + throw new ArgumentException($"Unable to find event info for event {attr.EventDeclaringType.FullName}#{attr.EventName}"); + _backingStoreReader = () => GetEventsInternal(_instance, backingStore); + _callbackAdder = (x) => evtInfo.AddEventHandler(_instance, x); + _callbackRemover = (x) => evtInfo.RemoveEventHandler(_instance, x); + if (attr.TargetParameters == null) + { + _targetMethodInfo = attr.TargetDeclaringType.GetMethod(attr.TargetName, BindFlagAll); + if (_targetMethodInfo == null) + throw new ArgumentException($"Unable to find method {attr.TargetDeclaringType.FullName}#{attr.TargetName} to replace"); + } + else + { + _targetMethodInfo = + attr.TargetDeclaringType.GetMethod(attr.TargetName, BindFlagAll, null, attr.TargetParameters, null); + if (_targetMethodInfo == null) + throw new ArgumentException($"Unable to find method {attr.TargetDeclaringType.FullName}#{attr.TargetName}){string.Join(", ", attr.TargetParameters.Select(x => x.FullName))}) to replace"); + } + } + + /// + /// Test that this replacement can be performed. + /// + /// The instance to operate on, or null if static + /// true if possible, false if unsuccessful + public bool Test(object instance) + { + _instance = instance; + _registeredCallbacks.Clear(); + foreach (Delegate callback in _backingStoreReader.Invoke()) + if (callback.Method == _targetMethodInfo) + _registeredCallbacks.Add(callback); + + return _registeredCallbacks.Count > 0; + } + + private Delegate _newCallback; + + /// + /// Removes the target callback defined in the attribute and replaces it with the provided callback. + /// + /// The new event callback + /// The instance to operate on, or null if static + public void Replace(Delegate newCallback, object instance) + { + _instance = instance; + if (_newCallback != null) + throw new Exception("Reflected event replacer is in invalid state: Replace when already replaced"); + _newCallback = newCallback; + Test(instance); + if (_registeredCallbacks.Count == 0) + throw new Exception("Reflected event replacer is in invalid state: Nothing to replace"); + foreach (Delegate callback in _registeredCallbacks) + _callbackRemover.Invoke(callback); + _callbackAdder.Invoke(_newCallback); + } + + /// + /// Checks if the callback is currently replaced + /// + public bool Replaced => _newCallback != null; + + /// + /// Removes the callback added by and puts the original callback back. + /// + /// The instance to operate on, or null if static + public void Restore(object instance) + { + _instance = instance; + if (_newCallback == null) + throw new Exception("Reflected event replacer is in invalid state: Restore when not replaced"); + _callbackRemover.Invoke(_newCallback); + foreach (Delegate callback in _registeredCallbacks) + _callbackAdder.Invoke(callback); + _newCallback = null; + } + + + private static readonly string[] _backingFieldForEvent = { "{0}", "{0}" }; + + private static FieldInfo GetEventBackingField(string eventName, Type baseType) + { + FieldInfo eventField = null; + Type type = baseType; + while (type != null && eventField == null) + { + for (var i = 0; i < _backingFieldForEvent.Length && eventField == null; i++) + eventField = type.GetField(string.Format(_backingFieldForEvent[i], eventName), BindFlagAll); + type = type.BaseType; + } + return eventField; + } + + private static IEnumerable GetEventsInternal(object instance, FieldInfo eventField) + { + if (eventField.GetValue(instance) is MulticastDelegate eventDel) + { + foreach (Delegate handle in eventDel.GetInvocationList()) + yield return handle; + } + } + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedFieldInfoAttribute.cs b/Torch/Utils/Reflected/ReflectedFieldInfoAttribute.cs new file mode 100644 index 0000000..9553085 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedFieldInfoAttribute.cs @@ -0,0 +1,22 @@ +using System; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain the instance for the given field. + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedFieldInfoAttribute : ReflectedMemberAttribute + { + /// + /// Creates a reflected field info attribute using the given type and name. + /// + /// Type that contains the member + /// Name of the member + public ReflectedFieldInfoAttribute(Type type, string name) + { + Type = type; + Name = name; + } + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedGetterAttribute.cs b/Torch/Utils/Reflected/ReflectedGetterAttribute.cs new file mode 100644 index 0000000..738cda5 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedGetterAttribute.cs @@ -0,0 +1,28 @@ +using System; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain a delegate capable of retrieving the value of a field. + /// + /// + /// + /// _instanceGetter; + /// + /// [ReflectedGetterAttribute(Name="_staticField", Type=typeof(Example))] + /// private static Func _staticGetter; + /// + /// private class Example { + /// private int _instanceField; + /// private static int _staticField; + /// } + /// ]]> + /// + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedGetterAttribute : ReflectedMemberAttribute + { + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedLazyAttribute.cs b/Torch/Utils/Reflected/ReflectedLazyAttribute.cs new file mode 100644 index 0000000..28d3896 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedLazyAttribute.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Torch.Utils.Reflected +{ + /// + /// Indicates that the type will perform its own call to + /// + public class ReflectedLazyAttribute : Attribute + { + } +} diff --git a/Torch/Utils/ReflectedManager.cs b/Torch/Utils/Reflected/ReflectedManager.cs similarity index 50% rename from Torch/Utils/ReflectedManager.cs rename to Torch/Utils/Reflected/ReflectedManager.cs index c9a9fce..08eb4a1 100644 --- a/Torch/Utils/ReflectedManager.cs +++ b/Torch/Utils/Reflected/ReflectedManager.cs @@ -10,389 +10,24 @@ using System.Threading.Tasks; using NLog; using Sandbox.Engine.Multiplayer; using Torch.API; +using Torch.Utils.Reflected; namespace Torch.Utils { - public abstract class ReflectedMemberAttribute : Attribute - { - /// - /// Name of the member to access. If null, the tagged field's name. - /// - public string Name { get; set; } = null; - - /// - /// Declaring type of the member to access. If null, inferred from the instance argument type. - /// - public Type Type { get; set; } = null; - - /// - /// Assembly qualified name of - /// - public string TypeName - { - get => Type?.AssemblyQualifiedName; - set => Type = value == null ? null : Type.GetType(value, true); - } - } - #region MemberInfoAttributes - /// - /// Indicates that this field should contain the instance for the given field. - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedFieldInfoAttribute : ReflectedMemberAttribute - { - /// - /// Creates a reflected field info attribute using the given type and name. - /// - /// Type that contains the member - /// Name of the member - public ReflectedFieldInfoAttribute(Type type, string name) - { - Type = type; - Name = name; - } - } - /// - /// Indicates that this field should contain the instance for the given method. - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedMethodInfoAttribute : ReflectedMemberAttribute - { - /// - /// Creates a reflected method info attribute using the given type and name. - /// - /// Type that contains the member - /// Name of the member - public ReflectedMethodInfoAttribute(Type type, string name) - { - Type = type; - Name = name; - } - /// - /// Expected parameters of this method, or null if any parameters are accepted. - /// - public Type[] Parameters { get; set; } = null; - - /// - /// Assembly qualified names of - /// - public string[] ParameterNames - { - get => Parameters.Select(x => x.AssemblyQualifiedName).ToArray(); - set => Parameters = value?.Select(x => x == null ? null : Type.GetType(x)).ToArray(); - } - - /// - /// Expected return type of this method, or null if any return type is accepted. - /// - public Type ReturnType { get; set; } = null; - } - - /// - /// Indicates that this field should contain the instance for the given property. - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedPropertyInfoAttribute : ReflectedMemberAttribute - { - /// - /// Creates a reflected property info attribute using the given type and name. - /// - /// Type that contains the member - /// Name of the member - public ReflectedPropertyInfoAttribute(Type type, string name) - { - Type = type; - Name = name; - } - } #endregion #region FieldPropGetSet - /// - /// Indicates that this field should contain a delegate capable of retrieving the value of a field. - /// - /// - /// - /// _instanceGetter; - /// - /// [ReflectedGetterAttribute(Name="_staticField", Type=typeof(Example))] - /// private static Func _staticGetter; - /// - /// private class Example { - /// private int _instanceField; - /// private static int _staticField; - /// } - /// ]]> - /// - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedGetterAttribute : ReflectedMemberAttribute - { - } - /// - /// Indicates that this field should contain a delegate capable of setting the value of a field. - /// - /// - /// - /// _instanceSetter; - /// - /// [ReflectedSetterAttribute(Name="_staticField", Type=typeof(Example))] - /// private static Action _staticSetter; - /// - /// private class Example { - /// private int _instanceField; - /// private static int _staticField; - /// } - /// ]]> - /// - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedSetterAttribute : ReflectedMemberAttribute - { - } #endregion #region Invoker - /// - /// Indicates that this field should contain a delegate capable of invoking an instance method. - /// - /// - /// - /// ExampleInstance; - /// - /// private class Example { - /// private int ExampleInstance(int a, float b) { - /// return a + ", " + b; - /// } - /// } - /// ]]> - /// - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedMethodAttribute : ReflectedMemberAttribute - { - /// - /// When set the parameters types for the method are assumed to be this. - /// - public Type[] OverrideTypes { get; set; } - /// - /// Assembly qualified names of - /// - public string[] OverrideTypeNames - { - get => OverrideTypes.Select(x => x.AssemblyQualifiedName).ToArray(); - set => OverrideTypes = value?.Select(x => x == null ? null : Type.GetType(x)).ToArray(); - } - } - - /// - /// Indicates that this field should contain a delegate capable of invoking a static method. - /// - /// - /// - /// ExampleStatic; - /// - /// private class Example { - /// private static int ExampleStatic(int a, float b) { - /// return a + ", " + b; - /// } - /// } - /// ]]> - /// - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedStaticMethodAttribute : ReflectedMethodAttribute - { - } #endregion #region EventReplacer - /// - /// Instance of statefully replacing and restoring the callbacks of an event. - /// - public class ReflectedEventReplacer - { - private const BindingFlags BindFlagAll = BindingFlags.Static | - BindingFlags.Instance | - BindingFlags.Public | - BindingFlags.NonPublic; - private object _instance; - private Func> _backingStoreReader; - private Action _callbackAdder; - private Action _callbackRemover; - private readonly ReflectedEventReplaceAttribute _attributes; - private readonly HashSet _registeredCallbacks = new HashSet(); - private readonly MethodInfo _targetMethodInfo; - - internal ReflectedEventReplacer(ReflectedEventReplaceAttribute attr) - { - _attributes = attr; - FieldInfo backingStore = GetEventBackingField(attr.EventName, attr.EventDeclaringType); - if (backingStore == null) - throw new ArgumentException($"Unable to find backing field for event {attr.EventDeclaringType.FullName}#{attr.EventName}"); - EventInfo evtInfo = ReflectedManager.GetFieldPropRecursive(attr.EventDeclaringType, attr.EventName, BindFlagAll, (a, b, c) => a.GetEvent(b, c)); - if (evtInfo == null) - throw new ArgumentException($"Unable to find event info for event {attr.EventDeclaringType.FullName}#{attr.EventName}"); - _backingStoreReader = () => GetEventsInternal(_instance, backingStore); - _callbackAdder = (x) => evtInfo.AddEventHandler(_instance, x); - _callbackRemover = (x) => evtInfo.RemoveEventHandler(_instance, x); - if (attr.TargetParameters == null) - { - _targetMethodInfo = attr.TargetDeclaringType.GetMethod(attr.TargetName, BindFlagAll); - if (_targetMethodInfo == null) - throw new ArgumentException($"Unable to find method {attr.TargetDeclaringType.FullName}#{attr.TargetName} to replace"); - } - else - { - _targetMethodInfo = - attr.TargetDeclaringType.GetMethod(attr.TargetName, BindFlagAll, null, attr.TargetParameters, null); - if (_targetMethodInfo == null) - throw new ArgumentException($"Unable to find method {attr.TargetDeclaringType.FullName}#{attr.TargetName}){string.Join(", ", attr.TargetParameters.Select(x => x.FullName))}) to replace"); - } - } - - /// - /// Test that this replacement can be performed. - /// - /// The instance to operate on, or null if static - /// true if possible, false if unsuccessful - public bool Test(object instance) - { - _instance = instance; - _registeredCallbacks.Clear(); - foreach (Delegate callback in _backingStoreReader.Invoke()) - if (callback.Method == _targetMethodInfo) - _registeredCallbacks.Add(callback); - - return _registeredCallbacks.Count > 0; - } - - private Delegate _newCallback; - - /// - /// Removes the target callback defined in the attribute and replaces it with the provided callback. - /// - /// The new event callback - /// The instance to operate on, or null if static - public void Replace(Delegate newCallback, object instance) - { - _instance = instance; - if (_newCallback != null) - throw new Exception("Reflected event replacer is in invalid state: Replace when already replaced"); - _newCallback = newCallback; - Test(instance); - if (_registeredCallbacks.Count == 0) - throw new Exception("Reflected event replacer is in invalid state: Nothing to replace"); - foreach (Delegate callback in _registeredCallbacks) - _callbackRemover.Invoke(callback); - _callbackAdder.Invoke(_newCallback); - } - - /// - /// Checks if the callback is currently replaced - /// - public bool Replaced => _newCallback != null; - - /// - /// Removes the callback added by and puts the original callback back. - /// - /// The instance to operate on, or null if static - public void Restore(object instance) - { - _instance = instance; - if (_newCallback == null) - throw new Exception("Reflected event replacer is in invalid state: Restore when not replaced"); - _callbackRemover.Invoke(_newCallback); - foreach (Delegate callback in _registeredCallbacks) - _callbackAdder.Invoke(callback); - _newCallback = null; - } - - - private static readonly string[] _backingFieldForEvent = { "{0}", "{0}" }; - - private static FieldInfo GetEventBackingField(string eventName, Type baseType) - { - FieldInfo eventField = null; - Type type = baseType; - while (type != null && eventField == null) - { - for (var i = 0; i < _backingFieldForEvent.Length && eventField == null; i++) - eventField = type.GetField(string.Format(_backingFieldForEvent[i], eventName), BindFlagAll); - type = type.BaseType; - } - return eventField; - } - - private static IEnumerable GetEventsInternal(object instance, FieldInfo eventField) - { - if (eventField.GetValue(instance) is MulticastDelegate eventDel) - { - foreach (Delegate handle in eventDel.GetInvocationList()) - yield return handle; - } - } - } - - /// - /// Attribute used to indicate that the the given field, of type ]]>, should be filled with - /// a function used to create a new event replacer. - /// - [AttributeUsage(AttributeTargets.Field)] - public class ReflectedEventReplaceAttribute : Attribute - { - /// - /// Type that the event is declared in - /// - public Type EventDeclaringType { get; set; } - /// - /// Name of the event - /// - public string EventName { get; set; } - - /// - /// Type that the method to replace is declared in - /// - public Type TargetDeclaringType { get; set; } - /// - /// Name of the method to replace - /// - public string TargetName { get; set; } - /// - /// Optional parameters of the method to replace. Null to ignore. - /// - public Type[] TargetParameters { get; set; } = null; - - /// - /// Creates a reflected event replacer attribute to, for the event defined as eventName in eventDeclaringType, - /// replace the method defined as targetName in targetDeclaringType with a custom callback. - /// - /// Type the event is declared in - /// Name of the event - /// Type the method to remove is declared in - /// Name of the method to remove - public ReflectedEventReplaceAttribute(Type eventDeclaringType, string eventName, Type targetDeclaringType, - string targetName) - { - EventDeclaringType = eventDeclaringType; - EventName = eventName; - TargetDeclaringType = targetDeclaringType; - TargetName = targetName; - } - } #endregion /// @@ -438,7 +73,8 @@ namespace Torch.Utils public static void Process(Assembly asm) { foreach (Type type in asm.GetTypes()) - Process(type); + if (!type.HasAttribute()) + Process(type); } /// diff --git a/Torch/Utils/Reflected/ReflectedMemberAttribute.cs b/Torch/Utils/Reflected/ReflectedMemberAttribute.cs new file mode 100644 index 0000000..a8c6fbd --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedMemberAttribute.cs @@ -0,0 +1,26 @@ +using System; + +namespace Torch.Utils +{ + public abstract class ReflectedMemberAttribute : Attribute + { + /// + /// Name of the member to access. If null, the tagged field's name. + /// + public string Name { get; set; } = null; + + /// + /// Declaring type of the member to access. If null, inferred from the instance argument type. + /// + public Type Type { get; set; } = null; + + /// + /// Assembly qualified name of + /// + public string TypeName + { + get => Type?.AssemblyQualifiedName; + set => Type = value == null ? null : Type.GetType(value, true); + } + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedMethodAttribute.cs b/Torch/Utils/Reflected/ReflectedMethodAttribute.cs new file mode 100644 index 0000000..bd81a42 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedMethodAttribute.cs @@ -0,0 +1,40 @@ +using System; +using System.Linq; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain a delegate capable of invoking an instance method. + /// + /// + /// + /// ExampleInstance; + /// + /// private class Example { + /// private int ExampleInstance(int a, float b) { + /// return a + ", " + b; + /// } + /// } + /// ]]> + /// + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedMethodAttribute : ReflectedMemberAttribute + { + /// + /// When set the parameters types for the method are assumed to be this. + /// + public Type[] OverrideTypes { get; set; } + + /// + /// Assembly qualified names of + /// + public string[] OverrideTypeNames + { + get => OverrideTypes.Select(x => x.AssemblyQualifiedName).ToArray(); + set => OverrideTypes = value?.Select(x => x == null ? null : Type.GetType(x)).ToArray(); + } + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedMethodInfoAttribute.cs b/Torch/Utils/Reflected/ReflectedMethodInfoAttribute.cs new file mode 100644 index 0000000..65be9e7 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedMethodInfoAttribute.cs @@ -0,0 +1,41 @@ +using System; +using System.Linq; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain the instance for the given method. + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedMethodInfoAttribute : ReflectedMemberAttribute + { + /// + /// Creates a reflected method info attribute using the given type and name. + /// + /// Type that contains the member + /// Name of the member + public ReflectedMethodInfoAttribute(Type type, string name) + { + Type = type; + Name = name; + } + /// + /// Expected parameters of this method, or null if any parameters are accepted. + /// + public Type[] Parameters { get; set; } = null; + + /// + /// Assembly qualified names of + /// + public string[] ParameterNames + { + get => Parameters.Select(x => x.AssemblyQualifiedName).ToArray(); + set => Parameters = value?.Select(x => x == null ? null : Type.GetType(x)).ToArray(); + } + + /// + /// Expected return type of this method, or null if any return type is accepted. + /// + public Type ReturnType { get; set; } = null; + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedPropertyInfoAttribute.cs b/Torch/Utils/Reflected/ReflectedPropertyInfoAttribute.cs new file mode 100644 index 0000000..4964f8a --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedPropertyInfoAttribute.cs @@ -0,0 +1,22 @@ +using System; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain the instance for the given property. + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedPropertyInfoAttribute : ReflectedMemberAttribute + { + /// + /// Creates a reflected property info attribute using the given type and name. + /// + /// Type that contains the member + /// Name of the member + public ReflectedPropertyInfoAttribute(Type type, string name) + { + Type = type; + Name = name; + } + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedSetterAttribute.cs b/Torch/Utils/Reflected/ReflectedSetterAttribute.cs new file mode 100644 index 0000000..44cdd07 --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedSetterAttribute.cs @@ -0,0 +1,28 @@ +using System; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain a delegate capable of setting the value of a field. + /// + /// + /// + /// _instanceSetter; + /// + /// [ReflectedSetterAttribute(Name="_staticField", Type=typeof(Example))] + /// private static Action _staticSetter; + /// + /// private class Example { + /// private int _instanceField; + /// private static int _staticField; + /// } + /// ]]> + /// + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedSetterAttribute : ReflectedMemberAttribute + { + } +} \ No newline at end of file diff --git a/Torch/Utils/Reflected/ReflectedStaticMethodAttribute.cs b/Torch/Utils/Reflected/ReflectedStaticMethodAttribute.cs new file mode 100644 index 0000000..1552e3e --- /dev/null +++ b/Torch/Utils/Reflected/ReflectedStaticMethodAttribute.cs @@ -0,0 +1,26 @@ +using System; + +namespace Torch.Utils +{ + /// + /// Indicates that this field should contain a delegate capable of invoking a static method. + /// + /// + /// + /// ExampleStatic; + /// + /// private class Example { + /// private static int ExampleStatic(int a, float b) { + /// return a + ", " + b; + /// } + /// } + /// ]]> + /// + /// + [AttributeUsage(AttributeTargets.Field)] + public class ReflectedStaticMethodAttribute : ReflectedMethodAttribute + { + } +} \ No newline at end of file diff --git a/Torch/Utils/TorchAssemblyResolver.cs b/Torch/Utils/TorchAssemblyResolver.cs index db2c9ee..dab28c5 100644 --- a/Torch/Utils/TorchAssemblyResolver.cs +++ b/Torch/Utils/TorchAssemblyResolver.cs @@ -38,6 +38,8 @@ namespace Torch.Utils return path.StartsWith(_removablePathPrefix) ? path.Substring(_removablePathPrefix.Length) : path; } + private static readonly string[] _tryExtensions = {".dll", ".exe"}; + private Assembly CurrentDomainOnAssemblyResolve(object sender, ResolveEventArgs args) { string assemblyName = new AssemblyName(args.Name).Name; @@ -57,18 +59,21 @@ namespace Torch.Utils { try { - string assemblyPath = Path.Combine(path, assemblyName + ".dll"); - if (!File.Exists(assemblyPath)) - continue; - _log.Trace("Loading {0} from {1}", assemblyName, SimplifyPath(assemblyPath)); - LogManager.Flush(); - Assembly asm = Assembly.LoadFrom(assemblyPath); - _assemblies.Add(assemblyName, asm); - // Recursively load SE dependencies since they don't trigger AssemblyResolve. - // This trades some performance on load for actually working code. - foreach (AssemblyName dependency in asm.GetReferencedAssemblies()) - CurrentDomainOnAssemblyResolve(sender, new ResolveEventArgs(dependency.Name, asm)); - return asm; + foreach (var tryExt in _tryExtensions) + { + string assemblyPath = Path.Combine(path, assemblyName + tryExt); + if (!File.Exists(assemblyPath)) + continue; + _log.Trace("Loading {0} from {1}", assemblyName, SimplifyPath(assemblyPath)); + LogManager.Flush(); + Assembly asm = Assembly.LoadFrom(assemblyPath); + _assemblies.Add(assemblyName, asm); + // Recursively load SE dependencies since they don't trigger AssemblyResolve. + // This trades some performance on load for actually working code. + foreach (AssemblyName dependency in asm.GetReferencedAssemblies()) + CurrentDomainOnAssemblyResolve(sender, new ResolveEventArgs(dependency.Name, asm)); + return asm; + } } catch {