diff --git a/src/zzuf.c b/src/zzuf.c index 84ff72c..179059f 100644 --- a/src/zzuf.c +++ b/src/zzuf.c @@ -44,6 +44,8 @@ #endif #if defined HAVE_WINDOWS_H # include +# include +# include #endif #if defined HAVE_IO_H # include @@ -121,7 +123,8 @@ static char const *sig2name(int); #endif #if defined HAVE_WINDOWS_H static int dll_inject(void *, void *); -static void *get_entry(char const *); +static intptr_t get_base_address(DWORD); +static intptr_t get_entry_point_offset(char const *); #endif static void finfo(FILE *, struct opts *, uint32_t); #if defined HAVE_REGEX_H @@ -1088,11 +1091,6 @@ static int run_process(struct opts *opts, int pipes[][2]) #elif HAVE_WINDOWS_H pid = GetCurrentProcess(); - /* Get entry point */ - epaddr = get_entry(opts->newargv[0]); - if(!epaddr) - return -1; - memset(&sinfo, 0, sizeof(sinfo)); sinfo.cb = sizeof(sinfo); DuplicateHandle(pid, (HANDLE)_get_osfhandle(pipes[0][1]), pid, @@ -1107,6 +1105,12 @@ static int run_process(struct opts *opts, int pipes[][2]) if(!ret) return -1; + /* Get the child process's entry point address */ + epaddr = (void *)(get_base_address(pinfo.dwProcessId) + + get_entry_point_offset(opts->newargv[0])); + if(!epaddr) + return -1; + /* Insert the replacement code */ ret = dll_inject(pinfo.hProcess, epaddr); if(ret < 0) @@ -1197,11 +1201,35 @@ static int dll_inject(void *process, void *epaddr) return 0; } -static void *get_entry(char const *name) +/* Find the process's base address once it is loaded in memory (the header + * information is unreliable because of Vista's ASLR). */ +static intptr_t get_base_address(DWORD pid) +{ + MODULEENTRY32 entry; + intptr_t ret = 0; + void *list; + int k; + + list = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, pid); + entry.dwSize = sizeof(entry); + for(k = Module32First(list, &entry); k; k = Module32Next(list, &entry)) + { + /* FIXME: how do we select the correct module? */ + ret = (intptr_t)entry.modBaseAddr; + } + CloseHandle(list); + + return ret; +} + +/* Find the process's entry point address offset. The information is in + * the file's PE header. */ +static intptr_t get_entry_point_offset(char const *name) { PIMAGE_DOS_HEADER dos; PIMAGE_NT_HEADERS nt; - void *file, *map, *base, *ret = NULL; + intptr_t ret = 0; + void *file, *map, *base; file = CreateFile(name, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, 0, NULL); @@ -1231,8 +1259,7 @@ static void *get_entry(char const *name) && nt->FileHeader.Machine == IMAGE_FILE_MACHINE_I386 && nt->OptionalHeader.Magic == 0x10b /* IMAGE_NT_OPTIONAL_HDR32_MAGIC */) { - ret = (void *)(uintptr_t)(nt->OptionalHeader.ImageBase + - nt->OptionalHeader.AddressOfEntryPoint); + ret = (intptr_t)nt->OptionalHeader.AddressOfEntryPoint; } UnmapViewOfFile(base);