Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions pydbml/_classes/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Tuple
from typing import Union

from pydbml.exceptions import ColumnNotFoundError
Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(self,
self.database: Optional[Database] = None
self.name = name
self.schema = schema
self.columns: List[Column] = []
self._columns: List[Column] = []
for column in columns or []:
self.add_column(column)
self.indexes: List[Index] = []
Expand All @@ -51,6 +52,11 @@ def __init__(self,
self.abstract = abstract
self.properties = properties if properties else {}

@property
def columns(self) -> Tuple[Column, ...]:
"""Returns a read-only tuple of columns."""
return tuple(self._columns)

@property
def note(self):
return self._note
Expand All @@ -75,18 +81,18 @@ def add_column(self, c: Column) -> None:
if not isinstance(c, Column):
raise TypeError('Columns must be of type Column')
c.table = self
self.columns.append(c)
self._columns.append(c)

def delete_column(self, c: Union[Column, int]) -> Column:
if isinstance(c, Column):
if c in self.columns:
if c in self._columns:
c.table = None
return self.columns.pop(self.columns.index(c))
return self._columns.pop(self._columns.index(c))
else:
raise ColumnNotFoundError(f'Column {c} if missing in the table')
elif isinstance(c, int):
self.columns[c].table = None
return self.columns.pop(c)
self._columns[c].table = None
return self._columns.pop(c)

def add_index(self, i: Index) -> None:
'''
Expand Down Expand Up @@ -119,9 +125,9 @@ def get_refs(self) -> List['Reference']:

def __getitem__(self, k: Union[int, str]) -> Column:
if isinstance(k, int):
return self.columns[k]
return self._columns[k]
elif isinstance(k, str):
for c in self.columns:
for c in self._columns:
if c.name == k:
return c
raise ColumnNotFoundError(f'Column {k} not present in table {self.name}')
Expand All @@ -135,7 +141,7 @@ def get(self, k, default: Optional[Column] = None) -> Optional[Column]:
return default

def __iter__(self):
return iter(self.columns)
return iter(self._columns)

def __repr__(self):
'''
Expand Down
20 changes: 13 additions & 7 deletions pydbml/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from ._classes.sticky_note import StickyNote
Expand All @@ -25,7 +26,7 @@ def __init__(
) -> None:
self.sql_renderer = sql_renderer
self.dbml_renderer = dbml_renderer
self.tables: List['Table'] = []
self._tables: List['Table'] = []
self.table_dict: Dict[str, 'Table'] = {}
self.refs: List['Reference'] = []
self.enums: List['Enum'] = []
Expand All @@ -34,19 +35,24 @@ def __init__(
self.project: Optional['Project'] = None
self.allow_properties = allow_properties

@property
def tables(self) -> Tuple['Table', ...]:
"""Returns a read-only tuple of tables."""
return tuple(self._tables)

def __repr__(self) -> str:
return f"<Database>"

def __getitem__(self, k: Union[int, str]) -> Table:
if isinstance(k, int):
return self.tables[k]
return self._tables[k]
elif isinstance(k, str):
return self.table_dict[k]
else:
raise TypeError('indeces must be str or int')

def __iter__(self):
return iter(self.tables)
return iter(self._tables)

def _set_database(self, obj: Any) -> None:
obj.database = self
Expand All @@ -71,7 +77,7 @@ def add(self, obj: Any) -> Any:
raise DatabaseValidationError(f'Unsupported type {type(obj)}.')

def add_table(self, obj: Table) -> Table:
if obj in self.tables:
if obj in self._tables:
raise DatabaseValidationError(f'{obj} is already in the database.')
if obj.full_name in self.table_dict:
raise DatabaseValidationError(f'Table {obj.full_name} is already in the database.')
Expand All @@ -80,7 +86,7 @@ def add_table(self, obj: Table) -> Table:

self._set_database(obj)

self.tables.append(obj)
self._tables.append(obj)
self.table_dict[obj.full_name] = obj
if obj.alias:
self.table_dict[obj.alias] = obj
Expand Down Expand Up @@ -152,10 +158,10 @@ def delete(self, obj: Any) -> Any:

def delete_table(self, obj: Table) -> Table:
try:
index = self.tables.index(obj)
index = self._tables.index(obj)
except ValueError:
raise DatabaseValidationError(f'{obj} is not in the database.')
self._unset_database(self.tables.pop(index))
self._unset_database(self._tables.pop(index))
result = self.table_dict.pop(obj.full_name)
if obj.alias:
self.table_dict.pop(obj.alias)
Expand Down
4 changes: 2 additions & 2 deletions pydbml/renderer/sql/default/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Union
from typing import List, Dict, Union, Sequence

from pydbml.classes import Enum, Reference, Table
from pydbml.constants import MANY_TO_ONE, ONE_TO_MANY
Expand All @@ -9,7 +9,7 @@ def comment_to_sql(val: str) -> str:
return comment(val, '--')


def reorder_tables_for_sql(tables: List['Table'], refs: List['Reference']) -> List['Table']:
def reorder_tables_for_sql(tables: Sequence['Table'], refs: List['Reference']) -> List['Table']:
"""
Attempt to reorder the tables, so that they are defined in SQL before they are referenced by
inline foreign keys.
Expand Down
6 changes: 5 additions & 1 deletion test/test_classes/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ def test_add_column(self) -> None:
t.add_column(c2)
self.assertEqual(c1.table, t)
self.assertEqual(c2.table, t)
self.assertEqual(t.columns, [c1, c2])
self.assertEqual(t.columns, (c1, c2))
with self.assertRaises(TypeError):
t.add_column('wrong type')
c3 = Column('value', 'varchar')
with self.assertRaises(AttributeError):
# shouldn't be possible to modify columns directly
t.columns.append(c3)

def test_delete_column(self) -> None:
t = Table('products')
Expand Down
3 changes: 3 additions & 0 deletions test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def test_add_table_bad(self) -> None:
t2 = Table('test_table')
with self.assertRaises(DatabaseValidationError):
database.add_table(t2)
with self.assertRaises(AttributeError):
# shouldn't be possible to modify tables directly
database.tables.append(t2)

def test_delete_table(self) -> None:
c = Column('test', 'varchar', True)
Expand Down