Files
se-launcher/CringeLauncher/Patches/ModScriptCompilerPatch.cs
pas2704 bd626f7a2b
All checks were successful
Build / Compute Version (push) Successful in 6s
Build / Build Nuget package (CringeBootstrap.Abstractions) (push) Successful in 1m30s
Build / Build Nuget package (SharedCringe) (push) Successful in 1m45s
Build / Build Nuget package (NuGet) (push) Successful in 1m47s
Build / Build Nuget package (CringePlugins) (push) Successful in 1m58s
Build / Build Launcher (push) Successful in 2m24s
Fix init when pasting in a programmable block
Improvements for imgui input handling
2025-05-16 22:52:15 -04:00

319 lines
14 KiB
C#

using CringeBootstrap.Abstractions;
using CringeLauncher.Loader;
using HarmonyLib;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Emit;
using NLog;
using Sandbox;
using Sandbox.Game;
using Sandbox.Game.Entities.Blocks;
using Sandbox.Game.EntityComponents;
using Sandbox.Game.Gui;
using Sandbox.Game.Localization;
using Sandbox.Game.World;
using Sandbox.Graphics.GUI;
using Sandbox.ModAPI;
using Sandbox.ModAPI.Ingame;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using System.Text;
using VRage;
using VRage.Collections;
using VRage.ModAPI;
using VRage.Scripting;
using Message = VRage.Scripting.Message;
namespace CringeLauncher.Patches;
[HarmonyPatch]
public static class ModScriptCompilerPatch
{
internal static readonly MyConcurrentHashSet<MyProgrammableBlock> CompilingPbs = [];
private static readonly Logger Log = LogManager.GetCurrentClassLogger();
private static ModAssemblyLoadContext _modContext;
private static readonly HashSet<string> LoadedModAssemblyNames = [];
private static readonly ConditionalWeakTable<MyProgrammableBlock, PbAssemblyLoadContext> LoadContexts = [];
private static readonly FieldInfo InstanceField = AccessTools.Field(typeof(MyProgrammableBlock), "m_instance");
private static readonly PropertyInfo AssemblyProperty = AccessTools.Property(typeof(MyProgrammableBlock), "CurrentAssembly");
private static readonly FieldInfo CompilerErrorsField = AccessTools.Field(typeof(MyProgrammableBlock), "m_compilerErrors");
private static readonly MethodInfo CreateInstanceMethod = AccessTools.Method(typeof(MyProgrammableBlock), "CreateInstance");
private static readonly MethodInfo SetDetailedInfoMethod = AccessTools.Method(typeof(MyProgrammableBlock), "SetDetailedInfo");
private static readonly ICoreLoadContext CoreContext = (ICoreLoadContext)AssemblyLoadContext.GetLoadContext(typeof(MySession).Assembly)!;
private static readonly DiagnosticAnalyzer ModWhitelistAnalyzer = AccessTools.FieldRefAccess<MyScriptCompiler, DiagnosticAnalyzer>(
MyScriptCompiler.Static, "m_modApiWhitelistDiagnosticAnalyzer");
private static readonly DiagnosticAnalyzer ScriptWhitelistAnalyzer =
AccessTools.FieldRefAccess<MyScriptCompiler, DiagnosticAnalyzer>(MyScriptCompiler.Static, "m_inGameWhitelistDiagnosticAnalyzer");
private static readonly Func<CSharpCompilation, SyntaxTree, int, SyntaxTree> InjectMod = AccessTools.MethodDelegate<Func<CSharpCompilation, SyntaxTree, int, SyntaxTree>>(
AccessTools.Method(typeof(MyScriptCompiler), "InjectMod"), MyScriptCompiler.Static);
private static readonly Func<CSharpCompilation, SyntaxTree, bool, SyntaxTree> InjectResourceMonitoring = AccessTools.MethodDelegate<Func<CSharpCompilation, SyntaxTree, bool, SyntaxTree>>(
AccessTools.Method(typeof(MyScriptCompiler), "InjectResourceMonitoring"), MyScriptCompiler.Static);
private static readonly Func<CompilationWithAnalyzers, EmitResult, List<Message>, bool, Task<bool>> EmitDiagnostics = AccessTools.MethodDelegate<Func<CompilationWithAnalyzers, EmitResult, List<Message>, bool, Task<bool>>>(
AccessTools.Method(typeof(MyScriptCompiler), "EmitDiagnostics"), MyScriptCompiler.Static);
private static readonly Func<string, string> MakeAssemblyName = AccessTools.MethodDelegate<Func<string, string>>(AccessTools.Method(typeof(MyScriptCompiler),
"MakeAssemblyName"));
private static readonly Func<MyScriptCompiler, string, IEnumerable<Script>, bool, CSharpCompilation> CreateCompilation =
AccessTools.MethodDelegate<Func<MyScriptCompiler, string, IEnumerable<Script>, bool, CSharpCompilation>>(AccessTools.Method(typeof(MyScriptCompiler),
"CreateCompilation"));
static ModScriptCompilerPatch()
{
MySession.OnUnloaded += OnUnloaded;
_modContext = new(CoreContext);
}
private static void OnUnloaded()
{
LoadedModAssemblyNames.Clear();
if (!_modContext.Assemblies.Any())
return;
_modContext.Unload();
_modContext = new(CoreContext);
}
[HarmonyPatch(typeof(MyProgrammableBlock), "Compile")]
[HarmonyPrefix]
private static bool CompilePrefix(MyProgrammableBlock __instance, string program, string storage, bool instantiate,
ref MyProgrammableBlock.ScriptTerminationReason ___m_terminationReason,
MyIngameScriptComponent ___m_scriptComponent)
{
if (!MySession.Static.EnableIngameScripts || __instance.CubeGrid is { IsPreview: true } or { CreatePhysics: false } || !CompilingPbs.Add(__instance))
return false;
___m_terminationReason = MyProgrammableBlock.ScriptTerminationReason.None;
CompileAsync(__instance, program, storage, instantiate, ___m_scriptComponent);
return false;
}
[HarmonyPatch(typeof(MyGuiScreenEditor), "CheckCodeButtonClicked")]
[HarmonyPrefix]
private static bool GuiCompilePrefix(List<string> ___m_compilerErrors, MyGuiScreenEditor __instance)
{
___m_compilerErrors.Clear();
var progress = new MyGuiScreenProgress(MyTexts.Get(MySpaceTexts.ProgrammableBlock_Editor_CheckingCode));
MyScreenManager.AddScreen(progress);
if (__instance.Description.Text.Length > 0)
{
var task = CompileAsync(__instance, ___m_compilerErrors, __instance.Description.Text.ToString(), progress);
task.ConfigureAwait(false).GetAwaiter().GetResult();
MyScreenManager.RemoveScreen(progress);
MyVRage.Platform.ImeProcessor?.RegisterActiveScreen(__instance);
__instance.FocusedControl = __instance.Description;
}
return false;
}
[HarmonyPatch(typeof(MyScriptCompiler), nameof(MyScriptCompiler.Compile))]
[HarmonyPrefix]
private static bool Prefix(ref Task<Assembly?> __result, MyApiTarget target, string assemblyName, IEnumerable<Script> scripts,
List<Message> messages, string friendlyName, bool enableDebugInformation = false)
{
__result = CompileAsync(_modContext, target, assemblyName, scripts, messages, friendlyName,
enableDebugInformation);
return false;
}
private static async Task CompileAsync(MyGuiScreenEditor editor, List<string> errors, string program, MyGuiScreenProgress progress)
{
var context = new PbAssemblyLoadContext(CoreContext, editor.Name);
var messages = new List<Message>();
var script = MyVRage.Platform.Scripting.GetIngameScript(program, "Program", nameof(MyGridProgram));
await CompileAsync(context, MyApiTarget.Ingame, "check", [script], messages,
"PB Code Editor", true);
errors.AddRange(messages.OrderBy(b => b.IsError ? 0 : 1).Select(b => b.Text));
context.Unload();
progress.CloseScreen();
if (errors.Count > 0)
{
var sb = new StringBuilder(errors.Sum(b => b.Length + Environment.NewLine.Length));
foreach (var error in errors)
{
sb.AppendLine(error);
}
MyScreenManager.AddScreen(new MyGuiScreenEditorError(sb.ToString()));
return;
}
var messageBox = MyGuiSandbox.CreateMessageBox(MyMessageBoxStyleEnum.Info, MyMessageBoxButtonsType.OK,
MyTexts.Get(MySpaceTexts.ProgrammableBlock_Editor_CompilationOk),
MyTexts.Get(MySpaceTexts.ProgrammableBlock_CodeEditor_Title));
MyGuiSandbox.AddScreen(messageBox);
}
private static async void CompileAsync(MyProgrammableBlock block,
string program,
string storage,
bool instantiate, MyIngameScriptComponent scriptComponent)
{
try
{
scriptComponent.NeedsUpdate = MyEntityUpdateEnum.NONE;
scriptComponent.UpdateFrequency = UpdateFrequency.None;
SetDetailedInfoMethod.Invoke(block, ["Compiling..."]);
if (LoadContexts.TryGetValue(block, out var context))
{
AccessTools.FieldRefAccess<MyProgrammableBlock, IMyGridProgram?>(block, InstanceField) = null;
AssemblyProperty.SetValue(block, null);
context.Unload();
}
LoadContexts.AddOrUpdate(block, context = new(CoreContext, $"pb_{block.EntityId}"));
var messages = new List<Message>();
var assembly = await CompileAsync(context, MyApiTarget.Ingame, $"pb_{block.EntityId}_{Random.Shared.NextInt64()}",
[MyVRage.Platform.Scripting.GetIngameScript(program, "Program", nameof(MyGridProgram))],
messages, $"PB: {block.DisplayName} ({block.EntityId})", true);
AssemblyProperty.SetValue(block, assembly);
var errors = AccessTools.FieldRefAccess<MyProgrammableBlock, List<string>>(block, CompilerErrorsField);
errors.Clear();
errors.AddRange(messages.Select(b => b.Text));
if (instantiate)
{
MySandboxGame.Static.Invoke(() => CreateInstanceMethod.Invoke(block, [assembly, errors, storage]),
nameof(CompileAsync));
}
}
catch (Exception e)
{
SetDetailedInfoMethod.Invoke(block, [e.ToString()]);
Log.Error(e);
}
finally
{
CompilingPbs.Remove(block);
}
}
private static async Task<Assembly?> CompileAsync(AssemblyLoadContext context, MyApiTarget target,
string assemblyName, IEnumerable<Script> scripts,
List<Message> messages, string? friendlyName, bool trackMemoryUsage = false,
bool enableDebugInformation = false)
{
friendlyName ??= "<No Name>";
var assemblyFileName = MakeAssemblyName(assemblyName);
Func<CSharpCompilation, SyntaxTree, bool, SyntaxTree>? syntaxTreeInjector;
DiagnosticAnalyzer? whitelistAnalyzer;
switch (target)
{
case MyApiTarget.None:
whitelistAnalyzer = null;
syntaxTreeInjector = null;
break;
case MyApiTarget.Mod:
{
var modId = MyModWatchdog.AllocateModId(friendlyName);
whitelistAnalyzer = ModWhitelistAnalyzer;
syntaxTreeInjector = (c, st, _) => InjectMod(c, st, modId);
//skip if name exists already
if (!LoadedModAssemblyNames.Add(assemblyFileName))
{
Console.WriteLine($"{assemblyFileName} is already loaded, skipping");
return null;
}
break;
}
case MyApiTarget.Ingame:
syntaxTreeInjector = InjectResourceMonitoring;
whitelistAnalyzer = ScriptWhitelistAnalyzer;
break;
default:
throw new ArgumentOutOfRangeException(nameof(target), target, "Invalid compilation target");
}
var compilation = CreateCompilation(MyScriptCompiler.Static, assemblyFileName, scripts, enableDebugInformation);
var compilationWithoutInjection = compilation;
var injectionFailed = false;
if (syntaxTreeInjector != null)
{
SyntaxTree[]? newSyntaxTrees = null;
try
{
var syntaxTrees = compilation.SyntaxTrees;
if (syntaxTrees.Length == 1)
{
newSyntaxTrees = [syntaxTreeInjector(compilation, syntaxTrees[0], trackMemoryUsage)];
}
else
{
var compilation1 = compilation;
newSyntaxTrees = await Task
.WhenAll(syntaxTrees.Select(
x => Task.Run(() => syntaxTreeInjector(compilation1, x, trackMemoryUsage)))).ConfigureAwait(false);
}
}
catch (Exception e)
{
Log.Warn(e);
injectionFailed = true;
}
if (newSyntaxTrees is not null)
compilation = compilation.RemoveAllSyntaxTrees().AddSyntaxTrees(newSyntaxTrees);
}
CompilationWithAnalyzers? analyticCompilation = null;
if (whitelistAnalyzer != null)
{
analyticCompilation = compilation.WithAnalyzers([whitelistAnalyzer]);
compilation = (CSharpCompilation)analyticCompilation.Compilation;
}
await using var assemblyStream = new MemoryStream();
var emitResult = compilation.Emit(assemblyStream);
var success = emitResult.Success;
var myBlacklistSyntaxVisitor = new MyBlacklistSyntaxVisitor();
foreach (var syntaxTree in compilation.SyntaxTrees)
{
myBlacklistSyntaxVisitor.SetSemanticModel(compilation.GetSemanticModel(syntaxTree, false));
myBlacklistSyntaxVisitor.Visit(await syntaxTree.GetRootAsync());
}
if (myBlacklistSyntaxVisitor.HasAnyResult())
{
myBlacklistSyntaxVisitor.GetResultMessages(messages);
}
else
{
success = await EmitDiagnostics(analyticCompilation!, emitResult, messages, success).ConfigureAwait(false);
assemblyStream.Seek(0, SeekOrigin.Begin);
if (injectionFailed)
return null;
if (success)
return context.LoadFromStream(assemblyStream);
await EmitDiagnostics(analyticCompilation!, compilationWithoutInjection.Emit(assemblyStream), messages,
false).ConfigureAwait(false);
}
return null;
}
}