from Cimpl import *
from simple_Cimpl_filters import grayscale

def red_channel(original_image: Image) -> Image:
    new_image = copy(original_image)
    for pixel in original_image:
        (x, y, (r, g, b)) = pixel
        new_colour = create_color(r,0,0)
        set_color(new_image, x, y, new_colour)
    return new_image


def Green_channel(First_image:Image)->Image:
    green_channel=copy(First_image)
    for pixel in First_image:
        x,y, (r,g,b)=pixel
        new_color=create_color(0, g ,0)
        set_color(green_channel, x,y, new_color)
    return green_channel


def blue_channel(image: Image) -> Image:

    new_image = copy(image)

    for x, y, (r, g, b) in image:
        blue = create_color(0, 0 ,b)
        set_color(new_image, x, y, blue)

    return new_image


def combine(rI: Image, gI: Image, bI: Image) -> Image:
    new_image = copy(rI)
    for pixel in new_image:
        x,y, (red_val,g,b) = pixel
        green_val = get_color(gI, x, y)[1]
        blue_val = get_color(bI, x, y)[2]
        new_colour = create_color(red_val, green_val, blue_val)
        set_color(new_image, x, y, new_colour)
    return new_image

def two_tone(image: Image, tone1: str, tone2: str) -> Image:

    colour_dict = {'black': (0,0,0), 'white': (255,255,255), 'red': (255,0,0),'lime': (0,255,0),'blue': (0,0,255),
    'yellow': (255,255,0),'cyan': (0,255,255),'magenta': (255,0,255),'gray': (128,128,128)}
    tone1 = tone1.lower()
    tone2 = tone2.lower()
    if tone1 in colour_dict and tone2 in colour_dict:
        new_image = copy(image)
        for pixel in new_image:
            x,y, (r,g,b) = pixel
            if (r+g+b)//3 < 128:
                new_colour = create_color(*colour_dict[tone1])
            else:
                new_colour = create_color(*colour_dict[tone2])
            set_color(new_image,x,y,new_colour)
        return new_image
    else:
        print("Invalid input of colours")
        return None

def three_tone(image: Image, tone1: str, tone2: str, tone3: str) -> Image:
    colour_dict = {'black': (0,0,0), 'white': (255,255,255), 'red': (255,0,0),'lime': (0,255,0),'blue': (0,0,255),
    'yellow': (255,255,0),'cyan': (0,255,255),'magenta': (255,0,255),'gray': (128,128,128)}
    tone1 = tone1.lower()
    tone2 = tone2.lower()
    tone3 = tone3.lower()
    if tone1 in colour_dict and tone2 in colour_dict:
        new_image = copy(image)
        for pixel in new_image:
            x,y, (r,g,b) = pixel
            avg = (r+g+b)//3
            if avg < 84:
                new_colour = create_color(*colour_dict[tone1])
            elif avg < 170:
                new_colour = create_color(*colour_dict[tone2])
            else:
                new_colour = create_color(*colour_dict[tone3])
            set_color(new_image,x,y,new_colour)
        return new_image
    else:
        print("Invalid input of colours")
        return None

def _adjust_component(value: int) -> int:

    if value >= 192:
        return 223
    if value >= 128:
        return 159
    if value >= 64:
        return 95
    return 31 

def posterize(image: Image) -> Image:

    new_image = copy(image)

    for x, y, (r, g, b) in image:
        newCol = create_color(_adjust_component(r), _adjust_component(g), _adjust_component(b))
        set_color(new_image, x, y, newCol)

    return new_image

def sepia(image: Image) -> Image:
    new_image = grayscale(copy(image))
    COLOUR_DICT = {63: [1.1, 0.9], 192: [1.15, 0.85], 255: [1.08, 0.93]}

    for x, y, (r, g, b) in new_image:
        pixel_check = False
        for border in COLOUR_DICT:
            if r < border and not pixel_check:
                pixel_check = True
                r_val = r * COLOUR_DICT[border][0]
                b_val = b * COLOUR_DICT[border][1]
                set_color(new_image, x, y, create_color(r_val, g, b_val))
    return new_image

def color(value: int) -> int:

    if value > 127:
        return 255
    if value <= 127:
        return 0
    
def extreme_contrast(image: Image) -> Image:

    new_image = copy(image)

    for x, y, (r, g, b) in image:
        newCol = create_color(color(r), color(g), color(b))
        set_color(new_image, x, y, newCol)

    return new_image

def detect_edges(image: Image, threshold: float) -> Image:
    new_image = copy(image)
    height = get_height(new_image)
    WHITE = create_color(255, 255, 255)
    BLACK = create_color(0, 0, 0)

    for x, y, colour in new_image:
        if y == height-1:
            set_color(new_image, x, y, WHITE)
        else:
            avg_top = sum(colour) / 3
            avg_bot = sum(get_color(new_image, x, y + 1)) / 3
            if abs(avg_top - avg_bot) > threshold: set_color(new_image, x, y, BLACK)
            else: set_color(new_image, x, y, WHITE)
    return new_image

def detect_edges_better(image: Image, threshold: float) -> Image:
    new_image = copy(image)
    height, width = get_height(new_image), get_width(new_image)
    SHADES = (create_color(255,255,255), create_color(0,0,0)) # (white, black)
    
    for x,y, colour in new_image:
        if y == height-1 or x == width-1:
            shade_num = 0
        else:
            avg_top = sum(colour) / 3
            avg_bot = sum(get_color(new_image, x, y + 1)) / 3
            avg_right = sum(get_color(new_image, x + 1, y)) / 3

            if (abs(avg_top - avg_bot) > threshold or 
                    abs(avg_top - avg_right) > threshold): 
                shade_num = 1
            else:
                shade_num = 0

        set_color(new_image, x, y, SHADES[shade_num])
    return new_image

def flip_vertical (image: Image) -> Image:

    new_image = copy(image)

    for x in range(get_width(image)):
        for y in range(get_height(image)):
            set_color(new_image, x, y, get_color(image, get_width(image)-x-1, y))

    return new_image

def flip_horizontal(image:Image) -> Image:
    
    new_image = copy(image)

    for x in range(get_width(image)):
        for y in range(get_height(image)):
            set_color(new_image, x, y, get_color(image, x, get_height(image)-y-1))

    return new_image
