import itertools
import numpy as np

NUM_RECYCLING_MODULES = 4
RECYCLING_RATIO = 0.25

QUALITY_PROBABILITIES = [
    [.01, .013, .016, .019, .025],
    [.02, .026, .032, .038, .05],
    [.025, .032, .04, .047, .062]
]
PROD_BONUSES = [
    [.04, .05, .06, .07, 0.1],
    [.06, .07, .09, .11, .15],
    [.1, .13, .16, .19, .25]
]

class NoRecyclerSolver:

    def __init__(self, starting_quality, ending_quality, max_quality,\
            prod_module_bonus, quality_module_probability, enable_recycling, module_slots, additional_prod):
        self.starting_quality=starting_quality
        self.ending_quality=ending_quality
        self.max_quality=max_quality
        self.prod_module_bonus=prod_module_bonus
        self.quality_module_probability=quality_module_probability
        self.enable_recycling=enable_recycling
        self.module_slots=module_slots
        self.additional_prod=additional_prod

        self.max_quality_increase = max_quality - starting_quality
        self.end_quality_increase = ending_quality - starting_quality
        self.num_quality_items_in_solver = max_quality - starting_quality + 1
        self.num_quality_recipes_in_solver = ending_quality - starting_quality + 1
        self.num_extra_qualities = max_quality - ending_quality

    def initialize_recipe_matrix(self, frac_quality):
        frac_prod = 1-frac_quality
        q = self.module_slots * self.quality_module_probability * frac_quality
        p = 1 + self.module_slots * self.prod_module_bonus * frac_prod + self.additional_prod
        # setup recipe matrix
        X = np.zeros(self.num_quality_items_in_solver)

        X[0] = (1-q) * p

        for i in range(self.num_quality_items_in_solver-1):
            X[i] = 0.9 * 10**(-i+1) * q * p

        X[self.num_quality_items_in_solver-1] = 10**(-self.num_quality_items_in_solver+2) * q * p

        return X.reshape((self.num_quality_items_in_solver, 1))

    def solve(self, frac_quality):
        # convert to matrix for row reduction
        X = self.initialize_recipe_matrix(frac_quality)

        X_inputs = -np.ones((1, 1))
        recipes = np.block([
                [X_inputs],
                [X]
        ])

        input = np.zeros((self.num_quality_items_in_solver+1,1))
        input[0] = 1

        # every quality except the one of interest is a free item
        first_row = np.zeros((1, self.num_quality_items_in_solver))
        free_items = -np.identity(self.num_quality_items_in_solver)
        free_items = np.block([[first_row], [free_items]])
        free_items = np.delete(free_items, self.ending_quality-1, 1)

        eqs = np.block([[recipes, free_items, input]])

        goal = np.zeros(self.num_quality_items_in_solver+1)
        goal[-1-self.num_extra_qualities] = 1

        result = np.linalg.solve(eqs, goal)
        return result

    def optimize_modules(self):
        best_result = None
        best_num_input = 9999999
        possible_frac_qualities = np.linspace(1.0/self.module_slots, 1.0, num=self.module_slots)
        for frac_quality in possible_frac_qualities:
            result = self.solve(frac_quality)
            num_input = result[-1]
            if num_input < best_num_input:
                best_num_input = num_input
                best_frac_quality = frac_quality
                best_result = result
        return (best_frac_quality, best_result)
    
    def run(self):
        print('')
        print(f'optimizing production of output quality {self.ending_quality} from input quality {self.starting_quality}')
        print('')

        best_frac_quality, best_result = self.optimize_modules()
        best_num_input = best_result[-1]

        print(f'q{self.starting_quality} input per q{self.ending_quality} output: {best_num_input}')
        qual_modules = round(best_frac_quality*self.module_slots)
        prod_modules = round((1-best_frac_quality)*self.module_slots)
        print(f'optimal recipe uses {qual_modules} quality modules and {prod_modules} prod modules')
        print('')

        print(f'you also get the following byproducts for each q{self.ending_quality} output:')
        free_item_idx = 1
        for i in range(self.starting_quality, self.max_quality+1):
            if i != self.ending_quality:
                print(f'q{i} output: {best_result[free_item_idx]}')
                free_item_idx += 1

class RecyclerSolver:

    def __init__(self, starting_type, ending_type,starting_quality, ending_quality, max_quality,\
            prod_module_bonus, quality_module_probability, enable_recycling, module_slots, additional_prod):

        self.starting_type=starting_type.lower()
        self.ending_type=ending_type.lower()

        if(self.starting_type) not in ['ingredient', 'product']:
            raise ValueError('starting type must be either \'ingredient\' or \'product\'')
        if(self.ending_type) not in ['ingredient', 'product']:
            raise ValueError('ending type must be either \'ingredient\' or \'product\'')

        self.starting_quality=starting_quality
        self.ending_quality=ending_quality
        self.max_quality=max_quality
        self.prod_module_bonus=prod_module_bonus
        self.quality_module_probability=quality_module_probability
        self.enable_recycling=enable_recycling
        self.module_slots=module_slots
        self.additional_prod=additional_prod

        self.max_quality_increase = max_quality - starting_quality
        self.end_quality_increase = ending_quality - starting_quality
        self.num_quality_items_in_solver = max_quality - starting_quality + 1
        self.num_quality_recipes_in_solver = ending_quality - starting_quality + 1
        self.num_extra_qualities = max_quality - ending_quality

        self.mat_size = 2*self.num_quality_items_in_solver

    def initialize_recipe_matrix(self, frac_quality):
        frac_prod = 1-frac_quality
        q = self.module_slots * self.quality_module_probability * frac_quality
        p = 1 + self.module_slots * self.prod_module_bonus * frac_prod + self.additional_prod
        # setup recipe matrix
        X = np.zeros((self.num_quality_recipes_in_solver, self.num_quality_items_in_solver))

        for i in range(self.num_quality_recipes_in_solver-1):
            X[i,i] = (1-q[i]) * p[i]

        for i in range(0, self.num_quality_recipes_in_solver-1):
            for j in range(i+1, self.num_quality_items_in_solver-1):
                X[i,j] = 0.9 * 10**(i-j+1) * q[i] * p[i]

        for i in range(self.num_quality_recipes_in_solver-1):
            X[i, self.num_quality_items_in_solver-1] = 10**(i-self.num_quality_items_in_solver+2) * q[i] * p[i]

        X[self.num_quality_recipes_in_solver-1, self.num_quality_recipes_in_solver-1] = 1 + self.module_slots * self.prod_module_bonus + self.additional_prod
        return X.T

    def initialize_recycling_matrix(self):
        # setup recycling matrix
        r = NUM_RECYCLING_MODULES * self.quality_module_probability
        R = np.zeros((self.num_quality_recipes_in_solver-1, self.num_quality_items_in_solver))

        for i in range(self.num_quality_recipes_in_solver-1):
            R[i, i] = (1-r)

        for i in range(0, self.num_quality_recipes_in_solver-1):
            for j in range(i+1, self.num_quality_items_in_solver-1):
                R[i,j] = 0.9 * 10**(i-j+1) * r

        for i in range(self.num_quality_recipes_in_solver-1):
            R[i, self.num_quality_items_in_solver-1] = 10**(i-self.num_quality_items_in_solver+2) * r

        R *= RECYCLING_RATIO
        return R.T
    
    def initialize_input_matrix(self, num_cols):
        input = np.zeros((self.num_quality_items_in_solver, num_cols))
        for i in range(num_cols):
            input[i,i] = -1
        return input

    def solve(self, frac_quality):
        # convert to matrix for row reduction
        X = self.initialize_recipe_matrix(frac_quality)
        R = self.initialize_recycling_matrix()
        X_inputs = self.initialize_input_matrix(self.num_quality_recipes_in_solver)
        R_inputs = self.initialize_input_matrix(self.num_quality_recipes_in_solver-1)
        recipes = np.block([
                [X_inputs, R],
                [X, R_inputs]
        ])
        input = np.zeros((self.mat_size,1))
        if(self.starting_type=='ingredient'):
            input[0] = 1
        elif(self.starting_type=='product'):
            input[self.num_quality_items_in_solver] = 1

        free_items = np.zeros((self.num_quality_items_in_solver*2, self.num_extra_qualities*2))
        for i in range(self.num_extra_qualities):
            free_items[self.num_quality_recipes_in_solver+i, 2*i] = -1
            free_items[self.num_quality_items_in_solver+self.num_quality_recipes_in_solver+i, 2*i+1] = -1

        eqs = np.block([[recipes, free_items, input]])

        goal = np.zeros(self.mat_size)
        if(self.ending_type=='ingredient'):
            goal[self.num_quality_items_in_solver-1-self.num_extra_qualities] = 1
        if(self.ending_type=='product'):
            goal[-1-self.num_extra_qualities] = 1

        result = np.linalg.solve(eqs, goal)
        return result

    def optimize_modules(self):
        best_result = None
        best_num_input = 9999999
        possible_frac_qualities = np.linspace(0, 1.0, num=self.module_slots+1)
        for frac_quality in itertools.product(possible_frac_qualities, repeat=self.end_quality_increase):
            frac_quality = np.array(frac_quality)
            try:
                result = self.solve(frac_quality)
            except np.linalg.LinAlgError as e:
                continue
            num_input = result[-1]
            if num_input < best_num_input:
                best_num_input = num_input
                best_frac_quality = frac_quality
                best_result = result
        return (best_frac_quality, best_result)
    
    def run(self):
        print('')
        print(f'optimizing recycling loop that turns {self.starting_type} quality {self.starting_quality} into {self.ending_type} quality {self.ending_quality}')
        print('')

        best_frac_quality, best_result = self.optimize_modules()
        best_num_input = best_result[-1]

        # note that input/output qualities used start at 1 but the code starts at 0 for indexing
        print(f'q{self.starting_quality} input per q{self.ending_quality} output: {best_num_input}')
        for i in range(self.starting_quality, self.ending_quality):
            qual_modules = round(best_frac_quality[i-1]*self.module_slots)
            prod_modules = round((1-best_frac_quality[i-1])*self.module_slots)
            print(f'recipe q{i} uses {qual_modules} quality modules and {prod_modules} prod modules')
        print(f'recipe q{self.ending_quality} uses 0 quality modules and {self.module_slots} prod modules')

        if(self.num_extra_qualities > 0):
            print('')
            print(f'as an additional bonus you get the following for each q{self.ending_quality} output:')
            free_item_results = best_result[-(self.num_extra_qualities*2)-1:-1:]
            for i in range(self.num_extra_qualities):
                print(f'q{self.max_quality-self.num_extra_qualities+i+1} ingredient: {free_item_results[i*2]}')
                print(f'q{self.max_quality-self.num_extra_qualities+i+1} output: {free_item_results[i*2+1]}')

if __name__ == '__main__':
    s = RecyclerSolver(
            starting_type='ingredient',
            ending_type='product',
            starting_quality=1,
            ending_quality=5,
            max_quality=5,
            prod_module_bonus=0.25,
            quality_module_probability=.062,
            enable_recycling=True,
            module_slots=4,
            additional_prod=0,
    )
    s.run()