diff --git a/src/database/sql_database.py b/src/database/sql_database.py index ce1ea9d..ecccf7e 100644 --- a/src/database/sql_database.py +++ b/src/database/sql_database.py @@ -9,6 +9,31 @@ class SQL_Database: sql_type_list = [int, float, str] + def default_config_id(method): + """Decorator used to set the default config_id""" + def new_method(self, config_id = None): + input_id = self.config_id if config_id is None else config_id + return method(self, input_id) + return new_method + + + def format_query_data(method): + """Decorator used to format query data""" + def new_method(self, config_id): + query = method(self, config_id) + + # Unpack elements if they are lists with only 1 element + if len(query[0]) == 1: + query = [i[0] for i in query] + + # Unpack list if it is a list with only 1 element + if len(query) == 1: + query = query[0] + + return query + return new_method + + def __init__(self): self.conn = None self.config_id = None @@ -171,6 +196,7 @@ class SQL_Database: return cur.lastrowid + @format_query_data def query_all(self, query): """Query for muliple rows of data""" @@ -180,6 +206,7 @@ class SQL_Database: return cur.fetchall() + @format_query_data def query_one_item(self, query): """Query for single data point""" @@ -187,7 +214,7 @@ class SQL_Database: cur.execute(query) query_data = cur.fetchone() - return query_data[0] + return query_data def past_runs(self): @@ -198,20 +225,6 @@ class SQL_Database: print(query_data) - def default_config_id(method): - """Decorator used to set the default config_id""" - def new_method(self, config_id = None): - input_id = self.config_id if config_id is None else config_id - return method(self, input_id) - return new_method - - - def format_query_data(method): - """Decorator used to format query data""" - return lambda self, config_id:\ - [i[0] for i in method(self, config_id)] - - def get_most_recent_config_id(self): """Function to get the most recent config_id from the database.""" @@ -219,7 +232,6 @@ class SQL_Database: @default_config_id - @format_query_data def get_generation_total_fitness(self, config_id): """Get each generations total fitness sum from the database """ @@ -234,7 +246,6 @@ class SQL_Database: @default_config_id - @format_query_data def get_highest_chromosome(self, config_id): """Get the highest fitness of each generation""" @@ -242,7 +253,6 @@ class SQL_Database: @default_config_id - @format_query_data def get_lowest_chromosome(self, config_id): """Get the lowest fitness of each generation"""