#include <jni.h>
#include <string>
#include <stdio.h>

#include <elf.h>
#include <dlfcn.h>
#include <sys/mman.h>

#define SOINFO_NAME_LEN 128
struct soinfo
{
    const char name[SOINFO_NAME_LEN];
    Elf32_Phdr *phdr;
    int phnum;
    unsigned entry;
    unsigned base;
    unsigned size;

    int unused;  // DO NOT USE, maintained for compatibility.

    unsigned *dynamic;

    unsigned wrprotect_start;
    unsigned wrprotect_end;

    soinfo *next;
    unsigned flags;

    const char *strtab;
    Elf32_Sym *symtab;

    unsigned nbucket;
    unsigned nchain;
    unsigned *bucket;
    unsigned *chain;

    unsigned *plt_got;

    Elf32_Rel *plt_rel;
    unsigned plt_rel_count;

    Elf32_Rel *rel;
    unsigned rel_count;

    Elf32_Rela *plt_rela;
    unsigned plt_rela_count;

    Elf32_Rela *rela;
    unsigned rela_count;

    unsigned *preinit_array;
    unsigned preinit_array_count;

    unsigned *init_array;
    unsigned init_array_count;
    unsigned *fini_array;
    unsigned fini_array_count;

    void (*init_func)(void);
    void (*fini_func)(void);

    unsigned *ARM_exidx;
    unsigned ARM_exidx_count;
    
    unsigned refcount;
};

static unsigned elfhash(const char *_name)
{
    const unsigned char *name = (const unsigned char *) _name;
    unsigned h = 0, g;

    while(*name) {
        h = (h << 4) + *name++;
        g = h & 0xf0000000;
        h ^= g;
        h ^= g >> 24;
    }
    return h;
}

//获取模块基址
void *get_module_base(const char *moduleName) {
    char buf[260] = {};
    void *addr = NULL;

    FILE *fp = fopen("/proc/self/maps", "r");
    if (fp == NULL) {
        perror("fopen");
        return NULL;
    }

    while (!feof(fp)) {
        memset(buf, 0, sizeof buf);
        fgets(buf, sizeof buf, fp);
        if (strstr(buf, moduleName) != NULL) {
            sscanf(buf, "%08x", &addr);
            break;
        }
    }
    fclose(fp);
    return addr;
}

void FillSoInfoStruct(soinfo* si, const char* szLibName){
	//获取libc.so的模块基址
    char* libc_base = (char*)get_module_base(szLibName);
    si->base = (unsigned)libc_base;
	
    //拿动态节地址
    elf32_hdr* header = (elf32_hdr*)si->base;
    elf32_phdr* phdr = (elf32_phdr*)(si->base + header->e_phoff);
    Elf32_Dyn *dyn = NULL;
    for (int i = 0; i < header->e_phnum; ++i) {
        if (phdr[i].p_type == PT_DYNAMIC){
            dyn = (Elf32_Dyn*)(si->base + phdr[i].p_vaddr);
            si->dynamic = (unsigned *)dyn;
            break;
        }
    }

    printf("dyn = %p\n", dyn);

    //遍历表
    //遍历表的操作, Android源码中有写,这里查看的是Android 4.0.3_r1版本的源码
    //在xref: /bionic/linker/linker.c中
    //搜索DT_HASH
    unsigned *d;
    for(d = si->dynamic; *d; d++){
        switch(*d++){
            case DT_HASH:
                si->nbucket = ((unsigned *) (si->base + *d))[0];
                si->nchain = ((unsigned *) (si->base + *d))[1];
                si->bucket = (unsigned *) (si->base + *d + 8);
                si->chain = (unsigned *) (si->base + *d + 8 + si->nbucket * 4);
                break;
            case DT_STRTAB:
                si->strtab = (const char *) (si->base + *d);
                break;
            case DT_SYMTAB:
                si->symtab = (Elf32_Sym *) (si->base + *d);
                break;
        }
    }
}

//获取函数名对应的符号表的下标
int GetSymtabIndex(soinfo* si, char* szFuncName){

    //查询hash表
    int nIndex = elfhash(szFuncName) % si->nbucket;
    nIndex = si->bucket[nIndex];
    if (nIndex == 0){
        return 0;
    }

    do {
        if (strcmp(si->strtab + si->symtab[nIndex].st_name, szFuncName) == 0){
            break;
        }
        nIndex = si->chain[nIndex];
    } while (nIndex != 0);

    return nIndex;
}

//替换后的函数
//加static,可使该函数不导出
static void fun2(int n){
    puts("fun2");
}

__attribute__((constructor)) void fun1(){
    puts("fun1");

    //soinfo结构体是操作系统内部的结构体
    soinfo si = {0};

    //填充soinfo结构体
    FillSoInfoStruct(&si, "libc.so");
	
    //获取libc.so库中的exit函数在符号表中的下标
    int nIndex = GetSymtabIndex(&si, "exit");
    if (nIndex == 0){
        return;
    }

    //修改libc.so库中的exit函数地址为fun2的地址
    mprotect((void*)((int)&si.symtab[nIndex] & ~0xfff), 0x1000, PROT_READ | PROT_WRITE);
    si.symtab[nIndex].st_value = (char*)fun2 - (char*)si.base;
}

typedef void (*pfnEXIT)(int);
int main(int argc, char* argv[]){

    //直接调exit(0), 走的是plt函数
    void* handle = dlopen("libc.so", 0);
    pfnEXIT pfnExit = (pfnEXIT)dlsym(handle, "exit");
    pfnExit(0);

    return 0;
}