From 9eddcf9e58c8e1f39ba771f7a4b1bffa27ff4364 Mon Sep 17 00:00:00 2001 From: Dmitry Spikhalsky Date: Fri, 31 Oct 2025 00:03:10 +0000 Subject: [PATCH] Make Database.tables and Table.columns read-only to prevent inconsistent state --- pydbml/_classes/table.py | 24 +++++++++++++++--------- pydbml/database.py | 20 +++++++++++++------- pydbml/renderer/sql/default/utils.py | 4 ++-- test/test_classes/test_table.py | 6 +++++- test/test_database.py | 3 +++ 5 files changed, 38 insertions(+), 19 deletions(-) diff --git a/pydbml/_classes/table.py b/pydbml/_classes/table.py index b8f340d..08939b8 100644 --- a/pydbml/_classes/table.py +++ b/pydbml/_classes/table.py @@ -2,6 +2,7 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING +from typing import Tuple from typing import Union from pydbml.exceptions import ColumnNotFoundError @@ -38,7 +39,7 @@ def __init__(self, self.database: Optional[Database] = None self.name = name self.schema = schema - self.columns: List[Column] = [] + self._columns: List[Column] = [] for column in columns or []: self.add_column(column) self.indexes: List[Index] = [] @@ -51,6 +52,11 @@ def __init__(self, self.abstract = abstract self.properties = properties if properties else {} + @property + def columns(self) -> Tuple[Column, ...]: + """Returns a read-only tuple of columns.""" + return tuple(self._columns) + @property def note(self): return self._note @@ -75,18 +81,18 @@ def add_column(self, c: Column) -> None: if not isinstance(c, Column): raise TypeError('Columns must be of type Column') c.table = self - self.columns.append(c) + self._columns.append(c) def delete_column(self, c: Union[Column, int]) -> Column: if isinstance(c, Column): - if c in self.columns: + if c in self._columns: c.table = None - return self.columns.pop(self.columns.index(c)) + return self._columns.pop(self._columns.index(c)) else: raise ColumnNotFoundError(f'Column {c} if missing in the table') elif isinstance(c, int): - self.columns[c].table = None - return self.columns.pop(c) + self._columns[c].table = None + return self._columns.pop(c) def add_index(self, i: Index) -> None: ''' @@ -119,9 +125,9 @@ def get_refs(self) -> List['Reference']: def __getitem__(self, k: Union[int, str]) -> Column: if isinstance(k, int): - return self.columns[k] + return self._columns[k] elif isinstance(k, str): - for c in self.columns: + for c in self._columns: if c.name == k: return c raise ColumnNotFoundError(f'Column {k} not present in table {self.name}') @@ -135,7 +141,7 @@ def get(self, k, default: Optional[Column] = None) -> Optional[Column]: return default def __iter__(self): - return iter(self.columns) + return iter(self._columns) def __repr__(self): ''' diff --git a/pydbml/database.py b/pydbml/database.py index 41dd589..4c90240 100644 --- a/pydbml/database.py +++ b/pydbml/database.py @@ -2,6 +2,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Tuple from typing import Union from ._classes.sticky_note import StickyNote @@ -25,7 +26,7 @@ def __init__( ) -> None: self.sql_renderer = sql_renderer self.dbml_renderer = dbml_renderer - self.tables: List['Table'] = [] + self._tables: List['Table'] = [] self.table_dict: Dict[str, 'Table'] = {} self.refs: List['Reference'] = [] self.enums: List['Enum'] = [] @@ -34,19 +35,24 @@ def __init__( self.project: Optional['Project'] = None self.allow_properties = allow_properties + @property + def tables(self) -> Tuple['Table', ...]: + """Returns a read-only tuple of tables.""" + return tuple(self._tables) + def __repr__(self) -> str: return f"" def __getitem__(self, k: Union[int, str]) -> Table: if isinstance(k, int): - return self.tables[k] + return self._tables[k] elif isinstance(k, str): return self.table_dict[k] else: raise TypeError('indeces must be str or int') def __iter__(self): - return iter(self.tables) + return iter(self._tables) def _set_database(self, obj: Any) -> None: obj.database = self @@ -71,7 +77,7 @@ def add(self, obj: Any) -> Any: raise DatabaseValidationError(f'Unsupported type {type(obj)}.') def add_table(self, obj: Table) -> Table: - if obj in self.tables: + if obj in self._tables: raise DatabaseValidationError(f'{obj} is already in the database.') if obj.full_name in self.table_dict: raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.') @@ -80,7 +86,7 @@ def add_table(self, obj: Table) -> Table: self._set_database(obj) - self.tables.append(obj) + self._tables.append(obj) self.table_dict[obj.full_name] = obj if obj.alias: self.table_dict[obj.alias] = obj @@ -152,10 +158,10 @@ def delete(self, obj: Any) -> Any: def delete_table(self, obj: Table) -> Table: try: - index = self.tables.index(obj) + index = self._tables.index(obj) except ValueError: raise DatabaseValidationError(f'{obj} is not in the database.') - self._unset_database(self.tables.pop(index)) + self._unset_database(self._tables.pop(index)) result = self.table_dict.pop(obj.full_name) if obj.alias: self.table_dict.pop(obj.alias) diff --git a/pydbml/renderer/sql/default/utils.py b/pydbml/renderer/sql/default/utils.py index 8befac3..6afe70e 100644 --- a/pydbml/renderer/sql/default/utils.py +++ b/pydbml/renderer/sql/default/utils.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Union +from typing import List, Dict, Union, Sequence from pydbml.classes import Enum, Reference, Table from pydbml.constants import MANY_TO_ONE, ONE_TO_MANY @@ -9,7 +9,7 @@ def comment_to_sql(val: str) -> str: return comment(val, '--') -def reorder_tables_for_sql(tables: List['Table'], refs: List['Reference']) -> List['Table']: +def reorder_tables_for_sql(tables: Sequence['Table'], refs: List['Reference']) -> List['Table']: """ Attempt to reorder the tables, so that they are defined in SQL before they are referenced by inline foreign keys. diff --git a/test/test_classes/test_table.py b/test/test_classes/test_table.py index 345ee99..924bf7d 100644 --- a/test/test_classes/test_table.py +++ b/test/test_classes/test_table.py @@ -96,9 +96,13 @@ def test_add_column(self) -> None: t.add_column(c2) self.assertEqual(c1.table, t) self.assertEqual(c2.table, t) - self.assertEqual(t.columns, [c1, c2]) + self.assertEqual(t.columns, (c1, c2)) with self.assertRaises(TypeError): t.add_column('wrong type') + c3 = Column('value', 'varchar') + with self.assertRaises(AttributeError): + # shouldn't be possible to modify columns directly + t.columns.append(c3) def test_delete_column(self) -> None: t = Table('products') diff --git a/test/test_database.py b/test/test_database.py index 095e75c..2bce004 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -61,6 +61,9 @@ def test_add_table_bad(self) -> None: t2 = Table('test_table') with self.assertRaises(DatabaseValidationError): database.add_table(t2) + with self.assertRaises(AttributeError): + # shouldn't be possible to modify tables directly + database.tables.append(t2) def test_delete_table(self) -> None: c = Column('test', 'varchar', True)