diff --git a/ln2sql/database.py b/ln2sql/database.py index 1991d7b..d7ece87 100755 --- a/ln2sql/database.py +++ b/ln2sql/database.py @@ -83,15 +83,16 @@ def _generate_path(path): def load(self, path): with open(self._generate_path(path)) as f: content = f.read() + flag=re.search("(\w+)", content).group(0)=='PostgreSQL' tables_string = [p.split(';')[0] for p in content.split('CREATE') if ';' in p] for table_string in tables_string: if 'TABLE' in table_string: - table = self.create_table(table_string) + table = self.create_table(table_string,flag) self.add_table(table) alter_tables_string = [p.split(';')[0] for p in content.split('ALTER') if ';' in p] for alter_table_string in alter_tables_string: if 'TABLE' in alter_table_string: - self.alter_table(alter_table_string) + self.alter_table(alter_table_string, flag) def predict_type(self, string): if 'int' in string.lower(): @@ -103,21 +104,31 @@ def predict_type(self, string): else: return 'unknow' - def create_table(self, table_string): + def create_table(self, table_string, flag): lines = table_string.split("\n") table = Table() for line in lines: if 'TABLE' in line: - table_name = re.search("`(\w+)`", line) - table.name = table_name.group(1) + if flag: + table_name = re.search(r'(?<=public.)\w+', line) + table.name = table_name.group(0) + else: + table_name = re.search("`(\w+)`", line) + table.name = table_name.group(1) if self.thesaurus_object is not None: table.equivalences = self.thesaurus_object.get_synonyms_of_a_word(table.name) elif 'PRIMARY KEY' in line: - primary_key_columns = re.findall("`(\w+)`", line) + if flag: + primary_key_columns = re.findall("PRIMARY KEY \((\w+)\)", line) + else: + primary_key_columns = re.findall("`(\w+)`", line) for primary_key_column in primary_key_columns: table.add_primary_key(primary_key_column) else: - column_name = re.search("`(\w+)`", line) + if flag: + column_name = re.search("(\w+)", line) + else: + column_name = re.search("`(\w+)`", line) if column_name is not None: column_type = self.predict_type(line) if self.thesaurus_object is not None: @@ -127,22 +138,33 @@ def create_table(self, table_string): table.add_column(column_name.group(1), column_type, equivalences) return table - def alter_table(self, alter_string): + def alter_table(self, alter_string, flag): lines = alter_string.replace('\n', ' ').split(';') for line in lines: if 'PRIMARY KEY' in line: - table_name = re.search("TABLE `(\w+)`", line).group(1) - table = self.get_table_by_name(table_name) - primary_key_columns = re.findall("PRIMARY KEY \(`(\w+)`\)", line) + if flag: + table_name = re.search(r'(?<=public.)\w+', line).group(0) + table = self.get_table_by_name(table_name) + primary_key_columns = re.findall("PRIMARY KEY \((\w+)\)", line) + else: + table_name = re.search("TABLE `(\w+)`", line).group(1) + table = self.get_table_by_name(table_name) + primary_key_columns = re.findall("PRIMARY KEY \(`(\w+)`\)", line) for primary_key_column in primary_key_columns: table.add_primary_key(primary_key_column) elif 'FOREIGN KEY' in line: - table_name = re.search("TABLE `(\w+)`", line).group(1) - table = self.get_table_by_name(table_name) - foreign_keys_list = re.findall("FOREIGN KEY \(`(\w+)`\) REFERENCES `(\w+)` \(`(\w+)`\)", line) + if flag: + table_name = re.search(r'(?<=public.)\w+', line).group(0) + table = self.get_table_by_name(table_name) + foreign_keys_list = re.findall("FOREIGN KEY \((\w+)\) REFERENCES public.(\w+)\((\w+)\)", line) + else: + table_name = re.search("TABLE `(\w+)`", line).group(1) + table = self.get_table_by_name(table_name) + foreign_keys_list = re.findall("FOREIGN KEY \(`(\w+)`\) REFERENCES `(\w+)` \(`(\w+)`\)", line) for column, foreign_table, foreign_column in foreign_keys_list: table.add_foreign_key(column, foreign_table, foreign_column) + def print_me(self): for table in self.tables: print('+-------------------------------------+')