diff --git a/EasyGA/database/sql_database.py b/EasyGA/database/sql_database.py index cc0344d..6c1d39b 100644 --- a/EasyGA/database/sql_database.py +++ b/EasyGA/database/sql_database.py @@ -1,5 +1,5 @@ import sqlite3 - +from typing import get_type_hints from tabulate import tabulate class SQL_Database: @@ -42,8 +42,17 @@ class SQL_Database: - def insert_config(self,ga): - """Insert the configuration attributes into the config.""" + def insert_config(self, ga): + """ + Insert the configuration attributes into the config. + + Notes: + + "Attributes" here refers to ga.__annotations__.keys(), + which allows the attributes to be customized. + + Only attributes that are bool, float, int, or str will be used. + """ # Get the current config and add one for the new config key self.config_id = self.get_current_config() @@ -55,33 +64,18 @@ class SQL_Database: self.config_id = self.config_id + 1 # Getting all the attributes from the attributes class - db_config_dict = ( - (attr_name, getattr(ga, attr_name)) + db_config = [ + (self.config_id, attr_name, attr_value) for attr_name in ga.__annotations__ - if attr_name != "population" - ) - - # Types supported in the database - sql_type_list = [int, float, str] - - # Loop through all attributes - for name, value in db_config_dict: - - # not a function - if not callable(value): - - # Convert to the right type - value = str(value) - - if "'" not in value and '"' not in value: - - # Insert into database - self.conn.execute(f""" - INSERT INTO config(config_id, attribute_name, attribute_value) - VALUES ('{self.config_id}', '{name}','{value}');""") - + if isinstance((attr_value := getattr(ga, attr_name)), (bool, float, int, str)) + ] + query = f""" + INSERT INTO config(config_id, attribute_name, attribute_value) + VALUES (?, ?, ?); + """ + self.conn.executemany(query, db_config) self.config_id = self.get_current_config() @@ -197,7 +191,7 @@ class SQL_Database: SELECT DISTINCT config_id FROM config;""") - def get_each_generation_number(self,config_id): + def get_each_generation_number(self, config_id): """Get an array of all the generation numbers""" return self.query_all(f""" @@ -305,16 +299,6 @@ class SQL_Database: os.remove(self._database_name) - def get_var_names(self, ga): - """Returns a list of the names of attributes of the ga.""" - - # Loop through all attributes - for var in ga.__dict__.keys(): - - # Remove leading underscore - yield (var[1:] if (var[0] == '_') else var) - - #=====================================# # Setters and Getters: # #=====================================#