#include #include #include #include "Set.h" struct set { struct node *tree; int size; }; struct node { int item; struct node *left; struct node *right; int height; }; static void doFree(struct node *t); static struct node *doInsert(struct node *t, int item, Set s); static struct node *newNode(int item); static struct node *avlRebalance(struct node *t); static int balance(struct node *t); static int height(struct node *t); static struct node *rotateLeft(struct node *t); static struct node *rotateRight(struct node *t); static void recalcHeight(struct node *t); static int max(int a, int b); struct node *doDelete(struct node *t, int item, Set s); static struct node *bstJoin(struct node *t1, struct node *t2); static bool doContains(struct node *t, int item); /** * Creates a new empty set */ Set SetNew(void) { Set s = malloc(sizeof(struct set)); if (s == NULL) { fprintf(stderr, "error: out of memory\n"); exit(EXIT_FAILURE); } s->tree = NULL; return s; } /** * Frees memory used by set */ void SetFree(Set s) { doFree(s->tree); free(s); } static void doFree(struct node *t) { if (t == NULL) return; doFree(t->left); doFree(t->right); free(t); } /** * Inserts an element into the set */ void SetInsert(Set s, int item) { s->tree = doInsert(s->tree, item, s); } static struct node *doInsert(struct node *t, int item, Set s) { if (t == NULL) { s->size++; return newNode(item); } if (item < t->item) { t->left = doInsert(t->left, item, s); } else if (item > t->item) { t->right = doInsert(t->right, item, s); } else { return t; } recalcHeight(t); return avlRebalance(t); } static struct node *newNode(int item) { struct node *n = malloc(sizeof(struct node)); if (n == NULL) { fprintf(stderr, "error: out of memory\n"); exit(EXIT_FAILURE); } n->item = item; n->left = NULL; n->right = NULL; n->height = 0; return n; } static struct node *avlRebalance(struct node *t) { int bal = balance(t); if (bal > 1) { if (balance(t->left) < 0) { t->left = rotateLeft(t->left); } t = rotateRight(t); } else if (bal < -1) { if (balance(t->right) > 0) { t->right = rotateRight(t->right); } t = rotateLeft(t); } return t; } static int balance(struct node *t) { return height(t->left) - height(t->right); } static int height(struct node *t) { return t == NULL ? -1 : t->height; } static struct node *rotateLeft(struct node *t) { if (t == NULL || t->right == NULL) return t; struct node *newRoot = t->right; t->right = newRoot->left; newRoot->left = t; recalcHeight(t); recalcHeight(newRoot); return newRoot; } static struct node *rotateRight(struct node *t) { if (t == NULL || t->left == NULL) return t; struct node *newRoot = t->left; t->left = newRoot->right; newRoot->right = t; recalcHeight(t); recalcHeight(newRoot); return newRoot; } static void recalcHeight(struct node *t) { t->height = 1 + max(height(t->left), height(t->right)); } static int max(int a, int b) { return a > b ? a : b; } /** * Deletes an element from the set */ void SetDelete(Set s, int item) { s->tree = doDelete(s->tree, item, s); } struct node *doDelete(struct node *t, int item, Set s) { if (t == NULL) { return NULL; } if (item < t->item) { t->left = doDelete(t->left, item, s); } else if (item > t->item) { t->right = doDelete(t->right, item, s); } else { s->size--; if (t->left == NULL || t->right == NULL) { struct node *temp = t->left != NULL ? t->left : t->right; free(t); return temp; } else { struct node *temp = t; t = bstJoin(t->left, t->right); free(temp); t->right = doDelete(t->right, t->item, s); } } recalcHeight(t); return avlRebalance(t); } static struct node *bstJoin(struct node *t1, struct node *t2) { if (t1 == NULL) { return t2; } else if (t2 == NULL) { return t1; } else { struct node *curr = t2; struct node *parent = NULL; while (curr->left != NULL) { parent = curr; curr = curr->left; } if (parent != NULL) { parent->left = curr->right; curr->right = t2; } curr->left = t1; return curr; } } /** * Checks if an element is in the set */ bool SetContains(Set s, int item) { return doContains(s->tree, item); } static bool doContains(struct node *t, int item) { if (t == NULL) { return false; } else if (item < t->item) { return doContains(t->left, item); } else if (item > t->item) { return doContains(t->right, item); } else { return true; } } /** * Returns the size of the set */ int SetSize(Set s) { return s->size; }