#include <jni.h>
#include <string>
#include <stdio.h>

#include <elf.h>
#include <dlfcn.h>
#include <sys/mman.h>

#include <unistd.h>
#include <limits.h>

#define SOINFO_NAME_LEN 128
struct soinfo
{
    const char name[SOINFO_NAME_LEN];
    Elf32_Phdr *phdr;
    int phnum;
    unsigned entry;
    unsigned base;
    unsigned size;

    int unused;  // DO NOT USE, maintained for compatibility.

    unsigned *dynamic;

    unsigned wrprotect_start;
    unsigned wrprotect_end;

    soinfo *next;
    unsigned flags;

    const char *strtab;
    Elf32_Sym *symtab;

    unsigned nbucket;
    unsigned nchain;
    unsigned *bucket;
    unsigned *chain;

    unsigned *plt_got;

    Elf32_Rel *plt_rel;
    unsigned plt_rel_count;

    Elf32_Rel *rel;
    unsigned rel_count;

    Elf32_Rela *plt_rela;
    unsigned plt_rela_count;

    Elf32_Rela *rela;
    unsigned rela_count;

    unsigned *preinit_array;
    unsigned preinit_array_count;

    unsigned *init_array;
    unsigned init_array_count;
    unsigned *fini_array;
    unsigned fini_array_count;

    void (*init_func)(void);
    void (*fini_func)(void);

    unsigned *ARM_exidx;
    unsigned ARM_exidx_count;

    unsigned refcount;
};

static unsigned elfhash(const char *_name)
{
    const unsigned char *name = (const unsigned char *) _name;
    unsigned h = 0, g;

    while(*name) {
        h = (h << 4) + *name++;
        g = h & 0xf0000000;
        h ^= g;
        h ^= g >> 24;
    }
    return h;
}

//获取模块基址
void *get_module_base(const char *moduleName) {
    char buf[260] = {};
    void *addr = NULL;

    FILE *fp = fopen("/proc/self/maps", "r");
    if (fp == NULL) {
        perror("fopen");
        return NULL;
    }

    while (!feof(fp)) {
        memset(buf, 0, sizeof buf);
        fgets(buf, sizeof buf, fp);
        if (strstr(buf, moduleName) != NULL) {
            sscanf(buf, "%08x", &addr);
            break;
        }
    }
    fclose(fp);
    return addr;
}

void FillSoInfoStruct(soinfo* si, const char* szLibName){
    //获取模块基址
    char* libc_base = (char*)get_module_base(szLibName);
    si->base = (unsigned)libc_base;

    //拿动态节地址
    elf32_hdr* header = (elf32_hdr*)si->base;
    elf32_phdr* phdr = (elf32_phdr*)(si->base + header->e_phoff);
    Elf32_Dyn *dyn = NULL;
    for (int i = 0; i < header->e_phnum; ++i) {
        if (phdr[i].p_type == PT_DYNAMIC){
            dyn = (Elf32_Dyn*)(si->base + phdr[i].p_vaddr);
            si->dynamic = (unsigned *)dyn;
            break;
        }
    }

    //遍历表
    //遍历表的操作, Android源码中有写,这里查看的是Android 4.0.3_r1版本的源码
    //在xref: /bionic/linker/linker.c中
    //搜索DT_HASH
    unsigned *d;
    for(d = si->dynamic; *d; d++){
        switch(*d++){
            case DT_HASH:
                si->nbucket = ((unsigned *) (si->base + *d))[0];
                si->nchain = ((unsigned *) (si->base + *d))[1];
                si->bucket = (unsigned *) (si->base + *d + 8);
                si->chain = (unsigned *) (si->base + *d + 8 + si->nbucket * 4);
                break;
            case DT_STRTAB:
                si->strtab = (const char *) (si->base + *d);
                break;
            case DT_SYMTAB:
                si->symtab = (Elf32_Sym *) (si->base + *d);
                break;
            case DT_JMPREL:
                //第一个重定位表
                si->plt_rel = (Elf32_Rel*) (si->base + *d);
                break;
            case DT_PLTRELSZ:
                si->plt_rel_count = *d / 8;
                break;
            case DT_REL:
                //第二个重定位表
                si->rel = (Elf32_Rel*) (si->base + *d);
                break;
            case DT_RELSZ:
                si->rel_count = *d / 8;
        }
    }
}

//获取函数名对应的符号表的下标
int GetSymtabIndex(soinfo* si, const char* szFuncName){
    //查询hash表
    int nIndex = elfhash(szFuncName) % si->nbucket;
    nIndex = si->bucket[nIndex];
    if (nIndex == 0){
        return 0;
    }

    do {
        if (strcmp(si->strtab + si->symtab[nIndex].st_name, szFuncName) == 0){
            break;
        }
        nIndex = si->chain[nIndex];
    } while (nIndex != 0);

    return nIndex;
}


__attribute__((constructor)) void fun1(){
    fflush(stdout);
    puts("fun1");
}


size_t FakeStrlen(const char* const){
    return 100;
}

typedef size_t (*PFN_STRLEN)(const char* const);
void* g_oldpfn = nullptr;

//全局变量
PFN_STRLEN g_pfnstrlen = strlen;

bool LuoHook(const char* lib, const char* name, void* addr, void** OldAddr){

    /*
     * 1.找到strlen符号索引(哈希表,符号表,字符串表)
     * 2.遍历重定位表,修改got地址(重定位表)
     **/

    //soinfo结构体是操作系统内部的结构体
    soinfo si = {0};
    //填充soinfo结构体
    FillSoInfoStruct(&si, lib);

    int nIndex = GetSymtabIndex(&si, name);

    if (nIndex == 0){
        return false;
    }

    //修改.rel.plt
    for (int i = 0; i < si.plt_rel_count; ++i) {
        //判断符号
        if (ELF32_R_SYM(si.plt_rel[i].r_info) == nIndex){
            //找到Got地址
            void** got = (void**)(si.base + si.plt_rel[i].r_offset);

            //修改内存保护属性
            size_t pageSize = sysconf(_SC_PAGE_SIZE);
            if (mprotect((void*)((int)got & ~(pageSize - 1)),
                    pageSize,
                    PROT_READ | PROT_WRITE) < 0){

                perror("mprotect");
                return false;
            }

            //保存旧的地址
            *OldAddr = *got;
            //修改got地址
            *got = addr;
        }
    }

    //.rel.dyn
    for (int i = 0; i < si.rel_count; ++i) {
        //判断符号
        if (ELF32_R_SYM(si.rel[i].r_info) == nIndex){
            //找到Got地址
            void** got = (void**)(si.base + si.rel[i].r_offset);

            //修改内存保护属性
            size_t pageSize = sysconf(_SC_PAGE_SIZE);
            if (mprotect((void*)((int)got & ~(pageSize - 1)),
                         pageSize,
                         PROT_READ | PROT_WRITE) < 0){

                perror("mprotect");
                return false;
            }

            //保存旧的地址
            *OldAddr = *got;
            //修改got地址
            *got = addr;
        }
    }

    return true;
}

int main(int argc, char* argv[]){

    //Got表Hook
    //遍历重定位表,修改got地址
    //有两个重定位表都需要来修
    //.rel.plt
    // .rel.dyn
    LuoHook("/data/local/tmp/luoelf", "strlen",(void*)FakeStrlen, (void**)&g_oldpfn);

    //全局变量
    printf("%d\n", g_pfnstrlen("1"));

    //局部变量
    PFN_STRLEN pfnStrlen = strlen;
    printf("%d\n", pfnStrlen("12"));

    //直接调
    printf("%d\n", strlen("123"));

    return 0;
}