diff --git a/MemoryModule.c b/MemoryModule.c index 8a348a1..924be3d 100644 --- a/MemoryModule.c +++ b/MemoryModule.c @@ -29,6 +29,8 @@ #pragma warning( disable : 4311 4312 ) #endif +#include + #include #include #include @@ -42,6 +44,11 @@ #define IMAGE_SIZEOF_BASE_RELOCATION (sizeof(IMAGE_BASE_RELOCATION)) #endif +#ifdef _WIN64 +// Support for SEH is currently only available for Win64 +#define WITH_SEH +#endif + #include "MemoryModule.h" typedef BOOL (WINAPI *DllEntryProc)(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved); @@ -55,6 +62,9 @@ typedef struct { BOOL initialized; BOOL isDLL; BOOL isRelocated; +#ifdef WITH_SEH + BOOL hasSEH; +#endif CustomLoadLibraryFunc loadLibrary; CustomGetProcAddressFunc getProcAddress; CustomFreeLibraryFunc freeLibrary; @@ -411,6 +421,50 @@ BuildImportTable(PMEMORYMODULE module) return result; } +#ifdef WITH_SEH + +static BOOL +SetupSEH(PMEMORYMODULE module) +{ + PIMAGE_RUNTIME_FUNCTION_ENTRY exceptionDesc; + unsigned char *codeBase = module->codeBase; + + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_EXCEPTION); + if (directory->Size == 0) { + return TRUE; + } + + exceptionDesc = (PIMAGE_RUNTIME_FUNCTION_ENTRY) (codeBase + directory->VirtualAddress); + if (!RtlAddFunctionTable((PRUNTIME_FUNCTION) (exceptionDesc), directory->Size / sizeof(IMAGE_RUNTIME_FUNCTION_ENTRY), (DWORD64) codeBase)) { + return FALSE; + } + + module->hasSEH = TRUE; + return TRUE; +} + +static BOOL +CleanupSEH(PMEMORYMODULE module) +{ + PIMAGE_RUNTIME_FUNCTION_ENTRY exceptionDesc; + unsigned char *codeBase = module->codeBase; + + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_EXCEPTION); + if (directory->Size == 0) { + return TRUE; + } + + exceptionDesc = (PIMAGE_RUNTIME_FUNCTION_ENTRY) (codeBase + directory->VirtualAddress); + if (!RtlDeleteFunctionTable((PRUNTIME_FUNCTION) (exceptionDesc))) { + return FALSE; + } + + module->hasSEH = FALSE; + return TRUE; +} + +#endif // WITH_SEH + static HCUSTOMMODULE _LoadLibrary(LPCSTR filename, void *userdata) { HMODULE result = LoadLibraryA(filename); @@ -508,6 +562,9 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, result->modules = NULL; result->initialized = FALSE; result->isDLL = (old_header->FileHeader.Characteristics & IMAGE_FILE_DLL) != 0; +#ifdef WITH_SEH + result->hasSEH = FALSE; +#endif result->loadLibrary = loadLibrary; result->getProcAddress = getProcAddress; result->freeLibrary = freeLibrary; @@ -558,6 +615,12 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, goto error; } +#ifdef WITH_SEH + if (!SetupSEH(result)) { + goto error; + } +#endif + // get entry point of loaded library if (result->headers->OptionalHeader.AddressOfEntryPoint != 0) { if (result->isDLL) { @@ -638,6 +701,13 @@ void MemoryFreeLibrary(HMEMORYMODULE mod) if (module == NULL) { return; } + +#ifdef WITH_SEH + if (module->hasSEH) { + CleanupSEH(module); + } +#endif + if (module->initialized) { // notify library about detaching from process DllEntryProc DllEntry = (DllEntryProc) (module->codeBase + module->headers->OptionalHeader.AddressOfEntryPoint);