Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
540 changes: 540 additions & 0 deletions docs/src/design/semantic-matching-spec.md

Large diffs are not rendered by default.

79 changes: 55 additions & 24 deletions src/datajoint/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import decimal
import inspect
import json
import logging
import re
import uuid
from dataclasses import dataclass
Expand All @@ -14,6 +15,8 @@

from .errors import DataJointError

logger = logging.getLogger(__name__.split(".")[0])

JSON_PATTERN = re.compile(r"^(?P<attr>\w+)(\.(?P<path>[\w.*\[\]]+))?(:(?P<type>[\w(,\s)]+))?$")


Expand Down Expand Up @@ -95,39 +98,68 @@ def __init__(self, restriction):
self.restriction = restriction


def assert_join_compatibility(expr1, expr2):
def assert_join_compatibility(expr1, expr2, semantic_check=True):
"""
Determine if expressions expr1 and expr2 are join-compatible. To be join-compatible,
the matching attributes in the two expressions must be in the primary key of one or the
other expression.
Raises an exception if not compatible.
Determine if expressions expr1 and expr2 are join-compatible.

With semantic_check=True (default):
Raises an error if there are non-homologous namesakes (same name, different lineage).
This prevents accidental joins on attributes that share names but represent
different entities.

If the ~lineage table doesn't exist for either schema, a warning is issued
and semantic checking is disabled (join proceeds as natural join).

With semantic_check=False:
No lineage checking. All namesake attributes are matched (natural join behavior).

:param expr1: A QueryExpression object
:param expr2: A QueryExpression object
:param semantic_check: If True (default), use semantic matching and error on conflicts
"""
from .expression import QueryExpression, U

for rel in (expr1, expr2):
if not isinstance(rel, (U, QueryExpression)):
raise DataJointError("Object %r is not a QueryExpression and cannot be joined." % rel)
if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible
try:
raise DataJointError(
"Cannot join query expressions on dependent attribute `%s`"
% next(r for r in set(expr1.heading.secondary_attributes).intersection(expr2.heading.secondary_attributes))
)
except StopIteration:
pass # all ok


def make_condition(query_expression, condition, columns):
# dj.U is always compatible (it represents all possible lineages)
if isinstance(expr1, U) or isinstance(expr2, U):
return

if semantic_check:
# Check if lineage tracking is available for both expressions
if not expr1.heading.lineage_available or not expr2.heading.lineage_available:
logger.warning(
"Semantic check disabled: ~lineage table not found. "
"To enable semantic matching, rebuild lineage with: "
"schema.rebuild_lineage()"
)
return

# Error on non-homologous namesakes
namesakes = set(expr1.heading.names) & set(expr2.heading.names)
for name in namesakes:
lineage1 = expr1.heading[name].lineage
lineage2 = expr2.heading[name].lineage
# Semantic match requires both lineages to be non-None and equal
if lineage1 is None or lineage2 is None or lineage1 != lineage2:
raise DataJointError(
f"Cannot join on attribute `{name}`: "
f"different lineages ({lineage1} vs {lineage2}). "
f"Use .proj() to rename one of the attributes."
)


def make_condition(query_expression, condition, columns, semantic_check=True):
"""
Translate the input condition into the equivalent SQL condition (a string)

:param query_expression: a dj.QueryExpression object to apply condition
:param condition: any valid restriction object.
:param columns: a set passed by reference to collect all column names used in the
condition.
:param semantic_check: If True (default), use semantic matching and error on conflicts.
:return: an SQL condition string or a boolean value.
"""
from .expression import Aggregation, QueryExpression, U
Expand Down Expand Up @@ -180,7 +212,11 @@ def combine_conditions(negate, conditions):
# restrict by AndList
if isinstance(condition, AndList):
# omit all conditions that evaluate to True
items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition) if item is not True]
items = [
item
for item in (make_condition(query_expression, cond, columns, semantic_check) for cond in condition)
if item is not True
]
if any(item is False for item in items):
return negate # if any item is False, the whole thing is False
if not items:
Expand Down Expand Up @@ -226,14 +262,9 @@ def combine_conditions(negate, conditions):
condition = condition()

# restrict by another expression (aka semijoin and antijoin)
check_compatibility = True
if isinstance(condition, PromiscuousOperand):
condition = condition.operand
check_compatibility = False

if isinstance(condition, QueryExpression):
if check_compatibility:
assert_join_compatibility(query_expression, condition)
assert_join_compatibility(query_expression, condition, semantic_check=semantic_check)
# Always match on all namesakes (natural join semantics)
common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names]
columns.update(common_attributes)
if isinstance(condition, Aggregation):
Expand All @@ -255,7 +286,7 @@ def combine_conditions(negate, conditions):

# if iterable (but not a string, a QueryExpression, or an AndList), treat as an OrList
try:
or_list = [make_condition(query_expression, q, columns) for q in condition]
or_list = [make_condition(query_expression, q, columns, semantic_check) for q in condition]
except TypeError:
raise DataJointError("Invalid restriction type %r" % condition)
else:
Expand Down
19 changes: 16 additions & 3 deletions src/datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def is_foreign_key(line):
return arrow_position >= 0 and not any(c in line[:arrow_position] for c in "\"#'")


def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql):
def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql, fk_attribute_map=None):
"""
:param line: a line from a table definition
:param context: namespace containing referenced objects
Expand All @@ -151,6 +151,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
:param attr_sql: list of sql statements defining attributes -- to be updated by this function.
:param foreign_key_sql: list of sql statements specifying foreign key constraints -- to be updated by this function.
:param index_sql: list of INDEX declaration statements, duplicate or redundant indexes are ok.
:param fk_attribute_map: dict mapping child attr -> (parent_table, parent_attr) -- to be updated by this function.
"""
# Parse and validate
from .expression import QueryExpression
Expand Down Expand Up @@ -194,6 +195,11 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
if primary_key is not None:
primary_key.append(attr)
attr_sql.append(ref.heading[attr].sql.replace("NOT NULL ", "", int(is_nullable)))
# Track FK attribute mapping for lineage: child_attr -> (parent_table, parent_attr)
if fk_attribute_map is not None:
parent_table = ref.support[0] # e.g., `schema`.`table`
parent_attr = ref.heading[attr].original_name
fk_attribute_map[attr] = (parent_table, parent_attr)

# declare the foreign key
foreign_key_sql.append(
Expand Down Expand Up @@ -223,6 +229,7 @@ def prepare_declare(definition, context):
foreign_key_sql = []
index_sql = []
external_stores = []
fk_attribute_map = {} # child_attr -> (parent_table, parent_attr)

for line in definition:
if not line or line.startswith("#"): # ignore additional comments
Expand All @@ -238,6 +245,7 @@ def prepare_declare(definition, context):
attribute_sql,
foreign_key_sql,
index_sql,
fk_attribute_map,
)
elif re.match(r"^(unique\s+)?index\s*.*$", line, re.I): # index
compile_index(line, index_sql)
Expand All @@ -258,6 +266,7 @@ def prepare_declare(definition, context):
foreign_key_sql,
index_sql,
external_stores,
fk_attribute_map,
)


Expand Down Expand Up @@ -285,6 +294,7 @@ def declare(full_table_name, definition, context):
foreign_key_sql,
index_sql,
external_stores,
fk_attribute_map,
) = prepare_declare(definition, context)

if config.get("add_hidden_timestamp", False):
Expand All @@ -297,11 +307,12 @@ def declare(full_table_name, definition, context):
if not primary_key:
raise DataJointError("Table must have a primary key")

return (
sql = (
"CREATE TABLE IF NOT EXISTS %s (\n" % full_table_name
+ ",\n".join(attribute_sql + ["PRIMARY KEY (`" + "`,`".join(primary_key) + "`)"] + foreign_key_sql + index_sql)
+ '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment
), external_stores
)
return sql, external_stores, primary_key, fk_attribute_map


def _make_attribute_alter(new, old, primary_key):
Expand Down Expand Up @@ -387,6 +398,7 @@ def alter(definition, old_definition, context):
foreign_key_sql,
index_sql,
external_stores,
_fk_attribute_map,
) = prepare_declare(definition, context)
(
table_comment_,
Expand All @@ -395,6 +407,7 @@ def alter(definition, old_definition, context):
foreign_key_sql_,
index_sql_,
external_stores_,
_fk_attribute_map_,
) = prepare_declare(old_definition, context)

# analyze differences between declarations
Expand Down
91 changes: 55 additions & 36 deletions src/datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .condition import (
AndList,
Not,
PromiscuousOperand,
Top,
assert_join_compatibility,
extract_column_names,
Expand Down Expand Up @@ -152,13 +151,22 @@ def make_subquery(self):
result._heading = self.heading.make_subquery_heading()
return result

def restrict(self, restriction):
def restrict(self, restriction, semantic_check=True):
"""
Produces a new expression with the new restriction applied.
rel.restrict(restriction) is equivalent to rel & restriction.
rel.restrict(Not(restriction)) is equivalent to rel - restriction

:param restriction: a sequence or an array (treated as OR list), another QueryExpression,
an SQL condition string, or an AndList.
:param semantic_check: If True (default), use semantic matching - only match on
homologous namesakes and error on non-homologous namesakes.
If False, use natural matching on all namesakes (no lineage checking).
:return: A new QueryExpression with the restriction applied.

rel.restrict(restriction) is equivalent to rel & restriction.
rel.restrict(Not(restriction)) is equivalent to rel - restriction

The primary key of the result is unaffected.
Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b))
Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b))
Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists
(logical disjunction of conditions)
Inverse restriction is accomplished by either using the subtraction operator or the Not class.
Expand All @@ -185,17 +193,14 @@ def restrict(self, restriction):
rel - None rel
rel - any_empty_entity_set rel

When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least
When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least
one element in arg (hence arg is treated as an OrList).
Conversely, rel - arg restricts rel to elements that do not match any elements in arg.
Conversely, rel - arg restricts rel to elements that do not match any elements in arg.
Two elements match when their common attributes have equal values or when they have no common attributes.
All shared attributes must be in the primary key of either rel or arg or both or an error will be raised.

QueryExpression.restrict is the only access point that modifies restrictions. All other operators must
ultimately call restrict()

:param restriction: a sequence or an array (treated as OR list), another QueryExpression, an SQL condition
string, or an AndList.
"""
attributes = set()
if isinstance(restriction, Top):
Expand All @@ -204,7 +209,7 @@ def restrict(self, restriction):
) # make subquery to avoid overwriting existing Top
result._top = restriction
return result
new_condition = make_condition(self, restriction, attributes)
new_condition = make_condition(self, restriction, attributes, semantic_check=semantic_check)
if new_condition is True:
return self # restriction has no effect, return the same object
# check that all attributes in condition are present in the query
Expand Down Expand Up @@ -240,14 +245,11 @@ def __and__(self, restriction):
return self.restrict(restriction)

def __xor__(self, restriction):
"""
Permissive restriction operator ignoring compatibility check e.g. ``q1 ^ q2``.
"""
if inspect.isclass(restriction) and issubclass(restriction, QueryExpression):
restriction = restriction()
if isinstance(restriction, Not):
return self.restrict(Not(PromiscuousOperand(restriction.restriction)))
return self.restrict(PromiscuousOperand(restriction))
"""The ^ operator has been removed in DataJoint 2.0."""
raise DataJointError(
"The ^ operator has been removed in DataJoint 2.0. "
"Use .restrict(other, semantic_check=False) for restrictions without semantic checking."
)

def __sub__(self, restriction):
"""
Expand All @@ -274,30 +276,37 @@ def __mul__(self, other):
return self.join(other)

def __matmul__(self, other):
"""
Permissive join of query expressions `self` and `other` ignoring compatibility check
e.g. ``q1 @ q2``.
"""
if inspect.isclass(other) and issubclass(other, QueryExpression):
other = other() # instantiate
return self.join(other, semantic_check=False)
"""The @ operator has been removed in DataJoint 2.0."""
raise DataJointError(
"The @ operator has been removed in DataJoint 2.0. "
"Use .join(other, semantic_check=False) for joins without semantic checking."
)

def join(self, other, semantic_check=True, left=False):
"""
create the joined QueryExpression.
a * b is short for A.join(B)
a @ b is short for A.join(B, semantic_check=False)
Additionally, left=True will retain the rows of self, effectively performing a left join.
Create the joined QueryExpression.

:param other: QueryExpression to join with
:param semantic_check: If True (default), use semantic matching - only match on
homologous namesakes (same lineage) and error on non-homologous namesakes.
If False, use natural join on all namesakes (no lineage checking).
:param left: If True, perform a left join (retain all rows from self)
:return: The joined QueryExpression

a * b is short for a.join(b)
"""
# trigger subqueries if joining on renamed attributes
# Joining with U is no longer supported
if isinstance(other, U):
return other * self
raise DataJointError(
"table * dj.U(...) is no longer supported in DataJoint 2.0. "
"This pattern is no longer necessary with the new semantic matching system."
)
if inspect.isclass(other) and issubclass(other, QueryExpression):
other = other() # instantiate
if not isinstance(other, QueryExpression):
raise DataJointError("The argument of join must be a QueryExpression")
if semantic_check:
assert_join_compatibility(self, other)
assert_join_compatibility(self, other, semantic_check=semantic_check)
# Always natural join on all namesakes
join_attributes = set(n for n in self.heading.names if n in other.heading.names)
# needs subquery if self's FROM clause has common attributes with other's FROM clause
need_subquery1 = need_subquery2 = bool(
Expand Down Expand Up @@ -826,8 +835,18 @@ def join(self, other, left=False):
return result

def __mul__(self, other):
"""shorthand for join"""
return self.join(other)
"""The * operator with dj.U has been removed in DataJoint 2.0."""
raise DataJointError(
"dj.U(...) * table is no longer supported in DataJoint 2.0. "
"This pattern is no longer necessary with the new semantic matching system."
)

def __sub__(self, other):
"""Anti-restriction with dj.U produces an infinite set."""
raise DataJointError(
"dj.U(...) - table produces an infinite set and is not supported. "
"Consider using a different approach for your query."
)

def aggr(self, group, **named_attributes):
"""
Expand Down
Loading
Loading