#include "common.h"
#include "heap.h"

static void swap_down(HEAP *h, int k);


/*---------------------------------------------------------------------------*/
/* "Constructors" and "destructors                                           */
/*---------------------------------------------------------------------------*/
int init_heap(HEAP *h, unsigned int capacity, size_t element_size,
               COMPARATOR comparator) {
    h->element_size = element_size;
    h->num_allocated = capacity;
    h->num_used = 0;
    if (NULL == (h->data = malloc(capacity * element_size)) ||
        NULL == (h->index_in_heap = Malloc(capacity, int)) ||
        /* 1-based indexing => need 1 extra element in h->heap */
        NULL == (h->heap = Malloc(capacity+1, unsigned int))) {
        fprintf(stderr, "init_heap: out of memory");
        return -1;
    }
    h->comparator = comparator;
    h->free_data = true;

    return 0;
}

int build_heap(HEAP *h, void *data, unsigned int num_elements,
                size_t element_size, COMPARATOR comparator) {
    h->element_size = element_size;
    h->num_allocated = h->num_used = num_elements;
    h->data = data;
    if (NULL == (h->index_in_heap = Malloc(num_elements, int)) ||
        /* 1-based indexing => need 1 extra element in h->heap */
        NULL == (h->heap = Malloc(num_elements+1, unsigned int))) {
        fprintf(stderr, "build_heap: out of memory");
        return -1;
    }
    h->comparator = comparator;
    h->free_data = false;

    for (unsigned int i = 0; i < num_elements; i++) {
        h->index_in_heap[i] = i+1; /* data[0] will correspond to heap[1] */
        h->heap[i+1] = i;          /* and vice versa */
    }

    /* Heapify */
    for (int k = h->num_used / 2; k >= 1; k--)
        swap_down(h, k);
    return 0;
}

void free_heap(HEAP *h) {
    free(h->index_in_heap); free(h->heap);
    if (h->free_data) free(h->data);
    return;
}


/*---------------------------------------------------------------------------*/
/* The crucial swap routine                                                  */
/*---------------------------------------------------------------------------*/

static void swap(HEAP *h, int i, int j)
{
    unsigned int tmp;
    assert(h->index_in_heap[h->heap[i]] == i &&
           h->index_in_heap[h->heap[j]] == j);
    h->index_in_heap[h->heap[i]] = j;
    h->index_in_heap[h->heap[j]] = i;
    tmp = h->heap[i], h->heap[i] = h->heap[j], h->heap[j] = tmp;
    return;
}


/*---------------------------------------------------------------------------*/
/* Insert routine                                                            */
/*---------------------------------------------------------------------------*/

static void swap_up(HEAP *h, int k)
{
    while (k > 1 && (h->comparator(h, k/2, k) < 0)) {
        swap(h, k, k/2);
        k = k/2;
    }
    return;
}

int insert(HEAP *h, void *x)
{
    /* First, make sure there's space for another element */
    if (h->num_used == h->num_allocated) {
        void *tmp1 = h->data;
        int *tmp2 = h->index_in_heap;
        unsigned int *tmp3 = h->heap;
        h->num_allocated *= 2;
        if (NULL == (tmp1 = realloc(h->data, h->num_allocated * h->element_size)) ||
            NULL == (tmp2 = Realloc(h->index_in_heap, h->num_allocated, int)) ||
            NULL == (tmp3 = Realloc(h->heap, h->num_allocated+1, unsigned int))) {
            fprintf(stderr, "insert: out of memory");
            free(h->data); free(h->index_in_heap); free(h->heap);
            return -1;
        }
        h->data = tmp1;
        h->index_in_heap = tmp2;
        h->heap = tmp3;
    }
    /* Insert element at end */
    memcpy((char *) h->data + h->num_used * h->element_size, x,
           h->element_size);
    h->index_in_heap[h->num_used] = h->num_used + 1;
    h->heap[h->num_used+1] = h->num_used;
    h->num_used++;

    /* Restore heap property */
    swap_up(h, h->num_used);
    return 0;
}

/*---------------------------------------------------------------------------*/
/* Delete routine                                                            */
/*---------------------------------------------------------------------------*/

static void swap_down(HEAP *h, int k)
{
    while (2*k <= h->num_used) {
        int j = 2*k;
        /* choose child with larger key */
        if (j < h->num_used && (h->comparator(h, j, j+1) < 0))
            j++;
        if (h->comparator(h, k, j) >= 0) break;
        swap(h, k, j);
        k = j;
    }
    return;
}

int delete_max(HEAP *h)
{
    int index_in_data = h->heap[1]; /* max is at root (index 1) of heap */

    /* Delete this element from the heap */
    h->index_in_heap[h->heap[1]] = -1;

    /* Copy last element to root */
    h->heap[1] = h->heap[h->num_used];
    h->index_in_heap[h->heap[1]] = 1;
    h->num_used--;
    
    /* Restore heap property */
    swap_down(h, 1);
    return index_in_data;
}


/*---------------------------------------------------------------------------*/
/* increase_key                                                              */
/*---------------------------------------------------------------------------*/

void increase_key(HEAP *h, int index_in_data) {
    /* complete this function */
}


/*---------------------------------------------------------------------------*/
/* Heapsort                                                                  */
/*---------------------------------------------------------------------------*/

void heapsort(HEAP *h)
{
    unsigned int saved_num_used = h->num_used;

    /* Sort by successive deleteMax */
    while (h->num_used > 1) {
        swap(h, 1, h->num_used); // move max to end
        h->num_used--;
        swap_down(h, 1);
    }
    h->num_used = saved_num_used;
    return;
}
