diff --git a/Analysis/alterlib/alterlib.py b/Analysis/alterlib/alterlib.py new file mode 100644 index 000000000..bb18b5f9e --- /dev/null +++ b/Analysis/alterlib/alterlib.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +import os, sys, io, collections, struct, unix_ar +LuaObject = collections.namedtuple('LuaObject', ['info', 'contents']) + +if input("Have you made sure you verified that FileUtils::CFileToLua is compatible with the new lua51.lib file? (y/n) ") != 'y': + exit() + +# Open lua lib in read mode +lua_lib = unix_ar.open(sys.argv[1], mode = "r") + +# Get contents of io library +lua_objects = [] +for entry in lua_lib.infolist(): + current_index = len(lua_objects) + lua_objects.append(LuaObject(entry, lua_lib.extract(entry, path=io.BytesIO()).read())) + if entry.name.strip(b"/").decode('iso-8859-1') == "lib_io.obj": + io_lib_index = current_index + +# Close lua lib +lua_lib.close() + +# Get io object file +io_lib = bytearray(lua_objects[io_lib_index].contents) + +# Get address of string table +symbolTableAddr, = struct.unpack_from("= 0, 'String "_io_std_new" was not found' +functionStringOffset = functionStringAddr - stringTableAddr +assert functionStringOffset >= 0, 'String "_io_std_new" is outside of the string table' + +# Search address of symbol table entry +for symbolIndex in range(numberOfSymbols): + symbolAddr = symbolTableAddr + symbolIndex * 18 + mustBeZero, = struct.unpack_from("\n') + else: + raise ValueError("mode must be one of 'r' or 'w'") + + def _read_entries(self): + if self._file.read(8) != b'!\n': + raise ValueError("Invalid archive signature") + + self._entries = [] + self._name_map = {} + pos = 8 + while True: + buffer = self._file.read(60) + if len(buffer) == 0: + break + elif len(buffer) == 60: + member = ArInfo.frombuffer(buffer) + member.offset = pos + self._name_map[member.name] = len(self._entries) + self._entries.append(member) + skip = member.size + if skip % 2 != 0: + skip += 1 + pos += 60 + skip + self._file.seek(skip, 1) + if pos == self._file.tell(): + continue + raise ValueError("Truncated archive?") + + def _check(self, expected_mode): + if self._file is None: + raise ValueError("Attempted to use a closed %s" % self.__class__.__name__) + if self._mode != expected_mode: + if self._mode == 'r': + raise ValueError("Can't change a read-only archive") + else: + raise ValueError("Can't read from a write-only archive") + + def add(self, name, arcname=None): + """ + Add a file to the archive. + + :param name: Path to the file to be added. + :type name: bytes | unicode + :param arcname: Name the file will be stored as in the archive, or + a full :class:`~unix_ar.ArInfo`. If unset, `name` will be used. + :type arcname: None | bytes | unicode | unix_ar.ArInfo + """ + self._check('w') + if arcname is None: + arcname = ArInfo(name) + elif not isinstance(arcname, ArInfo): + arcname = ArInfo(arcname) + arcname = arcname.updatefromdisk(name) + with _open(name, 'rb') as fp: + self.addfile(arcname, fp) + + def addfile(self, name, fileobj=None): + """ + Add a file to the archive from a file object. + + :param name: Name the file will be stored as in the archive, or + a full :class:`~unix_ar.ArInfo`. + :type name: bytes | unicode | unix_ar.ArInfo + :param fileobj: File object to read from. + """ + self._check('w') + if not isinstance(name, ArInfo): + name = ArInfo(name) + + name = name.updatefromdisk() + + self._file.write(name.tobuffer()) + if fileobj is None: + fp = _open(name.name, 'rb') + else: + fp = fileobj + + for pos in range(0, name.size, CHUNKSIZE): + chunk = fp.read(min(CHUNKSIZE, name.size - pos)) + if len(chunk) != CHUNKSIZE and len(chunk) != name.size - pos: + raise RuntimeError("File changed size?") + self._file.write(chunk) + if name.size % 2 == 1: + self._file.write(b'\n') + + if fileobj is None: + fp.close() + + def infolist(self): + """ + Return a list of :class:`~unix_ar.ArInfo` for files in the archive. + + These objects are copy, so feel free to change them before feeding them + to :meth:`~unix_ar.ArFile.add()` or :meth:`~unix_ar.ArFile.addfile()`. + + :rtype: [unix_ar.ArInfo] + """ + self._check('r') + return list(i.__copy__() for i in self._entries) + + def getinfo(self, member): + """ + Return an :class:`~unix_ar.ArInfo` for a specific file. + + This object is a copy, so feel free to change it before feeding them to + :meth:`~unix_ar.ArFile.add()` or :meth:`~unix_ar.ArFile.addfile()`. + + :param member: Either a file name or an incomplete + :class:`unix_ar.ArInfo` object to search for. + :type member: bytes | unicode | unix_ar.ArInfo + :rtype: unix_ar.ArInfo + """ + self._check('r') + if isinstance(member, ArInfo): + if member.offset is not None: + self._file.seek(member.offset, 0) + return ArInfo.frombuffer(self._file.read(60)) + else: + index = self._name_map[member.name] + return self._entries[index].__copy__() + else: + index = self._name_map[utf8(member)] + return self._entries[index].__copy__() + + def _extract(self, member, path): + if hasattr(path, 'write'): + fp = path + else: + fp = _open(path.rstrip(b'/'), 'wb') + + self._file.seek(member.offset + 60, 0) + for pos in range(0, member.size, CHUNKSIZE): + chunk = self._file.read(min(CHUNKSIZE, member.size - pos)) + fp.write(chunk) + fp.flush() + fp.seek(0) + return fp + + def extract(self, member, path='') -> 'filelike': + """ + Extract a single file from the archive. + + :param member: Either a file name or an :class:`unix_ar.ArInfo` object + to extract. + :type member: bytes | unicode | unix_ar.ArInfo + :param path: Destination path (current directory by default). You can + also change the `name` attribute on the `ArInfo` you pass this + method to extract to any file name. + :type path: bytes | unicode + """ + self._check('r') + actualmember = self.getinfo(member) + if isinstance(member, ArInfo): + if member.offset is None: + member.offset = actualmember.offset + if member.size > actualmember.size: + member.size = actualmember.size + else: + member = actualmember + if not hasattr(path, 'write'): + if not path or os.path.isdir(path): + path = os.path.join(utf8(path), member.name) + return self._extract(member, path) + + def extractfile(self, member): + self._check('r') + raise NotImplementedError("extractfile() is not yet implemented") + + def extractall(self, path=''): + """ + Extract all the files in the archive. + + :param path: Destination path (current directory by default). + :type path: bytes | unicode + """ + self._check('r') + # Iterate on _name_map instead of plain _entries so we don't extract + # multiple files with the same name, just the last one + for index in self._name_map.values(): + member = self._entries[index] + self._extract(member, os.path.join(utf8(path), member.name)) + + def open(self, member: str) -> io.BytesIO: + filelike = self.extract(member, path=io.BytesIO()) + filelike.name = member.strip('/') + return filelike + + def close(self): + """ + Close this archive and the underlying file. + + No method should be called on the object after this. + """ + if self._file is not None: + self._file.close() + self._file = None + self._entries = None + self._name_map = None + + +def open(file, mode='r'): + """ + Open an archive file. + + :param file: File name to open. + :type file: bytes | unicode + :param mode: Either ''r' or 'w' + :rtype: unix_ar.ArFile + """ + if hasattr(file, 'read'): + return ArFile(file, mode) + else: + if mode == 'r' or mode == 'rb': + omode = 'rb' + elif mode == 'w' or mode == 'wb': + omode = 'wb' + else: + raise ValueError("mode must be one of 'r' or 'w'") + return ArFile(_open(file, omode), mode) diff --git a/LunaDll/CMakeLists.txt b/LunaDll/CMakeLists.txt index ac516fde0..dabe83cbe 100644 --- a/LunaDll/CMakeLists.txt +++ b/LunaDll/CMakeLists.txt @@ -171,6 +171,7 @@ set(LunaLua_Sources Misc/AsyncHTTPClient.cpp Misc/CollisionMatrix.cpp Misc/ErrorReporter.cpp + Misc/FileUtils.cpp Misc/FreeImageUtils/FreeImageData.cpp Misc/FreeImageUtils/FreeImageGifData.cpp Misc/FreeImageUtils/FreeImageHelper.cpp @@ -198,6 +199,7 @@ set(LunaLua_Sources Misc/RuntimeHookManagers/LevelHUDController.cpp Misc/RuntimeHookUtils/APIHook.cpp Misc/SafeFPUControl.cpp + Misc/Syscalls.cpp Misc/TestMode.cpp Misc/TestModeMenu.cpp Misc/TypeLib.cpp diff --git a/LunaDll/FileManager/SaveFile.cpp b/LunaDll/FileManager/SaveFile.cpp index b6df0ebd7..f212240f9 100644 --- a/LunaDll/FileManager/SaveFile.cpp +++ b/LunaDll/FileManager/SaveFile.cpp @@ -107,11 +107,8 @@ void __stdcall SMBXSaveFile::Save() // Write data to file in atomic fashion std::wstring worldPath = SMBX13::Vars::SelectWorld[SMBX13::Vars::selWorld].WorldPath; std::wstring saveFilePath = worldPath + L"save" + std::to_wstring(SMBX13::Vars::selSave) + L".sav"; - LunaPathValidator::Result* ret = LunaPathValidator::GetForThread().CheckPath(WStr2Str(saveFilePath).c_str()); - if (ret && ret->canWrite) - { - writeFileAtomic(saveFilePath, rawStr.c_str(), rawStr.size()); - } + + LunaPathValidator::GetForThread().WriteFileAtomic(saveFilePath, rawStr.c_str(), rawStr.size()); } //void __stdcall SMBXSaveFile::Load() diff --git a/LunaDll/LuaMain/LuaProxy.h b/LunaDll/LuaMain/LuaProxy.h index 9a2af432c..c1be514bb 100644 --- a/LunaDll/LuaMain/LuaProxy.h +++ b/LunaDll/LuaMain/LuaProxy.h @@ -917,6 +917,7 @@ namespace LuaProxy { void warning(const std::string& str); void registerCharacterId(const luabind::object& namedArgs, lua_State* L); std::string showRichDialog(const std::string& title, const std::string& rtfText, bool isReadOnly); + luabind::object __getLuaFileFromCFile(std::uintptr_t address, bool forIoLines, lua_State* L); // Internal use profiler functions void __enablePerfTracker(); diff --git a/LunaDll/LuaMain/LuaProxyComponent/LuaProxyGlobalFunctions/LuaProxyGlobalFuncMisc.cpp b/LunaDll/LuaMain/LuaProxyComponent/LuaProxyGlobalFunctions/LuaProxyGlobalFuncMisc.cpp index 5a31deaec..7960278fb 100644 --- a/LunaDll/LuaMain/LuaProxyComponent/LuaProxyGlobalFunctions/LuaProxyGlobalFuncMisc.cpp +++ b/LunaDll/LuaMain/LuaProxyComponent/LuaProxyGlobalFunctions/LuaProxyGlobalFuncMisc.cpp @@ -8,6 +8,10 @@ #include "../../../Misc/RuntimeHook.h" #include "../../../Misc/Gui/RichTextDialog.h" #include "../../../Misc/PerfTracker.h" +#include "../../../Misc/FileUtils.h" +#include "../../LunaPathValidator.h" +#include "lauxlib.h" +#include void LuaProxy::Misc::npcToCoins() { @@ -70,20 +74,17 @@ void LuaProxy::Misc::cheatBuffer(const luabind::object &value, lua_State* L) } -luabind::object listByAttributes(const std::string& path, DWORD attributes, lua_State* L) +luabind::object listByAttributes(const std::string& path, DWORD attributes, lua_State* L) { luabind::object theList = luabind::newtable(L); - std::string modulePath = path; - if (!isAbsolutePath(path)) - { - modulePath = gAppPathUTF8; - modulePath += "\\"; - modulePath += path; - } - std::vector listedFiles = listOfDir(path, attributes); - for (unsigned int i = 0; i < listedFiles.size(); ++i) { - theList[i + 1] = listedFiles[i]; + std::vector listedFiles; + + if (LunaPathValidator::GetForThread().ListOfDir(path, attributes, listedFiles)) { + for (unsigned int i = 0; i < listedFiles.size(); ++i) { + theList[i + 1] = listedFiles[i]; + } } + return theList; } @@ -322,6 +323,10 @@ std::string LuaProxy::Misc::showRichDialog(const std::string& title, const std:: return dialog.getRtfText(); } +luabind::object LuaProxy::Misc::__getLuaFileFromCFile(std::uintptr_t address, bool forIoLines, lua_State* L) { + return FileUtils::CFileToLua(L, (std::FILE*) address, forIoLines); +} + // Internal use profiler functions void LuaProxy::Misc::__enablePerfTracker() { diff --git a/LunaDll/LuaMain/LuaProxyFFI.cpp b/LunaDll/LuaMain/LuaProxyFFI.cpp index 87cf157bd..ad9c9b4d8 100644 --- a/LunaDll/LuaMain/LuaProxyFFI.cpp +++ b/LunaDll/LuaMain/LuaProxyFFI.cpp @@ -1,4 +1,5 @@ +#include #include #include #include @@ -39,6 +40,14 @@ PlayerMOB* getTemplateForCharacter(int id); // Defined in RuntimeHookNpcHarm.cpp void markNPCTransformationAsHandledByLua(short npcIdx, short oldID, short newID); +extern "C" { + // For exposing strings to lua + struct FFIString { + char const* buf; + std::size_t size; + }; +} + extern "C" { FFI_EXPORT(void*) LunaLuaAlloc(size_t size) { CLunaFFILock ffiLock(__FUNCTION__); @@ -487,29 +496,35 @@ typedef struct ExtendedPlayerFields_\ } ExtendedPlayerFields;"; } - FFI_EXPORT(unsigned int) LunaLuaCollisionMatrixAllocateIndex() + FFI_EXPORT(FFIString*) LunaLuaCollisionMatrixGetGroupFromIndex(unsigned int groupIndex) { - return gCollisionMatrix.allocateIndex(); + std::string const& collisionGroup = gCollisionMatrix.getGroupFromIndex(groupIndex); + + static FFIString collisionGroupFFI; + collisionGroupFFI.buf = collisionGroup.c_str(); + collisionGroupFFI.size = collisionGroup.size(); + + return &collisionGroupFFI; } - FFI_EXPORT(void) LunaLuaCollisionMatrixIncrementReferenceCount(unsigned int group) - { - gCollisionMatrix.incrementReferenceCount(group); + FFI_EXPORT(unsigned int) LunaLuaCollisionMatrixAssignGroup(unsigned int previousGroupIndex, char const* newGroup) { + return gCollisionMatrix.assignGroup(previousGroupIndex, newGroup); } - FFI_EXPORT(void) LunaLuaCollisionMatrixDecrementReferenceCount(unsigned int group) - { - gCollisionMatrix.decrementReferenceCount(group); + FFI_EXPORT(bool) LunaLuaCollisionMatrixGetGroupsCollide(char const* i, char const* j) { + return gCollisionMatrix.getGroupsCollide(i, j); } - FFI_EXPORT(void) LunaLuaGlobalCollisionMatrixSetIndicesCollide(unsigned int first, unsigned int second, bool collide) - { - gCollisionMatrix.setIndicesCollide(first, second, collide); + FFI_EXPORT(bool) LunaLuaCollisionMatrixGetGroupIndexCollidesWithGroup(unsigned int i, char const* j) { + return gCollisionMatrix.getGroupsCollide(i, j); } - FFI_EXPORT(bool) LunaLuaGlobalCollisionMatrixGetIndicesCollide(unsigned int first, unsigned int second) - { - return gCollisionMatrix.getIndicesCollide(first, second); + FFI_EXPORT(bool) LunaLuaCollisionMatrixGetGroupIndicesCollide(unsigned int i, unsigned int j) { + return gCollisionMatrix.getGroupsCollide(i, j); + } + + FFI_EXPORT(void) LunaLuaCollisionMatrixSetGroupsCollide(char const* i, char const* j, bool collides) { + return gCollisionMatrix.setGroupsCollide(i, j, collides); } FFI_EXPORT(void) LunaLuaSetPlayerFilterBounceFix(bool enable) @@ -658,14 +673,14 @@ typedef struct ExtendedPlayerFields_\ CLunaFFILock ffiLock(__FUNCTION__); std::unique_lock lck(readFileMutex); - LunaPathValidator::Result* ptr = LunaPathValidator::GetForThread().CheckPath(path); - if (!ptr) return nullptr; - path = ptr->path; + FILE* f = LunaPathValidator::GetForThread().OpenFile(path, "rb"); + if (!f) return nullptr; - std::wstring wpath = Str2WStr(path); + std::wstring wpath = LunaPathValidator::GetForThread().LastPath(); CachedFileDataWeakPtr>::Entry* cacheEntry = g_lunaFileCache.get(wpath); if (cacheEntry == nullptr) { + fclose(f); return nullptr; } @@ -674,11 +689,6 @@ typedef struct ExtendedPlayerFields_\ if (!data) { // No data, try to read the file - FILE* f = _wfopen(wpath.c_str(), L"rb"); - if (!f) - { - return nullptr; - } fseek(f, 0, SEEK_END); size_t len = ftell(f); rewind(f); @@ -689,9 +699,9 @@ typedef struct ExtendedPlayerFields_\ data->resize(len); fread(&((*data)[0]), 1, len, f); } - fclose(f); cacheEntry->data = data; } + fclose(f); g_lunaFileCacheSet.insert(data); ReadFileStruct* cpy = (ReadFileStruct*)malloc(data->size() + sizeof(int)); @@ -714,10 +724,11 @@ typedef struct ExtendedPlayerFields_\ { CLunaFFILock ffiLock(__FUNCTION__); - LunaPathValidator::Result* ptr = LunaPathValidator::GetForThread().CheckPath(path); - if (!ptr) return false; + FILE* f = LunaPathValidator::GetForThread().OpenFile(path, "r"); + if (!f) return false; + fclose(f); - std::wstring wpath = Str2WStr(path); + std::wstring wpath = LunaPathValidator::GetForThread().LastPath(); return gCachedFileMetadata.exists(wpath); } @@ -725,16 +736,15 @@ typedef struct ExtendedPlayerFields_\ { CLunaFFILock ffiLock(__FUNCTION__); - LunaPathValidator::Result* ptr = LunaPathValidator::GetForThread().CheckPath(path); - if (!ptr) return false; - if (!ptr->canWrite) return false; - path = ptr->path; - // Try to write file - bool ret = writeFileAtomic(path, data, dataLen); + bool ret = LunaPathValidator::GetForThread().WriteFileAtomic(path, data, dataLen); + + if (!ret) { + return false; + } // If successful, update cache, if cached - std::wstring wpath = Str2WStr(path); + std::wstring wpath = LunaPathValidator::GetForThread().LastPath(); CachedFileDataWeakPtr>::Entry* cacheEntry = g_lunaFileCache.get(wpath); if (cacheEntry == nullptr) { @@ -994,14 +1004,23 @@ void CachedReadFile::releaseCached(bool isWorld) g_lunaFileCache.release(isWorld); } + extern "C" { - FFI_EXPORT(LunaPathValidator::Result*) LunaLuaMakeSafeAbsolutePath(const char* path) - { - if (!path) return nullptr; - return LunaPathValidator::GetForThread().CheckPath(path); + FFI_EXPORT(std::uintptr_t) LunaLuaOpenFileSafe(const char* path, const char* mode) { + return (std::uintptr_t) LunaPathValidator::GetForThread().OpenFile(path, mode); + } + + FFI_EXPORT(FFIString const*) LunaLuaGetPathValidatorLastErrorMessage(void) { + std::string const& lastError = LunaPathValidator::GetForThread().ErrorMessage(); + static FFIString lastErrorFFI; + lastErrorFFI.buf = lastError.c_str(); + lastErrorFFI.size = lastError.length(); + + return &lastErrorFFI; } } + extern "C" { FFI_EXPORT(void) LunaLuaSetWeakLava(bool value) { diff --git a/LunaDll/LuaMain/LunaLuaMain.cpp b/LunaDll/LuaMain/LunaLuaMain.cpp index 7fe1f690e..7625180a4 100644 --- a/LunaDll/LuaMain/LunaLuaMain.cpp +++ b/LunaDll/LuaMain/LunaLuaMain.cpp @@ -731,6 +731,7 @@ void CLunaLua::bindAll() // This used to be Level.loadPlayerHitBoxes, but it needs to be in a namespace that's usable from the overworld. def("loadCharacterHitBoxes", (void(*)(int, int, const std::string&))&LuaProxy::loadHitboxes), def("showRichDialog", &LuaProxy::Misc::showRichDialog), + def("__getLuaFileFromCFile", &LuaProxy::Misc::__getLuaFileFromCFile), def("__enablePerfTracker", &LuaProxy::Misc::__enablePerfTracker), def("__disablePerfTracker", &LuaProxy::Misc::__disablePerfTracker), def("__getPerfTrackerData", &LuaProxy::Misc::__getPerfTrackerData), diff --git a/LunaDll/LuaMain/LunaPathValidator.cpp b/LunaDll/LuaMain/LunaPathValidator.cpp index ef048db79..8b8d4e63d 100644 --- a/LunaDll/LuaMain/LunaPathValidator.cpp +++ b/LunaDll/LuaMain/LunaPathValidator.cpp @@ -1,7 +1,114 @@ +#include +#include +#include +#include +#include +#include +#include #include #include "LunaPathValidator.h" #include "../Globals.h" #include "../Misc/LoadScreen.h" +#include "../Misc/FileUtils.h" + +#include "../Misc/Syscalls.h" + +// I know this is incredibly cursed, but this is the only way I can think of to safely list the files of a directory +// Huge props to https://blog.s-schoener.com/2024-06-24-find-files-internals/ for these definitions +typedef struct _FILE_DIRECTORY_INFORMATION +{ + ULONG NextEntryOffset; + ULONG FileIndex; + LARGE_INTEGER CreationTime; + LARGE_INTEGER LastAccessTime; + LARGE_INTEGER LastWriteTime; + LARGE_INTEGER ChangeTime; + LARGE_INTEGER EndOfFile; + LARGE_INTEGER AllocationSize; + ULONG FileAttributes; + ULONG FileNameLength; + WCHAR FileName[1]; +} FILE_DIRECTORY_INFORMATION; + +// Taken from https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/ntifs/ns-ntifs-_file_rename_information +typedef struct _FILE_RENAME_INFORMATION { +#if (_WIN32_WINNT >= _WIN32_WINNT_WIN10_RS1) + union { + BOOLEAN ReplaceIfExists; // FileRenameInformation + ULONG Flags; // FileRenameInformationEx + } DUMMYUNIONNAME; +#else + BOOLEAN ReplaceIfExists; +#endif + HANDLE RootDirectory; + ULONG FileNameLength; + WCHAR FileName[1]; +} FILE_RENAME_INFORMATION, *PFILE_RENAME_INFORMATION; + +// Construct directory handle from path +static RAIIHandle getDirectoryHandle(wchar_t const* path, DWORD permissions, DWORD& errorCode) { + // Create directory handle + RAIIHandle directoryHandle = CreateFileW( + path, // Folder path + permissions, + FILE_SHARE_DELETE | FILE_SHARE_READ | FILE_SHARE_WRITE, // We're allowing other processes to do anything to the folder + nullptr, // This handle can't be inherited by child processes + OPEN_EXISTING, // We don't want to create a new folder + FILE_FLAG_BACKUP_SEMANTICS, // Neccesary to create a directory handle + nullptr // Useless for opening existing folders + ); + + // Check if the directory handle was successfully created + if (!directoryHandle.isValid()) { + errorCode = GetLastError(); + } + + return directoryHandle; +} + +// Get the path of a file or directory handle after symlink resolution +static DWORD getHandleFinalPath(HANDLE handle, std::wstring& finalPath, bool addTrailingBackslash = false) { + // Get final path length + DWORD pathLength = GetFinalPathNameByHandleW(handle, nullptr, 0, FILE_NAME_NORMALIZED); + + if (pathLength == 0) { + return GetLastError(); + } + + // Allocate string buffer + finalPath.resize(pathLength - 1); + + // Get final path + if (GetFinalPathNameByHandleW(handle, &finalPath[0], pathLength, FILE_NAME_NORMALIZED) == 0) { + return GetLastError(); + } + + // Remove "\\?\" prefix if needed + if (finalPath.rfind(L"\\\\?\\", 0) == 0) { + finalPath = finalPath.substr(4); + } + + // Add a trailing backslash if needed + if (addTrailingBackslash) { + if ((finalPath.size() > 0) && (finalPath[finalPath.size() - 1] != L'\\')) { + finalPath += L"\\"; + } + } + + return ERROR_SUCCESS; +} + +static DWORD getDirectoryFinalPath(wchar_t const* path, std::wstring& finalPath) { + DWORD errorCode; + + RAIIHandle directoryHandle = getDirectoryHandle(path, 0, errorCode); + + if (!directoryHandle.isValid()) { + return errorCode; + } + + return getHandleFinalPath(directoryHandle.borrow(), finalPath, true); +} // Instances LunaPathValidator gLunaPathValidator; @@ -14,7 +121,7 @@ static std::unordered_set naughtyExtensionMap( LunaPathValidator::LunaPathValidator() : - mEnginePath(), mMatchingEnginePath(), mMatchingEpisodePath() + mFinalEnginePath(), mFinalEpisodePath() { } @@ -24,83 +131,780 @@ LunaPathValidator::~LunaPathValidator() void LunaPathValidator::SetPaths() { - mEnginePath = NormalizedPath(gAppPathUTF8); + DWORD errorCode; + std::wstring mMatchingEnginePath; + std::wstring mMatchingEpisodePath; + + std::string mEnginePath = NormalizedPath(gAppPathUTF8); if ((mEnginePath.size() > 0) && (mEnginePath[mEnginePath.size() - 1] != '\\')) { mEnginePath += "\\"; } + mMatchingEnginePath = NormalizedPath(gAppPathWCHAR); if ((mMatchingEnginePath.size() > 0) && (mMatchingEnginePath[mMatchingEnginePath.size() - 1] != L'\\')) { mMatchingEnginePath += L"\\"; } std::transform(mMatchingEnginePath.begin(), mMatchingEnginePath.end(), mMatchingEnginePath.begin(), towlower); + + // Get the final path of the engine folder + if (mMatchingEnginePath.size() > 0) { + errorCode = getDirectoryFinalPath(mMatchingEnginePath.c_str(), mFinalEnginePath); + if (errorCode != ERROR_SUCCESS) { + mFinalEnginePath.resize(0); + } + } + // No need to add a trailing backslash, getDirectoryFinalPath adds it automatically + std::transform(mFinalEnginePath.begin(), mFinalEnginePath.end(), mFinalEnginePath.begin(), towlower); + mMatchingEpisodePath = NormalizedPath(GM_FULLDIR); if ((mMatchingEpisodePath.size() > 0) && (mMatchingEpisodePath[mMatchingEpisodePath.size() - 1] != L'\\')) { mMatchingEpisodePath += L"\\"; } std::transform(mMatchingEpisodePath.begin(), mMatchingEpisodePath.end(), mMatchingEpisodePath.begin(), towlower); + + // Get the final path of the episode folder + if (mMatchingEpisodePath.size() > 0) { + errorCode = getDirectoryFinalPath(mMatchingEpisodePath.c_str(), mFinalEpisodePath); + if (errorCode != ERROR_SUCCESS) { + mFinalEpisodePath.resize(0); + } + } + // No need to add a trailing backslash, getDirectoryFinalPath adds it automatically + std::transform(mFinalEpisodePath.begin(), mFinalEpisodePath.end(), mFinalEpisodePath.begin(), towlower); } -LunaPathValidator::Result* LunaPathValidator::CheckPath(const char* path) -{ +DWORD LunaPathValidator::CheckPath(std::wstring const& pathArg, bool requestWrite) { + // Make path lowercase + std::wstring path = pathArg; + std::transform(path.begin(), path.end(), path.begin(), towlower); + + if ((mFinalEpisodePath.size() == 0) || (mFinalEpisodePath != path.substr(0, mFinalEpisodePath.size()))) { + // If the episode path doesn't match + + if ((mFinalEnginePath.size() > 0) && (mFinalEnginePath == path.substr(0, mFinalEnginePath.size()))) { + // If engine path matches, check for write protection + bool canWrite = ((path.substr(mFinalEnginePath.size(), 5) == L"logs\\") || + (path.substr(mFinalEnginePath.size(), std::wstring::npos) == L"worlds\\mario challenge\\data.json")); + + if (requestWrite && !canWrite) { + return ERROR_WRITE_PROTECT; + } + } else { + // Otherwise, refuse file access + return ERROR_ACCESS_DENIED; + } + } + if (requestWrite) { + // Check file extension if write access is requested + std::wstring fileExt = L""; + std::wstring::size_type pathIdx = path.rfind(L'\\'); + std::wstring::size_type extIdx = path.rfind(L'.'); + if ((extIdx != std::wstring::npos) && ((pathIdx == std::wstring::npos) || (extIdx > pathIdx))) { + std::wstring fileExt = path.substr(extIdx + 1); + if (naughtyExtensionMap.find(fileExt) != naughtyExtensionMap.cend()) { + return ERROR_WRITE_PROTECT; + } + } + } + + return ERROR_SUCCESS; +} + +std::wstring LunaPathValidator::NormalizePath(std::wstring const& path) { + std::wstring wNormalPath; + // Normalize path and make it absolute if necessary - if (( - ((path[0] >= 'A') && (path[0] <= 'Z')) || - ((path[0] >= 'a') && (path[0] <= 'z')) + if (path.size() >= 3 && + ( + ((path[0] >= L'A') && (path[0] <= L'Z')) || + ((path[0] >= L'a') && (path[0] <= L'z')) ) && - (path[1] == ':') && - ((path[2] == '/') || (path[2] == '\\')) + (path[1] == L':') && + ((path[2] == L'/') || (path[2] == L'\\')) ) { // It's an absolute path already - mNormalPath = NormalizedPath(path); + wNormalPath = NormalizedPath(path); } else { // Not absolute path - mNormalPath = NormalizedPath(mEnginePath + path); + wNormalPath = NormalizedPath(mFinalEnginePath + path); } // Get wchar_t version of path for checking what it starts with, so that we use towlower to better handle unicode case insensitivity - std::wstring wNormalPath = Str2WStr(mNormalPath); std::transform(wNormalPath.begin(), wNormalPath.end(), wNormalPath.begin(), towlower); - if ((mMatchingEpisodePath.size() > 0) && (mMatchingEpisodePath == wNormalPath.substr(0, mMatchingEpisodePath.size()))) - { - // If episode path matches - mResult = { mNormalPath.c_str(), mNormalPath.length(), true }; + return wNormalPath; +} + +std::FILE* LunaPathValidator::OpenFile(const char* path, const char* mode) { + if (!path) { + mLastError.type = ErrorType::NULL_PATH; + return nullptr; } - else if ((mMatchingEnginePath.size() > 0) && (mMatchingEnginePath == wNormalPath.substr(0, mMatchingEnginePath.size()))) - { - // If engine path matches - bool canWrite = ((wNormalPath.substr(mMatchingEnginePath.size(), 5) == L"logs\\") || - (wNormalPath.substr(mMatchingEnginePath.size(), std::wstring::npos) == L"worlds\\mario challenge\\data.json")); - mResult = { mNormalPath.c_str(), mNormalPath.length(), canWrite }; + + return OpenFile(std::string(path), mode); +} + +std::FILE* LunaPathValidator::OpenFile(const wchar_t* path, const char* mode) { + if (!path) { + mLastError.type = ErrorType::NULL_PATH; + return nullptr; } - else - { + + return OpenFile(std::wstring(path), mode); +} + +std::FILE* LunaPathValidator::OpenFile(std::string const& path, const char* mode) { + return OpenFile(Str2WStr(path), mode); +} + +std::FILE* LunaPathValidator::OpenFile(std::wstring const& path, const char* mode) { + if (!mode) { + mode = "r"; + } + + FileUtils::FileOpeningMode modeInfo; + + // Parse file opening mode + if (!FileUtils::ParseFileOpeningMode(mode, modeInfo)) { + mLastError.type = ErrorType::MODE_PARSING_ERROR; + mLastError.pathOrMode = mode; + return nullptr; + } + + // Normalize path + std::wstring wNormalPath = NormalizePath(path); + std::wstring wLongPath = L"\\\\?\\"; + wLongPath += wNormalPath; + + // The file we're trying to open + RAIIHandle fileHandle = CreateFileW( + wLongPath.c_str(), // File path + modeInfo.requestWrite ? (GENERIC_READ | GENERIC_WRITE) : GENERIC_READ, // Open in readonly mode unless we request write permissions + FILE_SHARE_DELETE | FILE_SHARE_READ | FILE_SHARE_WRITE, // We're allowing other processes to do anything to the file + nullptr, // This handle can't be inherited by child processes + OPEN_EXISTING, // We're trying to open an existing file without truncating it + FILE_ATTRIBUTE_NORMAL, // No extra attributes + nullptr // Useless for opening existing files/folders + ); + + if (!fileHandle.isValid()) { + // We couldn't open the file, get the last error to understand why + mLastError.errorCode = GetLastError(); + + // If the file must exist or the reason is not ERROR_FILE_NOT_FOUND, return invalid handle + if (modeInfo.fileMustExist || mLastError.errorCode != ERROR_FILE_NOT_FOUND) { + mLastError.type = ErrorType::FILE_OPENING_ERROR; + mLastError.pathOrMode = WStr2Str(wNormalPath); + return nullptr; + } + + // We're trying to create a new file + + // Get the path of the parent folder + std::wstring parentFolder = wNormalPath; + removeFilePathW(parentFolder); + std::wstring longParentFolder = L"\\\\?\\"; + longParentFolder += parentFolder; + + // Get the name of the file to create + const wchar_t* filename = &wNormalPath[wNormalPath.rfind(L"\\") + 1]; + + // Get a handle of the parent folder + RAIIHandle directoryHandle = getDirectoryHandle(longParentFolder.c_str(), FILE_ADD_FILE, mLastError.errorCode); + + // Error if we can't get the handle + if (!directoryHandle.isValid()) { + mLastError.type = ErrorType::DIR_OPENING_ERROR; + mLastError.pathOrMode = WStr2Str(parentFolder); + return nullptr; + } + + // Try to get the final path of the parent folder + std::wstring finalDirectoryPath; + mLastError.errorCode = getHandleFinalPath(directoryHandle.borrow(), finalDirectoryPath, true); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::GET_DIR_FINAL_PATH_ERROR; + mLastError.pathOrMode = WStr2Str(parentFolder); + return nullptr; + } + + // Get path of file to create + std::wstring newFilePath = NormalizedPath(finalDirectoryPath + filename); + + // Check if we authorize file creation + mLastError.errorCode = CheckPath(newFilePath, modeInfo.requestWrite); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::UNAUTHORIZED_FILE_CREATION; + mLastError.pathOrMode = WStr2Str(newFilePath); + return nullptr; + } + + // Initialize unicode string containing the filename + UNICODE_STRING filenameUnicode; + ntdll::RtlInitUnicodeString(&filenameUnicode, filename); + + // Initialize file creation attributes + OBJECT_ATTRIBUTES attributes; + InitializeObjectAttributes( + &attributes, + &filenameUnicode, // The name of the file to create + OBJ_CASE_INSENSITIVE, // We ignore case for the file existence check + directoryHandle.borrow(), // The handle of the directory where the file is to be created + nullptr + ); + + // IO status block object written by NtCreateFile + IO_STATUS_BLOCK ioStatusBlock; + + // Create the file + NTSTATUS status = ntdll::NtCreateFile( + &fileHandle.getHandleRef(), // Where to write the handle + SYNCHRONIZE | FILE_READ_ATTRIBUTES | (modeInfo.requestWrite ? (FILE_GENERIC_READ | FILE_GENERIC_WRITE) : FILE_GENERIC_READ), // Open in readonly mode unless we request write permissions + &attributes, + &ioStatusBlock, + nullptr, // We don't care about setting an initial allocation size + FILE_ATTRIBUTE_NORMAL, // We're creating a normal file + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, // We're allowing other processes to do anything to the file + FILE_CREATE, // Error if the file already exists + FILE_NON_DIRECTORY_FILE| FILE_SYNCHRONOUS_IO_NONALERT, // We're not creating a directory, we open the file in synchronous IO mode + nullptr, // I don't even know what's an EA buffer + 0 + ); + + if (status != STATUS_SUCCESS) { + mLastError.errorCode = ntdll::RtlNtStatusToDosError(status); + mLastError.type = ErrorType::FILE_CREATION_ERROR; + mLastError.pathOrMode = WStr2Str(newFilePath); + return nullptr; + } + + UpdateLastPath(newFilePath); + } else { + // Get final path of opened file + std::wstring finalFilePath; + mLastError.errorCode = getHandleFinalPath(fileHandle.borrow(), finalFilePath); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::GET_FILE_FINAL_PATH_ERROR; + mLastError.pathOrMode = WStr2Str(wNormalPath); + return nullptr; + } + + // Check if we authorize file access + mLastError.errorCode = CheckPath(finalFilePath, modeInfo.requestWrite); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::UNAUTHORIZED_FILE_ACCESS; + mLastError.pathOrMode = WStr2Str(finalFilePath); + return nullptr; + } + + UpdateLastPath(finalFilePath); + } + + // Convert file handle to file descriptor + HANDLE rawFileHandle = fileHandle.takeOwnership(); + int fd = _open_osfhandle((std::intptr_t) rawFileHandle, modeInfo.flags); + if (fd == -1) { + CloseHandle(rawFileHandle); + mLastError.type = ErrorType::DESCRIPTOR_CREATION_ERROR; return nullptr; } + + // Convert file descriptor to C file object + FILE* fileObject = _fdopen(fd, mode); + if (!fileObject) { + mLastError.type = ErrorType::FILE_OBJECT_CREATION_ERROR; + mLastError.errorCode = _doserrno; + _close(fd); + } + + return fileObject; +} + +HANDLE LunaPathValidator::CreateTempFile(HANDLE parentFolder) { + // RNG for generating temp file names + static thread_local std::mt19937 rng(GetTickCount()); - if (mResult.canWrite) + HANDLE tmpHwnd = INVALID_HANDLE_VALUE; + for (uint32_t i=0; (i<=0xFFFF) && (tmpHwnd == INVALID_HANDLE_VALUE); i++) { - // Check extension - std::wstring fileExt = L""; - std::wstring::size_type pathIdx = wNormalPath.rfind(L'\\'); - std::wstring::size_type extIdx = wNormalPath.rfind(L'.'); - if ( - (extIdx != std::wstring::npos) && - ((pathIdx == std::wstring::npos) || (extIdx > pathIdx)) - ) + static const wchar_t* digits = L"0123456789ABCDEFGHIJKLMNOPQRSTUV"; + std::wstring tmpPath = L"."; + uint32_t randomNumber = rng(); + for (int j = 0; j < 16; j += 5) { - std::wstring fileExt = wNormalPath.substr(extIdx + 1); - mResult.canWrite = (naughtyExtensionMap.find(fileExt) == naughtyExtensionMap.cend()); + tmpPath += digits[(randomNumber >> j) & 0xF]; + } + tmpPath += L".TMP"; + + // Initialize unicode string containing the filename + UNICODE_STRING filenameUnicode; + ntdll::RtlInitUnicodeString(&filenameUnicode, tmpPath.c_str()); + + // Initialize file creation attributes + OBJECT_ATTRIBUTES attributes; + InitializeObjectAttributes( + &attributes, + &filenameUnicode, // The name of the file to create + OBJ_CASE_INSENSITIVE, // We ignore case for the file existence check + parentFolder, // The handle of the directory where the file is to be created + nullptr + ); + + // IO status block object written by NtCreateFile + IO_STATUS_BLOCK ioStatusBlock; + + // Create the file + NTSTATUS status = ntdll::NtCreateFile( + &tmpHwnd, // Where to write the handle + FILE_GENERIC_WRITE | DELETE | SYNCHRONIZE | FILE_READ_ATTRIBUTES, // We can write to the file or delete it + &attributes, + &ioStatusBlock, + nullptr, // We don't care about setting an initial allocation size + FILE_ATTRIBUTE_NORMAL, // We're creating a normal file + 0, // We're not allowing other processes to do anything to the file + FILE_CREATE, // Error if the file already exists + FILE_NON_DIRECTORY_FILE | FILE_SYNCHRONOUS_IO_NONALERT, // We're not creating a directory, we open the file in synchronous IO mode + nullptr, // I don't even know what's an EA buffer + 0 + ); + + if (tmpHwnd == INVALID_HANDLE_VALUE) + { + // No success + mLastError.errorCode = GetLastError(); + if (mLastError.errorCode == ERROR_FILE_EXISTS) + { + // File exists? Retry + continue; + } + else + { + // Other failure, abort + return INVALID_HANDLE_VALUE; + } + } + } + + if (tmpHwnd == INVALID_HANDLE_VALUE) + { + // Something very wrong... even 0xFFFF retries got "ERROR_FILE_EXISTS" + return INVALID_HANDLE_VALUE; + } + return tmpHwnd; +} + +// Mark a file handle for deletion, ignore errors +void MarkForDeletion(HANDLE handle) { + FILE_DISPOSITION_INFO deletionInfo; + deletionInfo.DeleteFileW = true; + + SetFileInformationByHandle( + handle, + FileDispositionInfo, + &deletionInfo, + sizeof(deletionInfo) + ); +} + +bool LunaPathValidator::WriteFileAtomic(const char* path, const void* data, ptrdiff_t dataLen) { + if (!path) { + mLastError.type = ErrorType::NULL_PATH; + return false; + } + + return WriteFileAtomic(std::string(path), data, dataLen); +} + +bool LunaPathValidator::WriteFileAtomic(const wchar_t* path, const void* data, ptrdiff_t dataLen) { + if (!path) { + mLastError.type = ErrorType::NULL_PATH; + return false; + } + + return WriteFileAtomic(std::wstring(path), data, dataLen); +} + +bool LunaPathValidator::WriteFileAtomic(std::string const& path, const void* data, ptrdiff_t dataLen) { + return WriteFileAtomic(Str2WStr(path), data, dataLen); +} + +bool LunaPathValidator::WriteFileAtomic(std::wstring const& path, const void* data, ptrdiff_t dataLen) { + // Normalize path + std::wstring wNormalPath = NormalizePath(path); + + // Get the path of the parent folder + std::wstring parentFolder = wNormalPath; + removeFilePathW(parentFolder); + std::wstring longParentFolder = L"\\\\?\\"; + longParentFolder += parentFolder; + + // Get the name of the file to create + const wchar_t* filename = &wNormalPath[wNormalPath.rfind(L"\\") + 1]; + std::size_t filenameLength = std::wcslen(filename); + + // Get a handle of the parent folder + RAIIHandle directoryHandle = getDirectoryHandle(longParentFolder.c_str(), FILE_TRAVERSE | FILE_READ_ATTRIBUTES, mLastError.errorCode); + + // Error if we can't get the handle + if (!directoryHandle.isValid()) { + mLastError.type = ErrorType::DIR_OPENING_ERROR; + mLastError.pathOrMode = WStr2Str(parentFolder); + return false; + } + + // Try to get the final path of the parent folder + std::wstring finalDirectoryPath; + mLastError.errorCode = getHandleFinalPath(directoryHandle.borrow(), finalDirectoryPath, true); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::GET_DIR_FINAL_PATH_ERROR; + mLastError.pathOrMode = WStr2Str(parentFolder); + return false; + } + + // Check if we authorize writing to the target file + // We don't care about checking for symlinks since SetFileInformationByHandle will overwrite them anyways + std::wstring actualFilePath = NormalizedPath(finalDirectoryPath + filename); + mLastError.errorCode = CheckPath(actualFilePath, true); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::UNAUTHORIZED_FILE_ACCESS; + mLastError.pathOrMode = WStr2Str(actualFilePath); + return false; + } + + // Check if we authorize temp file creation + mLastError.errorCode = CheckPath(finalDirectoryPath, true); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::UNAUTHORIZED_TEMP_FILE_CREATION; + mLastError.pathOrMode = WStr2Str(parentFolder); + return false; + } + + // Create temp file + RAIIHandle tempFile = CreateTempFile(directoryHandle.borrow()); + + if (!tempFile.isValid()) { + // mLastError.errorCode already set by CreateTempFile + mLastError.type = ErrorType::TEMP_FILE_CREATION_ERROR; + mLastError.pathOrMode = WStr2Str(parentFolder); + return false; + } + + // Write data to temp file + DWORD bytesWritten = 0; + if (WriteFile(tempFile.borrow(), data, dataLen, &bytesWritten, NULL) == 0) { + // Write failed + mLastError.type = ErrorType::TEMP_FILE_WRITE_ERROR; + mLastError.errorCode = GetLastError(); + + // Mark temp file for deletion + MarkForDeletion(tempFile.borrow()); + return false; + } + + // Not enough bytes written + if (bytesWritten != dataLen) { + mLastError.type = ErrorType::TEMP_FILE_INCOMPLETE_WRITE; + + // Mark temp file for deletion + MarkForDeletion(tempFile.borrow()); + return false; + } + + // IO status block object written by NtQueryDirectoryFile + IO_STATUS_BLOCK ioStatusBlock; + + // Actually replace the target file + std::size_t renameInfoSize = sizeof(FILE_RENAME_INFORMATION) + sizeof(wchar_t) * filenameLength; + std::unique_ptr renameInfo((PFILE_RENAME_INFORMATION) std::malloc(renameInfoSize), std::free); + renameInfo->ReplaceIfExists = true; // Replace the target file if it exists + renameInfo->RootDirectory = directoryHandle.borrow(); // Directory where to move the file + renameInfo->FileNameLength = filenameLength * sizeof(wchar_t); + std::wcscpy(renameInfo->FileName, filename); + NTSTATUS status = ntdll::NtSetInformationFile( + tempFile.borrow(), + &ioStatusBlock, + renameInfo.get(), + renameInfoSize, + (FILE_INFORMATION_CLASS) 10 // FileRenameInformation + ); + + if (status != STATUS_SUCCESS) { + // Unsuccessful replace + mLastError.type = ErrorType::FILE_REPLACE_ERROR; + mLastError.errorCode = ntdll::RtlNtStatusToDosError(status); + mLastError.pathOrMode = WStr2Str(actualFilePath); + + // Mark temp file for deletion + MarkForDeletion(tempFile.borrow()); + return false; + } + + UpdateLastPath(actualFilePath); + + // Everything worked as intended! + return true; +} + +bool LunaPathValidator::ListOfDir(const char* path, DWORD attributes, std::vector& outputList) { + if (!path) { + mLastError.type = ErrorType::NULL_PATH; + return false; + } + + return ListOfDir(std::string(path), attributes, outputList); +} + +bool LunaPathValidator::ListOfDir(const wchar_t* path, DWORD attributes, std::vector& outputList) { + if (!path) { + mLastError.type = ErrorType::NULL_PATH; + return false; + } + + return ListOfDir(std::wstring(path), attributes, outputList); +} + +bool LunaPathValidator::ListOfDir(std::string const& path, DWORD attributes, std::vector& outputList) { + return ListOfDir(Str2WStr(path), attributes, outputList); +} + +bool LunaPathValidator::ListOfDir(std::wstring const& path, DWORD attributes, std::vector& outputList) { + // Normalize path + std::wstring wNormalPath = NormalizePath(path); + std::wstring wLongPath = L"\\\\?\\"; + wLongPath += wNormalPath; + + // Get directory handle + RAIIHandle directory = getDirectoryHandle(wLongPath.c_str(), FILE_LIST_DIRECTORY, mLastError.errorCode); + if (!directory.isValid()) { + mLastError.type = ErrorType::DIR_OPENING_ERROR; + mLastError.pathOrMode = WStr2Str(wNormalPath); + return false; + } + + // Get directory final path + std::wstring finalDirectoryPath; + mLastError.errorCode = getHandleFinalPath(directory.borrow(), finalDirectoryPath, true); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::GET_DIR_FINAL_PATH_ERROR; + mLastError.pathOrMode = WStr2Str(wNormalPath); + return false; + } + + // Check if we're allowed to get the contents of the directory + mLastError.errorCode = CheckPath(finalDirectoryPath, false); + if (mLastError.errorCode != ERROR_SUCCESS) { + mLastError.type = ErrorType::UNAUTHORIZED_FOLDER_LIST; + mLastError.pathOrMode = WStr2Str(finalDirectoryPath); + return false; + } + + // IO status block object written by NtQueryDirectoryFile + IO_STATUS_BLOCK ioStatusBlock; + + // File data written to by NtQueryDirectoryFile + constexpr std::size_t fileDataSize = 1024 * 64; + std::unique_ptr fileData((FILE_DIRECTORY_INFORMATION*) malloc(fileDataSize), std::free); + + // Get first entry in directory + NTSTATUS status = ntdll::NtQueryDirectoryFile( + directory.borrow(), 0, nullptr, nullptr, + &ioStatusBlock, fileData.get(), fileDataSize, + FileDirectoryInformation, + false, // Return as many entries as possible + nullptr, + true // Restart scan + ); + + while (status != STATUS_NO_MORE_FILES) { + // Error if we encounter an error + if (status != STATUS_SUCCESS) { + mLastError.errorCode = ntdll::RtlNtStatusToDosError(status); + mLastError.type = ErrorType::FOLDER_LIST_ERROR; + mLastError.pathOrMode = WStr2Str(finalDirectoryPath); + return false; + } + + // Iterate all found files for current iteration + FILE_DIRECTORY_INFORMATION* currentFile = fileData.get(); + while (true) { + // Check if the file has the correct attributes + if ((currentFile->FileAttributes & attributes) != 0) { + // Get filename + std::size_t filenameSize = currentFile->FileNameLength / sizeof(wchar_t); + std::wstring filename(currentFile->FileName, filenameSize); + + // Add filename to return list if it's not . or .. and it has the correct attributes + if (filename != L"." && filename != L"..") { + outputList.push_back(WStr2Str(filename)); + } + } + + if (currentFile->NextEntryOffset == 0) { + break; + } + + currentFile = (FILE_DIRECTORY_INFORMATION*) (((std::uintptr_t) currentFile) + currentFile->NextEntryOffset); + } + + // Get next entry in directory + status = ntdll::NtQueryDirectoryFile( + directory.borrow(), 0, nullptr, nullptr, + &ioStatusBlock, fileData.get(), fileDataSize, + FileDirectoryInformation, + false, // Return as many entries as possible + nullptr, + false // Don't restart scan + ); + } + + // Success! + mLastSuccessfulPath = finalDirectoryPath; + return true; +} + +void LunaPathValidator::UpdateLastPath(std::wstring const& path) { + mLastSuccessfulPath = path; + std::transform(mLastSuccessfulPath.begin(), mLastSuccessfulPath.end(), mLastSuccessfulPath.begin(), towlower); +} + +std::wstring const& LunaPathValidator::LastPath() { + return mLastSuccessfulPath; +} + +LunaPathValidator::Error const& LunaPathValidator::LastError() { + return mLastError; +} + +std::string LunaPathValidator::ErrorMessage() { + std::string errorMessage; + switch (mLastError.type) { + case ErrorType::NULL_PATH: + errorMessage += "Path is null"; + break; + + case ErrorType::MODE_PARSING_ERROR: + errorMessage += mLastError.pathOrMode; + errorMessage += " is not a valid file opening mode"; + break; + case ErrorType::FILE_OPENING_ERROR: + errorMessage += "Couldn't open handle for file "; + errorMessage += mLastError.pathOrMode; + break; + + case ErrorType::GET_FILE_FINAL_PATH_ERROR: + errorMessage += "Couldn't get final path of file "; + errorMessage += mLastError.pathOrMode; + break; + + case ErrorType::UNAUTHORIZED_FILE_ACCESS: + errorMessage += "Access to "; + errorMessage += mLastError.pathOrMode; + errorMessage += " is unauthorized"; + break; + + case ErrorType::DIR_OPENING_ERROR: + errorMessage += "Couldn't open handle for directory "; + errorMessage += mLastError.pathOrMode; + break; + + case ErrorType::GET_DIR_FINAL_PATH_ERROR: + errorMessage += "Couldn't get final path of directory "; + errorMessage += mLastError.pathOrMode; + break; + + case ErrorType::UNAUTHORIZED_FILE_CREATION: + errorMessage += "Creating file "; + errorMessage += mLastError.pathOrMode; + errorMessage += " is unauthorized"; + break; + + case ErrorType::FILE_CREATION_ERROR: + errorMessage += "Couldn't create file "; + errorMessage += mLastError.pathOrMode; + break; + + case ErrorType::DESCRIPTOR_CREATION_ERROR: + errorMessage += "Couldn't convert handle into file descriptor"; + break; + + case ErrorType::FILE_OBJECT_CREATION_ERROR: + errorMessage += "Couldn't convert file descriptor into C file object"; + break; + + case ErrorType::UNAUTHORIZED_TEMP_FILE_CREATION: + errorMessage += "Creating temporary file in directory "; + errorMessage += mLastError.pathOrMode; + errorMessage += " is unauthorized"; + break; + + case ErrorType::TEMP_FILE_CREATION_ERROR: + errorMessage += "Couldn't create temporary file in directory "; + errorMessage += mLastError.pathOrMode; + break; + + case ErrorType::TEMP_FILE_WRITE_ERROR: + errorMessage += "Couldn't write data to temporary file"; + break; + + case ErrorType::TEMP_FILE_INCOMPLETE_WRITE: + errorMessage += "Write operation to temporary file was incomplete"; + break; + + case ErrorType::FILE_REPLACE_ERROR: + errorMessage += "Couldn't replace "; + errorMessage += mLastError.pathOrMode; + errorMessage += " by temporary file"; + break; + + case ErrorType::UNAUTHORIZED_FOLDER_LIST: + errorMessage += "Listing the contents of directory "; + errorMessage += mLastError.pathOrMode; + errorMessage += " is unauthorized"; + break; + + case ErrorType::FOLDER_LIST_ERROR: + errorMessage += "Error while listing the contents of directory "; + errorMessage += mLastError.pathOrMode; + break; + }; + + if ( + mLastError.type != ErrorType::MODE_PARSING_ERROR && + mLastError.type != ErrorType::DESCRIPTOR_CREATION_ERROR && + mLastError.type != ErrorType::NULL_PATH && + mLastError.type != ErrorType::TEMP_FILE_INCOMPLETE_WRITE + ) { + // mLastError.errorCode is a win32 error code + wchar_t* win32ErrorString; + + if (FormatMessageW( + FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, + nullptr, + mLastError.errorCode, + MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), + (wchar_t*) &win32ErrorString, + 0, + nullptr + ) > 0) { + errorMessage += " ("; + errorMessage += WStr2Str(win32ErrorString); + errorMessage += ')'; + + LocalFree(win32ErrorString); + } else { + errorMessage += " (Win32 error code "; + errorMessage += std::to_string(mLastError.errorCode); + errorMessage += ')'; } } - return &mResult; + return errorMessage; } LunaPathValidator& LunaPathValidator::GetForThread() diff --git a/LunaDll/LuaMain/LunaPathValidator.h b/LunaDll/LuaMain/LunaPathValidator.h index 632fdf40b..9b8ca6184 100644 --- a/LunaDll/LuaMain/LunaPathValidator.h +++ b/LunaDll/LuaMain/LunaPathValidator.h @@ -4,34 +4,82 @@ #include #include #include "../GlobalFuncs.h" +#include "../Misc/RAIIHandle.h" class LunaPathValidator { public: - struct Result - { - const char* path; - unsigned int len; - bool canWrite; - Result() : - path(nullptr), len(0), canWrite(false) - {} - Result(const char* _path, unsigned int _len, bool _canWrite) : - path(_path), len(_len), canWrite(_canWrite) - {} + enum class ErrorType { + NULL_PATH, + MODE_PARSING_ERROR, + FILE_OPENING_ERROR, + GET_FILE_FINAL_PATH_ERROR, + UNAUTHORIZED_FILE_ACCESS, + DIR_OPENING_ERROR, + GET_DIR_FINAL_PATH_ERROR, + UNAUTHORIZED_FILE_CREATION, + FILE_CREATION_ERROR, + DESCRIPTOR_CREATION_ERROR, + FILE_OBJECT_CREATION_ERROR, + UNAUTHORIZED_TEMP_FILE_CREATION, + TEMP_FILE_CREATION_ERROR, + TEMP_FILE_WRITE_ERROR, + TEMP_FILE_INCOMPLETE_WRITE, + FILE_REPLACE_ERROR, + UNAUTHORIZED_FOLDER_LIST, + FOLDER_LIST_ERROR, + }; + + struct Error { + ErrorType type; + DWORD errorCode; + std::string pathOrMode; }; private: - std::string mEnginePath; - std::wstring mMatchingEnginePath; - std::wstring mMatchingEpisodePath; + std::wstring mFinalEnginePath; + std::wstring mFinalEpisodePath; + + Error mLastError; - std::string mNormalPath; - Result mResult; + std::wstring mLastSuccessfulPath; + + DWORD CheckPath(std::wstring const& path, bool requestWrite); // ERROR_SUCCESS ERROR_ACCESS_DENIED ERROR_WRITE_PROTECT + std::wstring NormalizePath(std::wstring const& path); + HANDLE CreateTempFile(HANDLE parentFolder); + void UpdateLastPath(std::wstring const& path); public: LunaPathValidator(); ~LunaPathValidator(); + + // Initialize engine and episode paths void SetPaths(); - Result* CheckPath(const char* path); + + // Safely opens a file + std::FILE* OpenFile(const char* path, const char* mode); + std::FILE* OpenFile(const wchar_t* path, const char* mode); + std::FILE* OpenFile(std::string const& path, const char* mode); + std::FILE* OpenFile(std::wstring const& path, const char* mode); + + // Function to safely write data to a file, making repalcement as close to atomic as Windows seems to allow + bool WriteFileAtomic(const char* path, const void* data, ptrdiff_t dataLen); + bool WriteFileAtomic(const wchar_t* path, const void* data, ptrdiff_t dataLen); + bool WriteFileAtomic(std::string const& path, const void* data, ptrdiff_t dataLen); + bool WriteFileAtomic(std::wstring const& path, const void* data, ptrdiff_t dataLen); + + // List the contents of a directory in a safe way + bool ListOfDir(const char* path, DWORD attributes, std::vector& outputList); + bool ListOfDir(const wchar_t* path, DWORD attributes, std::vector& outputList); + bool ListOfDir(std::string const& path, DWORD attributes, std::vector& outputList); + bool ListOfDir(std::wstring const& path, DWORD attributes, std::vector& outputList); + + // If the previous call to OpenFile, WriteFileAtomic or ListOfDir succeeded, return the resolved path of the file, otherwise, return an unspecified value + std::wstring const& LastPath(); + + // If the previous call to OpenFile, WriteFileAtomic or ListOfDir failed, return the last error, otherwise, return an unspecified value + Error const& LastError(); + + // If the previous call to OpenFile, WriteFileAtomic or ListOfDir failed, return the last error message, otherwise, return an unspecified value + std::string ErrorMessage(); public: static LunaPathValidator& GetForThread(); }; diff --git a/LunaDll/LunaDll.vcxproj b/LunaDll/LunaDll.vcxproj index 96d4ae12a..bb2a65bbe 100644 --- a/LunaDll/LunaDll.vcxproj +++ b/LunaDll/LunaDll.vcxproj @@ -327,6 +327,10 @@ + + + + @@ -566,6 +570,8 @@ + + @@ -587,4 +593,4 @@ - \ No newline at end of file + diff --git a/LunaDll/LunaDll.vcxproj.filters b/LunaDll/LunaDll.vcxproj.filters index 0b0f61302..d4652a5b9 100644 --- a/LunaDll/LunaDll.vcxproj.filters +++ b/LunaDll/LunaDll.vcxproj.filters @@ -748,6 +748,18 @@ FileManager + + Misc + + + Misc + + + Misc + + + Misc + @@ -1502,6 +1514,12 @@ FileManager + + Misc + + + Misc + @@ -1523,4 +1541,4 @@ - \ No newline at end of file + diff --git a/LunaDll/Misc/CollisionMatrix.cpp b/LunaDll/Misc/CollisionMatrix.cpp index 8770ba306..2204bb3eb 100644 --- a/LunaDll/Misc/CollisionMatrix.cpp +++ b/LunaDll/Misc/CollisionMatrix.cpp @@ -1,19 +1,144 @@ #include "CollisionMatrix.h" #include -#include -#include -#include +#include #include #include "../../Globals.h" -CollisionMatrix gCollisionMatrix("onGroupDeallocationInternal"); +CollisionMatrix gCollisionMatrix; -CollisionMatrix::CollisionMatrix(char const* ev) : +// === PRIVATE METHODS === + +CollisionMatrix::GroupOrIndex::GroupOrIndex(unsigned int i): index(i), is_allocated(true) {} + +CollisionMatrix::GroupOrIndex::GroupOrIndex(char const* n): name(n), is_allocated(false) {} + +bool CollisionMatrix::GroupOrIndex::isAllocated() { + return is_allocated; +} + +unsigned int CollisionMatrix::GroupOrIndex::getIndex() { + return index; +} + +char const* CollisionMatrix::GroupOrIndex::getGroup() { + return name; +} + +void CollisionMatrix::cleanupMatrix() { + // Search last nonempty group index + unsigned int last_nonempty_group = matrix.size() - 1; + for (; last_nonempty_group < matrix.size() && matrix[last_nonempty_group].empty(); last_nonempty_group--) {} + + // Resize collision matrix + matrix.resize(last_nonempty_group + 1); +} + +unsigned int CollisionMatrix::allocateGroupIndex(char const* collisionGroup) { + //printBoxA("Entering allocateGroupIndex(collisionGroup = '%s')", collisionGroup); + unsigned int new_group; + + if (deallocated_groups.empty()) { // If there's no available deallocated group index, create a new one. + //printBoxA("No available unallocated index"); + new_group = next_group; + next_group++; + + reference_count.push_back(0u); + index_to_string.push_back(collisionGroup); + + //printBoxA("index_to_string.size() = %u, new_group = %u", index_to_string.size(), new_group); + } else { // Otherwise, take one from the set. + //printBoxA("Unallocated index available"); + new_group = deallocated_groups.top(); + deallocated_groups.pop(); + + // No need to set the reference count to 0 since a deallocated group index always has a reference count of 0. + index_to_string[new_group - 1] = collisionGroup; + } + //printBoxA("string_to_index['%s'] = %u", index_to_string[new_group - 1].c_str(), new_group); + string_to_index[index_to_string[new_group - 1]] = new_group; + + //printBoxA("Allocated index %u for group '%s'", new_group, collisionGroup); + + return new_group; +} + +CollisionMatrix::GroupOrIndex CollisionMatrix::tryGetGroupIndex(char const* collisionGroup) { + //printBoxA("Entering tryGetGroupIndex(collisionGroup = '%s')", collisionGroup); + if (*collisionGroup == '\0') { // Check for empty string + //printBoxA("Empty string, return index 0"); + return GroupOrIndex(0u); + } + + // Try to search for an already allocated index + auto foundIndex = string_to_index.find(collisionGroup); + + // Return it if it was found + if (foundIndex != string_to_index.end()) { + //printBoxA("Found index %u associated to '%s' in string_to_index", foundIndex->second, collisionGroup); + return GroupOrIndex(foundIndex->second); + } + + //printBoxA("Didn't find any index associated to '%s' in string_to_index", collisionGroup); + // Return collision group + return GroupOrIndex(collisionGroup); +} + +unsigned int CollisionMatrix::getOrAllocateGroupIndex(char const* collisionGroup) { + GroupOrIndex resolved = tryGetGroupIndex(collisionGroup); + + if (resolved.isAllocated()) { + return resolved.getIndex(); + } else { + return allocateGroupIndex(resolved.getGroup()); + } +} + +void CollisionMatrix::deallocateGroupIndex(unsigned int groupIndex) { + // Put group index in deallocated group indices + deallocated_groups.push(groupIndex); + + // Remove name <=> index mapping for this group + string_to_index.erase(index_to_string[groupIndex - 1]); + index_to_string[groupIndex - 1].clear(); +} + +void CollisionMatrix::incrementReferenceCount(unsigned int groupIndex) { + if (groupIndex != 0) { // We don't use reference counting for the default group index + reference_count[groupIndex - 1]++; + } +} + +bool CollisionMatrix::defaultBehavior(GroupOrIndex i, GroupOrIndex j) const { + if (i.isAllocated() && j.isAllocated()) { + // Both groups are allocated, perform the check on the collision group indices + return defaultBehavior(i.getIndex(), j.getIndex()); + } else if (!i.isAllocated() && !j.isAllocated()) { + // No groups are allocated, perform the check on the collision group strings + return defaultBehavior(i.getGroup(), j.getGroup()); + } else { + // Only one of the two collision groups are allocated. Therefore, i and j must be different. + // Two distinct collision groups always collide with eachother by default. + return true; + } +} + +bool CollisionMatrix::defaultBehavior(unsigned int i, unsigned int j) const { + return i == 0 || i != j; +} + +bool CollisionMatrix::defaultBehavior(char const* i, char const* j) const { + return *i == '\0' || std::strcmp(i, j) != 0; +} + +// === PUBLIC METHODS === + +CollisionMatrix::CollisionMatrix() : matrix(), deallocated_groups(), next_group(1), reference_count(), - deallocation_event_name(ev) + string_to_index(), + index_to_string() {} void CollisionMatrix::clear() { @@ -28,44 +153,86 @@ void CollisionMatrix::clear() { // Clear reference count vector reference_count.clear(); -} - -unsigned int CollisionMatrix::allocateIndex() { - unsigned int new_group; - if (deallocated_groups.empty()) { // If there's no available deallocated group index, create a new one. - new_group = next_group; - next_group++; + // Clear name <=> index mapping + string_to_index.clear(); + index_to_string.clear(); +} - reference_count.push_back(0u); - } else { // Otherwise, take one from the set. - new_group = deallocated_groups.top(); - deallocated_groups.pop(); +std::string const& CollisionMatrix::getGroupFromIndex(unsigned int groupIndex) { + // Empty string + static const std::string empty; - // No need to set the reference count to 0 since a deallocated group index always has a reference count of 0. + if (groupIndex == 0) { + return empty; } - - return new_group; + return index_to_string[groupIndex - 1]; } -void CollisionMatrix::incrementReferenceCount(unsigned int group) { - if (group != 0) { // We don't use reference counting for the default group index - reference_count[group - 1]++; +unsigned int CollisionMatrix::assignGroup(unsigned int previousGroupIndex, char const* newGroup) { + //printBoxA("Entering assignGroup(previousGroupIndex = %u, newGroup = '%s')", previousGroupIndex, newGroup); + unsigned int newIndex = getOrAllocateGroupIndex(newGroup); + //printBoxA("Index for '%s' is %u", newGroup, newIndex); + + if (newIndex != previousGroupIndex) { + incrementReferenceCount(newIndex); + decrementReferenceCount(previousGroupIndex); } + return newIndex; } -void CollisionMatrix::decrementReferenceCount(unsigned int group) { - if (group != 0) { // We don't use reference counting for the default group index - reference_count[group - 1]--; +void CollisionMatrix::decrementReferenceCount(unsigned int groupIndex) { + if (groupIndex != 0) { // We don't use reference counting for the default group index + reference_count[groupIndex - 1]--; - if (reference_count[group - 1] == 0) { - deallocateGroup(group); + if (reference_count[groupIndex - 1] == 0) { + deallocateGroupIndex(groupIndex); } } } +bool CollisionMatrix::getGroupsCollide(char const* i, char const* j) { + //printBoxA("Entering getGroupsCollide(i = '%s', j = '%s')", i, j); + + // Try to get the collision group indices + GroupOrIndex iResolved = tryGetGroupIndex(i); + GroupOrIndex jResolved = tryGetGroupIndex(j); -bool CollisionMatrix::getIndicesCollide(unsigned int i, unsigned int j) { + //printBoxA("Resolved collision groups: '%s' -> %u, '%s' -> %u", i, iResolved.getGroup(), j, jResolved.getGroup()); + + if (iResolved.isAllocated() && jResolved.isAllocated()) { + //printBoxA("Two groups allocated, int check"); + // Both groups are allocated, perform the check on the collision group indices + return getGroupsCollide(iResolved.getIndex(), jResolved.getIndex()); + } else if (!iResolved.isAllocated() && !jResolved.isAllocated()) { + //printBoxA("Zero groups allocated, int check"); + // No groups are allocated, fallback to default behavior + return defaultBehavior(iResolved.getGroup(), jResolved.getGroup()); + } else { + //printBoxA("One group allocated, return true"); + // Only one of the two collision groups are allocated. Therefore, i and j must be different. + // Furthermore (i, j) is not in the collision matrix, so their collision behavior must be the default one. + // Two distinct collision groups always collide with eachother by default. + return true; + } +} + +bool CollisionMatrix::getGroupsCollide(unsigned int i, char const* j) { + // Try to get the collision group index of j + GroupOrIndex jResolved = tryGetGroupIndex(j); + + if (jResolved.isAllocated()) { + // Both groups are allocated, perform the check on the collision group indices + return getGroupsCollide(i, jResolved.getIndex()); + } else { + // Only one of the two collision groups are allocated. Therefore, i and j must be different. + // Furthermore (i, j) is not in the collision matrix, so their collision behavior must be the default one. + // Two distinct collision groups always collide with eachother by default. + return true; + } +} + +bool CollisionMatrix::getGroupsCollide(unsigned int i, unsigned int j) { unsigned int min, max; std::tie(min, max) = std::minmax(i, j); @@ -73,12 +240,38 @@ bool CollisionMatrix::getIndicesCollide(unsigned int i, unsigned int j) { return (max < matrix.size() && matrix[max].count(min) == 1) != defaultBehavior(min, max); } -void CollisionMatrix::setIndicesCollide(unsigned int i, unsigned int j, bool collide) { +void CollisionMatrix::setGroupsCollide(char const* i, char const* j, bool collide) { + //printBoxA("Entering setGroupsCollide(i = '%s', j = '%s', collide = %u)", i, j, collide); + + // Try to get the collision group indices + GroupOrIndex iResolved = tryGetGroupIndex(i); + GroupOrIndex jResolved = tryGetGroupIndex(j); + + //printBoxA("Resolved collision groups: '%s' -> %u, '%s' -> %u", i, iResolved.getGroup(), j, jResolved.getGroup()); + + // Get default behavior for collision group pair + bool default_collide = defaultBehavior(iResolved, jResolved); + + //printBoxA("Default behavior: ('%s', '%s') -> %u", i, j, default_collide); + + // Return prematurely if at least one group isn't allocated and we don't modify the default value, avoids needless allocations + if ((!iResolved.isAllocated() || !jResolved.isAllocated()) && (collide == default_collide)) { + return; + } + + // Get or allocate collision group indices + unsigned int iIndex = getOrAllocateGroupIndex(i); + unsigned int jIndex = getOrAllocateGroupIndex(j); + + //printBoxA("Final indices: '%s' -> %u, '%s' -> %u", i, iIndex, j, jIndex); + unsigned int min, max; - std::tie(min, max) = std::minmax(i, j); + std::tie(min, max) = std::minmax(iIndex, jIndex); if (max >= matrix.size()) { // If the matrix is too small to contain (min, max) - if (collide != defaultBehavior(min, max)) { // If element (min, max) of the matrix is modified + if (collide != default_collide) { // If element (min, max) of the matrix is modified + //printBoxA("Matrix too small, inserted pair (%u, %u)", min, max); + // Resize the matrix matrix.resize(max + 1); @@ -92,9 +285,10 @@ void CollisionMatrix::setIndicesCollide(unsigned int i, unsigned int j, bool col } else { auto current_collide_iter = matrix[max].find(min); // Search collision group index min in set max bool contains = current_collide_iter != matrix[max].end(); // Check whether the previous search has succeeded or not - bool default_collide = defaultBehavior(min, max); // Get default behavior if (contains && (collide == default_collide)) { // If we have to remove collision group index min from set max + //printBoxA("Matrix large enough, removed pair (%u, %u)", min, max); + // Remove it matrix[max].erase(current_collide_iter); @@ -105,6 +299,8 @@ void CollisionMatrix::setIndicesCollide(unsigned int i, unsigned int j, bool col // Cleanup collision matrix cleanupMatrix(); } else if (!contains && (collide != default_collide)) { // If we have to insert collision group index min in set max + //printBoxA("Matrix large enough, inserted pair (%u, %u)", min, max); + // Insert it matrix[max].insert(min); @@ -113,30 +309,4 @@ void CollisionMatrix::setIndicesCollide(unsigned int i, unsigned int j, bool col incrementReferenceCount(max); } } -} - -void CollisionMatrix::cleanupMatrix() { - // Search last nonempty group index - unsigned int last_nonempty_group = matrix.size() - 1; - for (; last_nonempty_group < matrix.size() && matrix[last_nonempty_group].empty(); last_nonempty_group--) {} - - // Resize collision matrix - matrix.resize(last_nonempty_group + 1); -} - -void CollisionMatrix::deallocateGroup(unsigned int group) { - // Call lunalua event - if (gLunaLua.isValid()) { - std::shared_ptr deallocation_event = std::make_shared(deallocation_event_name, false); - deallocation_event->setDirectEventName(deallocation_event_name); - deallocation_event->setLoopable(false); - gLunaLua.callEvent(deallocation_event, group); - } - - // Put group index in deallocated group indices - deallocated_groups.push(group); -} - -bool CollisionMatrix::defaultBehavior(unsigned int i, unsigned int j) const { - return i == 0 || i != j; } \ No newline at end of file diff --git a/LunaDll/Misc/CollisionMatrix.h b/LunaDll/Misc/CollisionMatrix.h index db5b8d3e9..5cd052826 100644 --- a/LunaDll/Misc/CollisionMatrix.h +++ b/LunaDll/Misc/CollisionMatrix.h @@ -3,40 +3,105 @@ #include #include +#include +#include #include +#include class CollisionMatrix { + // Forward declaration + class GroupOrIndex; + + // The collision matrix itself. matrix[j].count(i) == 1 if getGroupsCollide(i, j) != default_behavior(i, j). + std::vector> matrix; + + // The type of `deallocated_groups` using queue_type = std::priority_queue::container_type, std::greater>; - std::vector> matrix; // The collision matrix itself. matrix[j].count(i) == 1 if getGroupsCollide(i, j) != default_behavior(i, j). + // Contains all currently deallocated groups + queue_type deallocated_groups; + + // The next group to be allocated if deallocated_groups is empty + unsigned int next_group; + + // How many references to the group exist. Index 0 corresponds to group 1 + std::vector reference_count; + + // Maps collision groups to their respective collision group indices. Doesn't contain the default collision group + std::unordered_map string_to_index; + + // Maps collision group indices to their respective collision groups. Index 0 corresponds to group 1 + std::vector index_to_string; + + // Removes all trailing empty sets from the collision matrix + void cleanupMatrix(); + + // Allocates a new collision group index + unsigned int allocateGroupIndex(char const* collisionGroup); + + // Tries to get the collision group index of a collision group + GroupOrIndex tryGetGroupIndex(char const* collisionGroup); - queue_type deallocated_groups; // Contains all currently deallocated groups - unsigned int next_group; // The next group to be allocated if deallocated_groups is empty + // Auxiliary class for representing a possibly allocated collision group + class GroupOrIndex { + friend GroupOrIndex CollisionMatrix::tryGetGroupIndex(char const* collisionGroup); - std::vector reference_count; // How many references to the group exist. Index 0 corresponds to group 1 + union { + unsigned int index; + char const* name; + }; + bool is_allocated; + explicit GroupOrIndex(unsigned int i); + explicit GroupOrIndex(char const* n); - char const* deallocation_event_name; // The lunalua event which is called whenever a group is deallocated - - void cleanupMatrix(); // Removes all trailing empty sets from the collision matrix - void deallocateGroup(unsigned int group); // Deallocates a collision group + public: + bool isAllocated(); + unsigned int getIndex(); + char const* getGroup(); + }; -protected: - bool defaultBehavior(unsigned int i, unsigned int j) const; // The default behavior of the collision matrix + + // Gets the collision group index of a collision group, allocates it if it's not + unsigned int getOrAllocateGroupIndex(char const* collisionGroup); + + // Deallocates a collision group + void deallocateGroupIndex(unsigned int groupIndex); + + // Increments the reference count of a group index + void incrementReferenceCount(unsigned int groupIndex); + + // The default behavior of the collision matrix + bool defaultBehavior(GroupOrIndex i, GroupOrIndex j) const; + bool defaultBehavior(unsigned int i, unsigned int j) const; + bool defaultBehavior(char const* i, char const* j) const; public: - CollisionMatrix() = delete; - CollisionMatrix(char const* deallocation_event_name); + CollisionMatrix(); + + // Resets this collision matrix + void clear(); + + // Returns the collision group of a collision group index + std::string const& getGroupFromIndex(unsigned int groupIndex); + + // Handles reference counting when a collision group field is assigned a new value, returns the collision group index of `newGroup` + unsigned int assignGroup(unsigned int previousGroupIndex, char const* newGroup); + + // Decrements the reference count of a group index + void decrementReferenceCount(unsigned int groupIndex); + + // Reads the collision matrix + bool getGroupsCollide(char const* i, char const* j); + bool getGroupsCollide(unsigned int i, char const* j); + bool getGroupsCollide(unsigned int i, unsigned int j); - void clear(); // Resets this collision matrix - unsigned int allocateIndex(); // Allocates a new collision group index - void incrementReferenceCount(unsigned int group); // Increments the reference count of a group index - void decrementReferenceCount(unsigned int group); // Decrements the reference count of a group index - bool getIndicesCollide(unsigned int i, unsigned int j); // Reads the collision matrix - void setIndicesCollide(unsigned int i, unsigned int j, bool collide); // Writes to the collision matrix + // Writes to the collision matrix + void setGroupsCollide(char const* i, char const* j, bool collide); }; +// Global collision matrix extern CollisionMatrix gCollisionMatrix; #endif diff --git a/LunaDll/Misc/DeclareCall.h b/LunaDll/Misc/DeclareCall.h new file mode 100644 index 000000000..f1064114b --- /dev/null +++ b/LunaDll/Misc/DeclareCall.h @@ -0,0 +1,81 @@ +#ifndef DECLARECALL_H_ +#define DECLARECALL_H_ + +#ifdef SYSCALLS_IMPL + #include + #include + + #ifdef __clang__ + // Auxiliary macros to only expand __COUNTER__ once + #define DECLARE_CALL_AUX(counter, ...) DECLARE_CALL_AUX2(counter, __VA_ARGS__) + #define DECLARE_CALL_AUX2(counter, ...) DECLARE_CALL_AUX3(functionAddr ## counter, __VA_ARGS__) + + // NB: I'm using AT&T syntax because of a clang bug: https://github.com/llvm/llvm-project/issues/60893 + #define DECLARE_CALL_AUX3(functionAddr, ...) \ + static std::uintptr_t functionAddr = 0; \ + __declspec(naked) __VA_ARGS__ { \ + __asm__ volatile ( \ + "movl %[FunctionAddr], %%eax\n" \ + "testl %%eax, %%eax\n" \ + "jnz 1f\n" \ + "movl %[ModuleHandle], %%eax\n" \ + "testl %%eax, %%eax\n" \ + "jnz 2f\n" \ + "push %[DllName]\n" \ + "call %P[GetModuleHandleW]\n" \ + "movl %%eax, %[ModuleHandle]\n" \ + "2: push %[FunctionName]\n" \ + "push %%eax\n" \ + "call %P[GetProcAddress]\n" \ + "movl %%eax, %[FunctionAddr]\n" \ + "1: jmp *%%eax\n" \ + : \ + : [DllName] "i" (dllName), \ + [FunctionName] "i" (&__func__), \ + [GetModuleHandleW] "s" (&GetModuleHandleW), \ + [GetProcAddress] "s" (&GetProcAddress), \ + [FunctionAddr] "m" (functionAddr), \ + [ModuleHandle] "m" (moduleHandle) \ + ); \ + } + + #define DECLARE_CALL(...) DECLARE_CALL_AUX(__COUNTER__, __VA_ARGS__) + + #else + #define DECLARE_CALL(...) \ + __declspec(naked) __VA_ARGS__ { \ + static char functionName[] = __func__; \ + static std::uintptr_t functionAddr = 0; \ + __asm { \ + __asm mov eax, functionAddr \ + __asm test eax, eax \ + __asm jnz funcExists \ + __asm mov eax, moduleHandle \ + __asm test eax, eax \ + __asm jnz moduleExists \ + __asm push offset dllName \ + __asm call GetModuleHandleW \ + __asm mov moduleHandle, eax \ + __asm moduleExists: push offset functionName \ + __asm push eax \ + __asm call GetProcAddress \ + __asm mov functionAddr, eax \ + __asm funcExists: jmp eax \ + } \ + } + #endif + + #define IMPORT_FROM(dllFilename) \ + namespace dllFilename { \ + static HMODULE moduleHandle = NULL; \ + static constexpr wchar_t dllName[] = L ## #dllFilename ".dll"; \ + } \ + namespace dllFilename + +#else + #define DECLARE_CALL(...) __VA_ARGS__; + + #define IMPORT_FROM(dllFilename) namespace dllFilename +#endif + +#endif \ No newline at end of file diff --git a/LunaDll/Misc/FileUtils.cpp b/LunaDll/Misc/FileUtils.cpp new file mode 100644 index 000000000..f5355974f --- /dev/null +++ b/LunaDll/Misc/FileUtils.cpp @@ -0,0 +1,152 @@ +#include "FileUtils.h" +#include +#include +#include +#include +#include +#include + +// Lua file object type, taken from LuaJIT's lib_io.c +enum class IOFileUDType : std::uint32_t { + IOFILE_TYPE_FILE = 0, /* Regular file. */ + IOFILE_TYPE_PIPE = 1, /* Pipe. */ + IOFILE_TYPE_STDF = 2, /* Standard file handle. */ + IOFILE_TYPE_MASK = 3, + IOFILE_FLAG_CLOSE = 4 /* Close after io.lines() iterator. */ +}; + +inline IOFileUDType operator|(IOFileUDType a, IOFileUDType b) { + return static_cast(static_cast::type>(a) | static_cast::type>(b)); +} + +// Lua file object structure, taken from LuaJIT's lib_io.c +struct IOFileUD { + std::FILE *fp; /* File handle. */ + IOFileUDType type; /* File type. */ +}; + +// Create a standard file object, set the field `name` of lua stack element -1 to it, then set its metatable to lua stack element -2, taken from LuaJIT's lib_io.c +// Note: This requires modifying lua51.lib to make _io_std_new a global symbol. I know this is incredibly cursed but I couldn't think of any other way. +// Note 2: The actual return type of io_std_new is GCobj* (internal luajit garbage collected object), we don't need it so we can ignore it. +extern "C" { + void io_std_new(lua_State *L, std::FILE* file, const char* name); +} + +// Creates a lua file object from a nonnull C FILE* object +luabind::object FileUtils::CFileToLua(lua_State* L, std::FILE* file, bool forIoLines) { + + // The table which will contain the newly created file + luabind::object fileTable = luabind::newtable(L); + + // The metatable of file objects + luabind::object fileMetatable = luabind::object(luabind::from_stack(L, LUA_REGISTRYINDEX))["FILE*"]; + + // Push metatable and table to store the new file object to stack + fileMetatable.push(L); + fileTable.push(L); + + // Create a new file object and store it to fileTable + io_std_new(L, file, "fileObject"); + + // Pop fileMetatable and fileTable + lua_pop(L, 2); + + // Get newly created file object + luabind::object fileObject = fileTable["fileObject"]; + + // Push fileObject to stack + fileObject.push(L); + + // Get address of file userdata memory + IOFileUD* luaFile = (IOFileUD*) lua_topointer(L, -1); + + // Pop fileObject from stack + lua_pop(L, 1); + + // Properly set the type of the file object (io_std_new initializes it to IOFileUDType::IOFILE_TYPE_STDF) + if (forIoLines) { + luaFile->type = IOFileUDType::IOFILE_TYPE_FILE | IOFileUDType::IOFILE_FLAG_CLOSE; + } else { + luaFile->type = IOFileUDType::IOFILE_TYPE_FILE; + } + + // Return new file object + return fileObject; +} + +// Converts a file opening mode string to flags, return false if the mode string is invalid +bool FileUtils::ParseFileOpeningMode(const char* mode, FileOpeningMode& out) { + // Get mode string length + std::size_t len = std::strlen(mode); + + // Empty string is invalid + if (len == 0) { + return false; + } + + // Get base character + char baseMode = mode[0]; + + // Refuse invalid modes + if (baseMode != 'r' && baseMode != 'w' && baseMode != 'a') { + return false; + } + + // Does the mode contains a '+'? + bool readWriteFlag = false; + + // Does the mode contains a 'b'? + bool binaryModeFlag = false; + + // check for modifiers + for (std::size_t i = 1; i < len; i++) { + char modifier = mode[i]; + + if (modifier == 'b') { // binary flag + // only one occurence allowed + if (binaryModeFlag) { + return false; + } + + binaryModeFlag = true; + } else if (modifier == '+') { // read+write flag + // only one occurence allowed + if (readWriteFlag) { + return false; + } + + readWriteFlag = true; + } + } + + // determine if the file must exist + out.fileMustExist = (baseMode == 'r'); + + // choose between _O_RDONLY, _O_WRONLY and _O_RDWR + if (readWriteFlag) { + out.flags = _O_RDWR; + } else if (baseMode == 'r') { + out.flags = _O_RDONLY; + } else { + out.flags = _O_WRONLY; + } + + // determine if we request write access + out.requestWrite = (out.flags != _O_RDONLY); + + // choose between _O_APPEND and _O_TRUNC + if (baseMode == 'w') { + out.flags |= _O_TRUNC; + } else if (baseMode == 'a') { + out.flags |= _O_APPEND; + } + + // Add binary or text flag + if (binaryModeFlag) { + out.flags |= _O_BINARY; + } else { + out.flags |= _O_TEXT; + } + + return true; +} \ No newline at end of file diff --git a/LunaDll/Misc/FileUtils.h b/LunaDll/Misc/FileUtils.h new file mode 100644 index 000000000..98c12ce02 --- /dev/null +++ b/LunaDll/Misc/FileUtils.h @@ -0,0 +1,19 @@ +#ifndef FileUtils_hhh +#define FileUtils_hhh + +#include +#include +#include + +namespace FileUtils { + struct FileOpeningMode { + int flags; // Win32 file opening flags + bool requestWrite; // Do we require write access? + bool fileMustExist; // Does the file must exist? + }; + + luabind::object CFileToLua(lua_State* L, std::FILE* file, bool forIoLines); + bool ParseFileOpeningMode(const char* mode, FileOpeningMode& out); +} + +#endif \ No newline at end of file diff --git a/LunaDll/Misc/RAIIHandle.h b/LunaDll/Misc/RAIIHandle.h new file mode 100644 index 000000000..ced1f8807 --- /dev/null +++ b/LunaDll/Misc/RAIIHandle.h @@ -0,0 +1,89 @@ +#ifndef RAIIHandle_hhh +#define RAIIHandle_hhh + +#include + +class RAIIHandle { + HANDLE h; + +public: + // Close the handle upon destruction + ~RAIIHandle() { + close(); + } + + // Initialize with invalid handle + RAIIHandle() : h(INVALID_HANDLE_VALUE) {} + + // No copy constructor + RAIIHandle(RAIIHandle const&) = delete; + + // Move constructor + RAIIHandle(RAIIHandle&& that) : h(that.h) { + // Set the handle of that to INVALID_HANDLE_VALUE to avoid it getting closed + that.h = INVALID_HANDLE_VALUE; + } + + // Construct from handle + RAIIHandle(HANDLE handle) : h(handle) {} + + // No copy assignment + RAIIHandle& operator=(RAIIHandle const&) = delete; + + // Move assignment + RAIIHandle& operator=(RAIIHandle&& that) { + *this = that.h; + + that.h = INVALID_HANDLE_VALUE; + + return *this; + } + + // Assign handle + RAIIHandle& operator=(HANDLE handle) { + // Close current handle + close(); + + // Assign new handle + h = handle; + + return *this; + } + + // Get mutable reference to handle + HANDLE& getHandleRef() { + return h; + } + + // Get const reference to handle + HANDLE const& getHandleRef() const { + return h; + } + + // Get handle without taking ownership + HANDLE borrow() const { + return h; + } + + // Get handle ownership + HANDLE takeOwnership() { + HANDLE handle = h; + h = INVALID_HANDLE_VALUE; + return handle; + } + + // Check if the handle is valid + bool isValid() const { + return h != INVALID_HANDLE_VALUE; + } + + // Close handle + void close() { + if (h != INVALID_HANDLE_VALUE) { + CloseHandle(h); + } + h = INVALID_HANDLE_VALUE; + } +}; + +#endif \ No newline at end of file diff --git a/LunaDll/Misc/RuntimeHookComponents/RuntimeHookHooks.cpp b/LunaDll/Misc/RuntimeHookComponents/RuntimeHookHooks.cpp index 948e9644a..94c37b61a 100644 --- a/LunaDll/Misc/RuntimeHookComponents/RuntimeHookHooks.cpp +++ b/LunaDll/Misc/RuntimeHookComponents/RuntimeHookHooks.cpp @@ -1966,7 +1966,7 @@ static unsigned int __stdcall runtimeHookGrabbedNPCCollisionGroupInternal(int np ExtendedNPCFields* extA = NPC::GetRawExtended(npcAIdx); ExtendedNPCFields* extB = NPC::GetRawExtended(npcBIdx); - if (!gCollisionMatrix.getIndicesCollide(extA->collisionGroup, extB->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(extA->collisionGroup, extB->collisionGroup)) // Check collision matrix return 0; // Collision cancelled return -1; // Collision goes ahead @@ -3910,14 +3910,14 @@ static unsigned int __stdcall runtimeHookBlockNPCFilterInternal(unsigned int hit { ExtendedNPCFields* ownerExt = NPC::GetRawExtended(block->OwnerNPCIdx); - if (!gCollisionMatrix.getIndicesCollide(ext->collisionGroup,ownerExt->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(ext->collisionGroup,ownerExt->collisionGroup)) // Check collision matrix return 0; } else { ExtendedBlockFields* blockExt = Blocks::GetRawExtended(blockIdx); - if (!gCollisionMatrix.getIndicesCollide(ext->collisionGroup,blockExt->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(ext->collisionGroup,blockExt->collisionGroup)) // Check collision matrix return 0; } @@ -3968,7 +3968,7 @@ static unsigned int __stdcall runtimeHookNPCCollisionGroupInternal(int npcAIdx, ExtendedNPCFields* extA = NPC::GetRawExtended(npcAIdx); ExtendedNPCFields* extB = NPC::GetRawExtended(npcBIdx); - if (!gCollisionMatrix.getIndicesCollide(extA->collisionGroup,extB->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(extA->collisionGroup,extB->collisionGroup)) // Check collision matrix return 0; // Collision cancelled return -1; // Collision goes ahead @@ -4057,7 +4057,7 @@ static unsigned int __stdcall runtimeHookBlockPlayerFilterInternal(short playerI // Collision groups ExtendedBlockFields* blockExt = Blocks::GetRawExtended(blockIdx); - if (!gCollisionMatrix.getIndicesCollide(playerExt->collisionGroup,blockExt->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(playerExt->collisionGroup,blockExt->collisionGroup)) // Check collision matrix { return 0; } @@ -4115,7 +4115,7 @@ static unsigned int __stdcall runtimeHookPlayerNPCInteractionCheckInternal(short // Collision groups ExtendedNPCFields* npcExt = NPC::GetRawExtended(npcIdx); - if (!gCollisionMatrix.getIndicesCollide(playerExt->collisionGroup,npcExt->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(playerExt->collisionGroup,npcExt->collisionGroup)) // Check collision matrix { return 0; } @@ -4234,7 +4234,7 @@ static unsigned int __stdcall runtimeHookPlayerPlayerInteractionInternal(short* } // Collision groups - if (!gCollisionMatrix.getIndicesCollide(extA->collisionGroup,extB->collisionGroup)) // Check collision matrix + if (!gCollisionMatrix.getGroupsCollide(extA->collisionGroup,extB->collisionGroup)) // Check collision matrix { return 0; } diff --git a/LunaDll/Misc/Syscalls.cpp b/LunaDll/Misc/Syscalls.cpp new file mode 100644 index 000000000..664d5bf7b --- /dev/null +++ b/LunaDll/Misc/Syscalls.cpp @@ -0,0 +1,2 @@ +#define SYSCALLS_IMPL +#include "Syscalls.h" \ No newline at end of file diff --git a/LunaDll/Misc/Syscalls.h b/LunaDll/Misc/Syscalls.h new file mode 100644 index 000000000..22d4bcf3c --- /dev/null +++ b/LunaDll/Misc/Syscalls.h @@ -0,0 +1,66 @@ +#ifndef SYSCALLS_H_ +#define SYSCALLS_H_ + +#include +#include +#include +#include "DeclareCall.h" + +IMPORT_FROM(ntdll) { + DECLARE_CALL( + NTSTATUS NTAPI NtCreateFile( + PHANDLE FileHandle, + ACCESS_MASK DesiredAccess, + POBJECT_ATTRIBUTES ObjectAttributes, + PIO_STATUS_BLOCK IoStatusBlock, + PLARGE_INTEGER AllocationSize, + ULONG FileAttributes, + ULONG ShareAccess, + ULONG CreateDisposition, + ULONG CreateOptions, + PVOID EaBuffer, + ULONG EaLength + ) + ) + + DECLARE_CALL( + NTSTATUS NTAPI NtQueryDirectoryFile( + HANDLE FileHandle, + HANDLE Event, + PIO_APC_ROUTINE ApcRoutine, + PVOID ApcContext, + PIO_STATUS_BLOCK IoStatusBlock, + PVOID FileInformation, + ULONG Length, + FILE_INFORMATION_CLASS FileInformationClass, + BOOLEAN ReturnSingleEntry, + PUNICODE_STRING FileName, + BOOLEAN RestartScan + ) + ) + + DECLARE_CALL( + VOID NTAPI RtlInitUnicodeString( + PUNICODE_STRING DestinationString, + __drv_aliasesMem PCWSTR SourceString + ) + ) + + DECLARE_CALL( + ULONG NTAPI RtlNtStatusToDosError( + NTSTATUS Status + ) + ) + + DECLARE_CALL( + NTSTATUS NTAPI NtSetInformationFile( + HANDLE FileHandle, + PIO_STATUS_BLOCK IoStatusBlock, + PVOID FileInformation, + ULONG Length, + FILE_INFORMATION_CLASS FileInformationClass + ) + ) +} + +#endif \ No newline at end of file diff --git a/LunaDll/Misc/Win32PathUtils.cpp b/LunaDll/Misc/Win32PathUtils.cpp index 2f815c18a..68aaac672 100644 --- a/LunaDll/Misc/Win32PathUtils.cpp +++ b/LunaDll/Misc/Win32PathUtils.cpp @@ -107,93 +107,6 @@ static uint32_t getRandomU32() return rng(); } -// Function to write data to a file, making repalcement as close to atomic as Windows seems to allow -bool writeFileAtomic(const std::string& path, const void* data, ptrdiff_t dataLen) -{ - return writeFileAtomic(Str2WStr(path), data, dataLen); -} - -// Function to write data to a file, making repalcement as close to atomic as Windows seems to allow -bool writeFileAtomic(const std::wstring& path, const void* data, ptrdiff_t dataLen) -{ - std::wstring pathW = GetWin32LongPath(path); - if ((pathW.size() <= 0) || (pathW[pathW.size() - 1] == L'\\')) - { - // Can't end with backslash - return false; - } - - // Make temporary file path - std::wstring tmpPath; - HANDLE tmpHwnd = INVALID_HANDLE_VALUE; - for (uint32_t i=0; (i<=0xFFFF) && (tmpHwnd == INVALID_HANDLE_VALUE); i++) - { - static const wchar_t* digits = L"0123456789ABCDEFGHIJKLMNOPQRSTUV"; - tmpPath = pathW + L"."; - uint32_t rng = getRandomU32(); - for (int j = 0; j < 16; j += 5) - { - tmpPath += digits[(rng >> j) & 0xF]; - } - tmpPath += L".TMP"; - tmpHwnd = CreateFileW(tmpPath.c_str(), GENERIC_WRITE, 0, NULL, CREATE_NEW, NULL, NULL); - if (tmpHwnd == INVALID_HANDLE_VALUE) - { - // No success - if (GetLastError() == ERROR_FILE_EXISTS) - { - // File exists? Retry - continue; - } - else - { - // Other failure, abort - return false; - } - } - } - if (tmpHwnd == INVALID_HANDLE_VALUE) - { - // Something very wrong... even 0xFFFF retries got "ERROR_FILE_EXISTS" - return false; - } - - // DEBUG: wprintf(L"Opened tmp file: %s\n", tmpPath.c_str()); - - DWORD bytesWritten = 0; - if ((WriteFile(tmpHwnd, data, dataLen, &bytesWritten, NULL) == 0) || (bytesWritten != dataLen)) - { - // Write failed - CloseHandle(tmpHwnd); - DeleteFileW(tmpPath.c_str()); - return false; - } - - // DEBUG: wprintf(L"Wrote %u bytes\n", bytesWritten); - - // Close temporary file - CloseHandle(tmpHwnd); - - // Try to replace target file if possible - if (ReplaceFileW(pathW.c_str(), tmpPath.c_str(), NULL, REPLACEFILE_IGNORE_MERGE_ERRORS, NULL, NULL) != 0) - { - // Success! We're done! - // DEBUG: wprintf(L"Replaced file %s\n", pathW.c_str()); - return true; - } - - // Otherwise, let's try to move the file - if (MoveFileEx(tmpPath.c_str(), pathW.c_str(), MOVEFILE_REPLACE_EXISTING) != 0) - { - // DEBUG: wprintf(L"Moved file %s\n", pathW.c_str()); - return true; - } - - // DEBUG: wprintf(L"Failed to write to %s\n", pathW.c_str()); - DeleteFileW(tmpPath.c_str()); - return false; -} - bool readFileToStr(const std::string& path, std::string& out) { return readFileToStr(Str2WStr(path), out); diff --git a/LunaDll/Misc/Win32PathUtils.h b/LunaDll/Misc/Win32PathUtils.h index 813a62ac2..a6bb3aca9 100644 --- a/LunaDll/Misc/Win32PathUtils.h +++ b/LunaDll/Misc/Win32PathUtils.h @@ -165,9 +165,6 @@ std::wstring GetWin32LongPath(const char* path); std::wstring GetWin32LongPath(const std::string& path); std::wstring GetWin32LongPath(const std::wstring& path); -bool writeFileAtomic(const std::string& path, const void* data, ptrdiff_t dataLen); -bool writeFileAtomic(const std::wstring& path, const void* data, ptrdiff_t dataLen); - bool readFileToStr(const std::string& path, std::string& out); bool readFileToStr(const std::wstring& path, std::string& out); diff --git a/LunaDll/libs/lua/lib/lua51.lib b/LunaDll/libs/lua/lib/lua51.lib index 909a9a09c..2e7c2eb75 100644 Binary files a/LunaDll/libs/lua/lib/lua51.lib and b/LunaDll/libs/lua/lib/lua51.lib differ diff --git a/LunadllNewLauncher/SMBXLauncher/SMBXLauncher.pro b/LunadllNewLauncher/SMBXLauncher/SMBXLauncher.pro index a5c1f7950..cc2c79594 100644 --- a/LunadllNewLauncher/SMBXLauncher/SMBXLauncher.pro +++ b/LunadllNewLauncher/SMBXLauncher/SMBXLauncher.pro @@ -41,7 +41,8 @@ SOURCES += main.cpp\ Utils/Json/qjsonurlvalidationexception.cpp \ launchercustomwebpage.cpp \ hybridlogger.cpp \ - devtoolsdialog.cpp + devtoolsdialog.cpp \ + launcherurlrequestinterceptor.cpp HEADERS += mainlauncherwindow.h \ ../../LunaDll/Input/LunaGameController.h \ @@ -64,7 +65,8 @@ HEADERS += mainlauncherwindow.h \ Utils/Json/qjsonurlvalidationexception.h \ launchercustomwebpage.h \ hybridlogger.h \ - devtoolsdialog.h + devtoolsdialog.h \ + launcherurlrequestinterceptor.h # LunaLoader # win32: SOURCES += ../../LunaLoader/LunaLoaderPatch.cpp diff --git a/LunadllNewLauncher/SMBXLauncher/launchercustomwebpage.cpp b/LunadllNewLauncher/SMBXLauncher/launchercustomwebpage.cpp index 22cf93661..89a5947e5 100644 --- a/LunadllNewLauncher/SMBXLauncher/launchercustomwebpage.cpp +++ b/LunadllNewLauncher/SMBXLauncher/launchercustomwebpage.cpp @@ -8,13 +8,10 @@ LauncherCustomWebPage::LauncherCustomWebPage(QObject *parent) : bool LauncherCustomWebPage::acceptNavigationRequest(const QUrl &url, QWebEnginePage::NavigationType type, bool isMainFrame) { - if (type == QWebEnginePage::NavigationTypeLinkClicked) - { - qDebug() << url; - if(!url.isLocalFile()) { - QDesktopServices::openUrl(url); - return false; - } + qDebug() << "Entering acceptNavigationRequest(url = " << url << ", type = " << type << ", isMainFrame = " << isMainFrame << ")"; + if(!url.isLocalFile()) { + QDesktopServices::openUrl(url); + return false; } return QWebEnginePage::acceptNavigationRequest(url, type, isMainFrame); } diff --git a/LunadllNewLauncher/SMBXLauncher/launcherurlrequestinterceptor.cpp b/LunadllNewLauncher/SMBXLauncher/launcherurlrequestinterceptor.cpp new file mode 100644 index 000000000..630d681c0 --- /dev/null +++ b/LunadllNewLauncher/SMBXLauncher/launcherurlrequestinterceptor.cpp @@ -0,0 +1,23 @@ +#include "launcherurlrequestinterceptor.h" + +LauncherUrlRequestInterceptor::LauncherUrlRequestInterceptor(QObject* parent, QDir dataFolder) : + QWebEngineUrlRequestInterceptor(parent), + dataFolderCanonicalPath(dataFolder.canonicalPath()) +{ + // Add trailing slash to the data folder canonical path if needed + if (!dataFolderCanonicalPath.endsWith('/')) { + dataFolderCanonicalPath.append('/'); + } +} + +void LauncherUrlRequestInterceptor::interceptRequest(QWebEngineUrlRequestInfo &info) { + QUrl requestedFileUrl = info.requestUrl(); + + if (requestedFileUrl.isLocalFile()) { + QDir requestedFilePath(requestedFileUrl.toLocalFile()); + + if (!requestedFilePath.canonicalPath().startsWith(dataFolderCanonicalPath)) { + info.block(true); + } + } +} \ No newline at end of file diff --git a/LunadllNewLauncher/SMBXLauncher/launcherurlrequestinterceptor.h b/LunadllNewLauncher/SMBXLauncher/launcherurlrequestinterceptor.h new file mode 100644 index 000000000..09d3b481b --- /dev/null +++ b/LunadllNewLauncher/SMBXLauncher/launcherurlrequestinterceptor.h @@ -0,0 +1,18 @@ +#ifndef LAUNCHERURLREQUESTINTERCEPTOR_H +#define LAUNCHERURLREQUESTINTERCEPTOR_H + +#include +#include +#include +#include + +class LauncherUrlRequestInterceptor : public QWebEngineUrlRequestInterceptor { + // Path to the data folder after symlink resolution + QString dataFolderCanonicalPath; + +public: + LauncherUrlRequestInterceptor(QObject* parent, QDir dataFolder); + virtual void interceptRequest(QWebEngineUrlRequestInfo &info) override; +}; + +#endif \ No newline at end of file diff --git a/LunadllNewLauncher/SMBXLauncher/mainlauncherwindow.cpp b/LunadllNewLauncher/SMBXLauncher/mainlauncherwindow.cpp index c838cad24..db25202c2 100644 --- a/LunadllNewLauncher/SMBXLauncher/mainlauncherwindow.cpp +++ b/LunadllNewLauncher/SMBXLauncher/mainlauncherwindow.cpp @@ -1,3 +1,5 @@ +#include "launcherurlrequestinterceptor.h" +#include #if defined(_WIN32) && !defined(_WIN64) #include #define WIN_CHECK_FOR_64BIT_CPU @@ -91,6 +93,9 @@ MainLauncherWindow::MainLauncherWindow(QWidget *parent) : ui->webLauncherPage->setPage(new LauncherCustomWebPage(ui->webLauncherPage)); QWebEnginePage* page = ui->webLauncherPage->page(); + LauncherUrlRequestInterceptor* interceptor = new LauncherUrlRequestInterceptor(this, QDir::current()); + page->profile()->setRequestInterceptor(interceptor); + // Set up the development tools dialog if (devDialogPtr == nullptr) {