#include <assert.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>

/* Allocate size bytes of executable memory. */
unsigned char *alloc_exec_mem(size_t size)
{
        void *ptr;

        ptr = mmap(0, size, PROT_READ | PROT_WRITE | PROT_EXEC,
                   MAP_PRIVATE | MAP_ANON, -1, 0);

        if (ptr == MAP_FAILED) {
                perror("mmap");
                exit(1);
        }

        return ptr;
}

typedef struct { unsigned char code; } reg_t;
static const reg_t EAX = { 0x0 };
static const reg_t ECX = { 0x1 };
static const reg_t EDX = { 0x2 };
static const reg_t EBX = { 0x3 };
static const reg_t ESP = { 0x4 };
static const reg_t EBP = { 0x5 };
static const reg_t ESI = { 0x6 };
static const reg_t EDI = { 0x7 };

typedef struct asm_buf_t asm_buf_t;
struct asm_buf_t {
        unsigned char *data;
        ssize_t size;
        ssize_t capacity;
};

void emit_byte(asm_buf_t *buffer, int32_t byte)
{
        assert((uint32_t) byte <= UINT8_MAX);
        assert(buffer->size < buffer->capacity);

        buffer->data[buffer->size++] = (unsigned char) byte;
}

void emit_int32(asm_buf_t *buffer, int32_t value)
{
        /* Emit value in little-endian order. */
        emit_byte(buffer, (value >>  0) & 0xff);
        emit_byte(buffer, (value >>  8) & 0xff);
        emit_byte(buffer, (value >> 16) & 0xff);
        emit_byte(buffer, (value >> 24) & 0xff);
}

void asm_ret(asm_buf_t *buffer)
{
        /* 1100 0011 */
        emit_byte(buffer, 0xC3);
}

void asm_xor(asm_buf_t *buffer, reg_t src, reg_t dst)
{
        /* 0011 000 | w=1 | 11 | src(3) | dst(3) */
        emit_byte(buffer, 0x31);
        emit_byte(buffer, (0x3 << 6) | (src.code << 3) | dst.code);
}

void asm_mov_imm32(asm_buf_t *buffer, reg_t dst, int32_t value)
{
        if (value == 0) {
                asm_xor(buffer, dst, dst);
                return;
        }

        /* 1011 | w=1 | dst(3) | imm32 */
        emit_byte(buffer, 0xB8 | dst.code);
        emit_int32(buffer, value);
}

void asm_mov_mem_reg(asm_buf_t *buffer, reg_t src, int32_t disp, reg_t dst)
{
        unsigned char mod, reg, rm;

        /* 1000 101 | w=1 | mod(2) | reg(3) | rm(3) | [sib(8)] | [disp(8/32)] */
        emit_byte(buffer, 0x8B);

        if (disp == 0) {
                mod = 0x0;
        } else if (disp <= INT8_MAX) {
                mod = 0x1;
        } else {
                mod = 0x2;
        }

        rm = src.code;
        reg = dst.code;

        emit_byte(buffer, (mod << 6) | (reg << 3) | rm);

        if (src.code == ESP.code) {
                /* Emit SIB (Scaled Index Byte). */
                /* ss=00 | index=100 | base=esp(3) */
                emit_byte(buffer, (0x0 << 6) | (0x4 << 3) | ESP.code);
        }

        if (mod == 0x1) {
                emit_byte(buffer, (unsigned char)disp);
        } else if (mod == 0x2) {
                emit_int32(buffer, disp);
        }
}

void asm_test(asm_buf_t *buffer, reg_t reg1, reg_t reg2)
{
        /* 1000 010 | w=1 | 11 | reg1(3) | reg2(3) */
        emit_byte(buffer, 0x85);
        emit_byte(buffer, (0x3 << 6) | (reg1.code << 3) | reg2.code);
}

void asm_xchg(asm_buf_t *buffer, reg_t reg1, reg_t reg2)
{
        if (reg1.code == EAX.code) {
                /* 1001 | 0 | reg2 */
                emit_byte(buffer, (0x9 << 4) | (0x0 << 3) | reg2.code);
                return;
        }

        if (reg2.code == EAX.code) {
                asm_xchg(buffer, reg2, reg1);
                return;
        }

        /* 1000 011 | w=1 | 11 | reg1(3) | reg2(3) */
        emit_byte(buffer, 0x87);
        emit_byte(buffer, (0x3 << 6) | (reg1.code << 3) | reg2.code);
}

void asm_add(asm_buf_t *buffer, reg_t src, reg_t dst)
{
        /* 0000 000 | w=1 | 11 | src(3) | dst(3) */
        emit_byte(buffer, 0x01);
        emit_byte(buffer, (0x3 << 6) | (src.code << 3) | dst.code);
}

typedef struct label_t label_t;
struct label_t {
        ssize_t target_addr;
        ssize_t instr_addr;
        bool has_target;
        bool has_instr;
};

void asm_loop(asm_buf_t *buffer, label_t *label)
{
        /* 1110 0010 | displacement(8) */

        ssize_t my_addr = buffer->size;
        ssize_t disp = 0;

        if (label->has_target) {
                disp = label->target_addr - (my_addr + 2);
        } else {
                assert(!label->has_instr && "Label already used!");
                label->instr_addr = my_addr;
                label->has_instr = true;
        }

        assert(disp >= INT8_MIN && disp <= INT8_MAX);
        emit_byte(buffer, 0xe2);
        emit_byte(buffer, (unsigned char) disp);
}

typedef struct { unsigned char code; } cc_t;
static const cc_t CC_E, CC_Z     = { 0x4 };

void asm_jcc(asm_buf_t *buffer, cc_t cc, label_t *label)
{
        /* 0000 1111 1000 | cc(4) | disp(32) */

        ssize_t my_addr = buffer->size;
        ssize_t disp = 0;

        if (label->has_target) {
                disp = label->target_addr - (my_addr + 6);
        } else {
                assert(!label->has_instr && "Label already used!");
                label->instr_addr = my_addr;
                label->has_instr = true;
        }

        assert(disp >= INT32_MIN && disp <= INT32_MAX);
        emit_byte(buffer, 0x0f);
        emit_byte(buffer, (0x8 << 4) | cc.code);
        emit_int32(buffer, (int32_t) disp);
}

/* Bind label to the current address and update any instruction that uses it. */
void bind_label(asm_buf_t *buffer, label_t *label)
{
        ssize_t disp;
        ssize_t orig_buf_size;
        ssize_t addr;

        assert(!label->has_target && "Label already bound!");

        addr = buffer->size;
        label->target_addr = addr;
        label->has_target = true;

        if (label->has_instr) {
                orig_buf_size = buffer->size;

                /* Update the jump instruction with the displacement. */
                switch (buffer->data[label->instr_addr]) {
                        case 0xe2: /* loop */
                                disp = addr - (label->instr_addr + 2);
                                assert(disp >= INT8_MIN && disp <= INT8_MAX);
                                buffer->size = label->instr_addr + 1;
                                emit_byte(buffer, (unsigned char) disp);
                                break;
                        case 0x0f: /* jcc */
                                disp = addr - (label->instr_addr + 6);
                                assert(disp >= INT32_MIN && disp <= INT32_MAX);
                                buffer->size = label->instr_addr + 2;
                                emit_int32(buffer, (int32_t) disp);
                                break;
                        default:
                                assert(0 && "Binding label to unknown jump.");
                }

                buffer->size = orig_buf_size;
        }
}

void generate_fib_function(asm_buf_t *buf)
{
        label_t end = {0, 0, 0, 0};
        label_t loop = {0, 0, 0, 0};

        asm_mov_mem_reg(buf, ESP, 4, ECX);
        asm_mov_imm32(buf, EAX, 0);
        asm_mov_imm32(buf, EDX, 1);
        asm_test(buf, ECX, ECX);
        asm_jcc(buf, CC_Z, &end);

        bind_label(buf, &loop);
        asm_xchg(buf, EAX, EDX);
        asm_add(buf, EAX, EDX);
        asm_loop(buf, &loop);

        bind_label(buf, &end);
        asm_ret(buf);
}

int main()
{
        typedef int (*func_ptr_t)(int);

        asm_buf_t buf;
        int i, x;
        func_ptr_t func;

        buf.data = alloc_exec_mem(1024);
        buf.size = 0;
        buf.capacity = 1024;

        generate_fib_function(&buf);

        func = (func_ptr_t) buf.data;

        printf("Code: ");
        for (i = 0; i < buf.size; i++) {
                printf("%02x", buf.data[i]);        
        }
        printf("\n");

        for (i = 0; i < 10; i++) {
                x = (*func)(i);
                printf("fib(%d) = %d\n", i, x);
        }

        return 0;
}
