import cv2
import numpy as np
import imutils


class AnswerSheet:

    def __init__(self, image, numberofquestions=100):
        """Takes an image file location and analyzes the image"""
        self.ImageHandler = ImageHandler(image)
        self.ImageHandler.rotate_image_to_horizontal()
        self.LocationManager = LocationHandler(self.ImageHandler)
        self.numberofquestions = numberofquestions

    def answers(self):
        """returns the filled in answers"""
        return AnswerCompiler(self.LocationManager).answers_by_number(
            self.numberofquestions)

    def id(self):
        """returns the student ID"""
        return AnswerCompiler(self.LocationManager).id_number()


# noinspection PyUnresolvedReferences
class ImageHandler:

    def __init__(self, image):
        self.image = image
        self.height = 0
        self.width = 0
        self.height_ratio_to_known_size = 1
        self.width_ratio_to_known_size = 1
        self.contours = []
        self.boundingRectangles = []
        self.binaryimage = self._convert_to_binary_colors()
        self._set_dimensions_and_ratios()
        self._locate_marks()
        self._bounding_rectangles()

    def _convert_to_binary_colors(self):
        b, g, r = cv2.split(self.image)
        blur = cv2.GaussianBlur(r, (3, 3), 0)
        ret, thresh = cv2.threshold(blur, 0, 255,
                                    cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
        return thresh

    def rotate_image_to_horizontal(self):
        if self.height > self.width:
            self.image = imutils.rotate_bound(self.image, -90)
            self.update()

    def resize(self, ratiotooriginalsize):
        self.height = self.height // int(ratiotooriginalsize ** -1)
        self.width = self.width // int(ratiotooriginalsize ** -1)
        self.image = cv2.resize(self.image, (self.width, self.height))
        self.update()

    def _set_dimensions_and_ratios(self):
        self.height = np.size(self.image, 0)
        self.width = np.size(self.image, 1)
        self.height_ratio_to_known_size = self.height / 1700
        self.width_ratio_to_known_size = self.width / 2200

    def _locate_marks(self):
        """finds all black marks"""
        self.contours = cv2.findContours(
            self.binaryimage,
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )[1]

    def _bounding_rectangles(self):
        """returns the best fit rectangle for a contour"""
        self.boundingRectangles = [cv2.boundingRect(c) for c in self.contours]

    def update(self):
        """relocates all marks"""
        self.contours = []
        self.binaryimage = self._convert_to_binary_colors()
        self._set_dimensions_and_ratios()
        self._locate_marks()
        self._bounding_rectangles()


class RectangleConverter:

    def __init__(self, locationhandler):
        self.LH = locationhandler
        self.FG = FunctionGenerator(locationhandler)
        self.RV = RectangleValidator()
        self._answernumberrectangles = self.LH.questionnumbermarks
        self._answernumberrectangles = sorted(self._answernumberrectangles,
                                              key=lambda x: x[0])

    def convert_answer_rectangle_to_question_number_column(self, rectangle):
        row = self._row(rectangle)
        column = self._column(rectangle)
        if column is None:
            column = 10
        if row is not None and row > 10:
            answernumber = row - 10 + 50 * (column // 5)
        else:
            answernumber = 0
        return answernumber, column

    def convert_id_rectangles_to_row_column(self, rectangle):
        row = self._row(rectangle)
        if row is None:
            row = 0
        column = self._column(rectangle)
        return row, column

    def _column(self, rectangle):
        """determines the column a rectangle is on"""
        fx = self.FG.generate_multiple_horizontal_line_functions()
        for f in range(len(fx)):
            if self.RV.rectangleonfunction(fx[f], rectangle, 0):
                return f

    def _row(self, rectangle):
        """determines the row a rectangle is on"""
        fys = self.FG.generate_multiple_vertical_line_functions()
        for row, fy in enumerate(fys):
            if self.RV.rectangleonfunction(fy, rectangle, 1):
                return row

    def all_answers_number_column(self):
        """returns the answer boxes row and column"""
        number_column = []
        for rectangle in self.LH.answers:
            number_column.append(
                self.convert_answer_rectangle_to_question_number_column(
                    rectangle))
        return number_column

    def all_id_number_column(self):
        """returns the id boxes row and column"""
        number_column = []
        for rectangle in self.LH.idmarkings:
            number_column.append(
                self.convert_id_rectangles_to_row_column(
                    rectangle))
        return number_column


class AnswerCompiler:

    def __init__(self, locationhandler):
        self.LH = locationhandler
        self.columnnumbertoanswer = ["A", "B", "C", "D", "E",
                                     "A", "B", "C", "D", "E", tuple()]
        self.idmarktonumber = ([("",)] * 21 + [(i,) for i in range(10)]
                               + [("",)] * 30)
        self.MQN = RectangleConverter(self.LH)
        self.answer_number_column = self.MQN.all_answers_number_column()
        self.id_number_column = self.MQN.all_id_number_column()
        self.answers = []

    def answers_by_number(self, numberofquestions=100):
        answers = [""] * (numberofquestions + 1)
        number_column = sorted(self.answer_number_column, key=lambda x: x[0])
        for number, column in number_column:
            if numberofquestions > number > 0:
                answers[number] += self.columnnumbertoanswer[column]
        return answers

    def id_number(self):
        self.id_number_column = sorted(self.id_number_column,
                                       key=lambda x: x[1])
        return "".join([str(self.idmarktonumber[n][0])
                        for n, c in self.id_number_column])


# noinspection PyUnresolvedReferences
class LocationHandler:

    def __init__(self, imagehandler):
        self.image = imagehandler
        self.boxes = []
        self.answers = []
        self.questionnumbermarks = []
        self.basicsheetmarks = []
        self.idmarkings = []
        self.upperleft = (self.image.width, self.image.height, 0, 0)
        self.lowerleft = (0, 0, 0, 0)
        self.lowerright = (0, 0, 0, 0)
        self.spacingbox = (0, 0, 0, 0)
        self._find_marks_of_usable_size()
        self._find_corner_rectangles()
        self.FG = FunctionGenerator(self)
        self._find_basic_sheet_marks()
        self.FG = FunctionGenerator(self)
        self._find_answer_marks()
        self._find_id_set_marks()

    def _find_marks_of_usable_size(self):
        for (x, y, w, h) in self.image.boundingRectangles:
            if self._is_box_correct_size(w, h):
                self.boxes.append((x, y, w, h))

    def _is_box_correct_size(self, w, h):
        widthratio = self.image.width_ratio_to_known_size
        heightratio = self.image.height_ratio_to_known_size
        minwidth = 3 * widthratio
        maxwidth = 24 * widthratio
        minheight = 16 * heightratio
        maxheight = 38 * heightratio

        return minwidth <= w <= maxwidth and minheight <= h <= maxheight

    def _find_corner_rectangles(self):
        """finds corners when oriented horizontally"""
        for i in self.boxes:
            self.upperleft = (i if self._distance_from_origin(i) <
                              self._distance_from_origin(
                self.upperleft) else
                self.upperleft)
            self.lowerright = (i if self._distance_from_origin(i) >
                               self._distance_from_origin(
                self.lowerright) else
                self.lowerright)
            self.lowerleft = (
                i if self._distance_between_boxes(i,
                                                  (0, self.image.height, 0, 0))
                < self._distance_between_boxes(self.lowerleft,
                                               (0, self.image.height, 0,
                                                0))
                else self.lowerleft)

    def _find_answer_marks(self):
        """adds all none common sheet marks to the answer list"""
        for box in self.boxes:
            if box not in self.basicsheetmarks:
                self.answers.append(box)

    def _find_basic_sheet_marks(self):
        """Finds all the marks common to every sheet"""
        RV = RectangleValidator().rectangleonfunction
        horizontalfx = self.FG.generate_horizontal_line_function()
        verticalfy = self.FG.generate_vertical_line_function()
        for box in self.boxes:
            if RV(horizontalfx, box, 0):
                self.basicsheetmarks.append(box)
                self.questionnumbermarks.append(box)
            elif RV(verticalfy, box, 1):
                if box not in [self.upperleft, self.lowerleft]:
                    self.spacingbox = box
                self.basicsheetmarks.append(box)

    def _find_id_set_marks(self):
        """Locates marks referring to the student id and removes them
        from the list of answers"""
        RV = RectangleValidator().rectangleonfunction
        horizontalfxs = self.FG.generate_multiple_horizontal_line_functions()
        idfxs = horizontalfxs[10:]
        for box in self.answers:
            for fx in idfxs:
                if RV(fx, box, 0):
                    self.idmarkings.append(box)
        for idmark in self.idmarkings:
            self.answers.remove(idmark)

    def _distance_from_origin(self, box):
        (x1, y1, w1, h1) = box
        return np.sqrt(x1 ** 2 + y1 ** 2)

    def _distance_between_boxes(self, box1, box2):
        (x1, y1, w1, h1) = box1
        (x2, y2, w2, h2) = box2
        return np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)


class RectangleValidator:

    def rectangleonfunction(self, func, rectangle, indexforinputtofunction):
        """Checks to see if the rectangle lies on the function
        3rd parameter refers to the index of the rectangle to
        be inserted into the function"""
        index = indexforinputtofunction
        # converts between index for x and index for y
        i = (index + 1) % 2
        return (rectangle[i] <= func(rectangle[index]) <=
                rectangle[i] + rectangle[i - 2])


class FunctionGenerator:

    def __init__(self, locationhandler):
        self.LH = locationhandler
        self.upperleft = self.LH.upperleft
        self.lowerleft = self.LH.lowerleft
        self.lowerright = self.LH.lowerright
        self.spacingbox = self.LH.spacingbox
        self.numberboxes = sorted(self.LH.questionnumbermarks,
                                  key=lambda x: x[0])

    def generate_vertical_line_function(self):
        """returns a function corresponding to the
        to corner black marks on the top of the page"""
        def fy(y):
            return int(self._y_slope(self.lowerleft, self.upperleft) *
                       (y - self.upperleft[1])
                       + (self.upperleft[0] + self.upperleft[-2] / 2))

        return fy

    def generate_multiple_vertical_line_functions(self):
        """Returns functions corresponding to the
        black boxes on the side of the image"""

        def function_creator(rect):
            return lambda y: int(
                self._y_slope(self.lowerleft, self.upperleft) *
                (y - rect[1]) +
                (rect[0] + rect[-2] / 2))

        fys = []
        for rectangle in self.numberboxes:
            fys.append(function_creator(rectangle))
        return fys

    def generate_horizontal_line_function(self):
        """returns a function corresponding to the
        to corner black marks on the left side of the page"""
        def fx(x):
            return int(self._x_slope(self.lowerleft, self.lowerright) *
                       (x - self.lowerright[0]) +
                       (self.lowerright[1] + self.lowerright[-1] / 2))

        return fx

    def generate_multiple_horizontal_line_functions(self):
        """Returns functions corresponding to the
        answer columns"""
        delta_y = self._find_delta_y()
        secondrow = self._find_second_set_of_answers()
        thirdrow = self._find_id_set()

        def function_creator(func, j):
            if j < 7:
                return lambda x: int(func(x) - delta_y * j)
            elif j < 12:
                return lambda x: int(func(x) - delta_y * (j - 7) - secondrow)
            else:
                return lambda x: int(func(x) - delta_y * (j - 11) - thirdrow)

        funcs = []
        for i in range(2, 21):
            fx = self.generate_horizontal_line_function()
            funcs.append(function_creator(fx, i))
        return funcs

    def _find_delta_y(self):
        """determines the spacing between answer columns"""
        return self.lowerleft[1] - self.spacingbox[1]

    def _find_second_set_of_answers(self):
        return (self.lowerleft[1] - self.upperleft[1]) * (
            (64 * 2 - 1) / (162 * 2))

    def _find_id_set(self):
        return (self.lowerleft[1] - self.upperleft[1]) * (
            25 / 41)

    def _x_slope(self, r1, r2):
        (x1, y1, w1, h1) = r1
        (x2, y2, w2, h2) = r2
        try:
            return (y2 - y1) / (x2 - x1)
        except ZeroDivisionError:
            return 0

    def _y_slope(self, r1, r2):
        (x1, y1, w1, h1) = r1
        (x2, y2, w2, h2) = r2
        try:
            return (x2 - x1) / (y2 - y1)
        except ZeroDivisionError:
            return 0
