# mapper/util.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
from sqlalchemy import sql, util
from sqlalchemy.sql import expression, util as sql_util, operators
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty, AttributeExtension
from sqlalchemy.orm import attributes, exc
all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire",
"none"))
_INSTRUMENTOR = ('mapper', 'instrumentor')
class CascadeOptions(object):
"""Keeps track of the options sent to relation().cascade"""
def __init__(self, arg=""):
if not arg:
values = set()
else:
values = set(c.strip() for c in arg.split(','))
self.delete_orphan = "delete-orphan" in values
self.delete = "delete" in values or "all" in values
self.save_update = "save-update" in values or "all" in values
self.merge = "merge" in values or "all" in values
self.expunge = "expunge" in values or "all" in values
self.refresh_expire = "refresh-expire" in values or "all" in values
if self.delete_orphan and not self.delete:
util.warn("The 'delete-orphan' cascade option requires "
"'delete'. This will raise an error in 0.6.")
for x in values:
if x not in all_cascades:
raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
def __contains__(self, item):
return getattr(self, item.replace("-", "_"), False)
def __repr__(self):
return "CascadeOptions(%s)" % repr(",".join(
[x for x in ['delete', 'save_update', 'merge', 'expunge',
'delete_orphan', 'refresh-expire']
if getattr(self, x, False) is True]))
class Validator(AttributeExtension):
"""Runs a validation method on an attribute value to be set or appended.
The Validator class is used by the :func:`~sqlalchemy.orm.validates`
decorator, and direct access is usually not needed.
"""
def __init__(self, key, validator):
"""Construct a new Validator.
key - name of the attribute to be validated;
will be passed as the second argument to
the validation method (the first is the object instance itself).
validator - an function or instance method which accepts
three arguments; an instance (usually just 'self' for a method),
the key name of the attribute, and the value. The function should
return the same value given, unless it wishes to modify it.
"""
self.key = key
self.validator = validator
def append(self, state, value, initiator):
return self.validator(state.obj(), self.key, value)
def set(self, state, value, oldvalue, initiator):
return self.validator(state.obj(), self.key, value)
def polymorphic_union(table_map, typecolname, aliasname='p_union'):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
this is used.
"""
colnames = set()
colnamemaps = {}
types = {}
for key in table_map.keys():
table = table_map[key]
# mysql doesnt like selecting from a select; make it an alias of the select
if isinstance(table, sql.Select):
table = table.alias()
table_map[key] = table
m = {}
for c in table.c:
colnames.add(c.key)
m[c.key] = c
types[c.key] = c.type
colnamemaps[table] = m
def col(name, table):
try:
return colnamemaps[table][name]
except KeyError:
return sql.cast(sql.null(), types[name]).label(name)
result = []
for type, table in table_map.iteritems():
if typecolname is not None:
result.append(sql.select([col(name, table) for name in colnames] +
[sql.literal_column("'%s'" % type).label(typecolname)],
from_obj=[table]))
else:
result.append(sql.select([col(name, table) for name in colnames],
from_obj=[table]))
return sql.union_all(*result).alias(aliasname)
def identity_key(*args, **kwargs):
"""Get an identity key.
Valid call signatures:
* ``identity_key(class, ident)``
class
mapped class (must be a positional argument)
ident
primary key, if the key is composite this is a tuple
* ``identity_key(instance=instance)``
instance
object instance (must be given as a keyword arg)
* ``identity_key(class, row=row)``
class
mapped class (must be a positional argument)
row
result proxy row (must be given as a keyword arg)
"""
if args:
if len(args) == 1:
class_ = args[0]
try:
row = kwargs.pop("row")
except KeyError:
ident = kwargs.pop("ident")
elif len(args) == 2:
class_, ident = args
elif len(args) == 3:
class_, ident = args
else:
raise sa_exc.ArgumentError("expected up to three "
"positional arguments, got %s" % len(args))
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys()))
mapper = class_mapper(class_)
if "ident" in locals():
return mapper.identity_key_from_primary_key(ident)
return mapper.identity_key_from_row(row)
instance = kwargs.pop("instance")
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys()))
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
class ExtensionCarrier(dict):
"""Fronts an ordered collection of MapperExtension objects.
Bundles multiple MapperExtensions into a unified callable unit,
encapsulating ordering, looping and EXT_CONTINUE logic. The
ExtensionCarrier implements the MapperExtension interface, e.g.::
carrier.after_insert(...args...)
The dictionary interface provides containment for implemented
method names mapped to a callable which executes that method
for participating extensions.
"""
interface = set(method for method in dir(MapperExtension)
if not method.startswith('_'))
def __init__(self, extensions=None):
self._extensions = []
for ext in extensions or ():
self.append(ext)
def copy(self):
return ExtensionCarrier(self._extensions)
def push(self, extension):
"""Insert a MapperExtension at the beginning of the collection."""
self._register(extension)
self._extensions.insert(0, extension)
def append(self, extension):
"""Append a MapperExtension at the end of the collection."""
self._register(extension)
self._extensions.append(extension)
def __iter__(self):
"""Iterate over MapperExtensions in the collection."""
return iter(self._extensions)
def _register(self, extension):
"""Register callable fronts for overridden interface methods."""
for method in self.interface.difference(self):
impl = getattr(extension, method, None)
if impl and impl is not getattr(MapperExtension, method):
self[method] = self._create_do(method)
def _create_do(self, method):
"""Return a closure that loops over impls of the named method."""
def _do(*args, **kwargs):
for ext in self._extensions:
ret = getattr(ext, method)(*args, **kwargs)
if ret is not EXT_CONTINUE:
return ret
else:
return EXT_CONTINUE
_do.__name__ = method
return _do
@staticmethod
def _pass(*args, **kwargs):
return EXT_CONTINUE
def __getattr__(self, key):
"""Delegate MapperExtension methods to bundled fronts."""
if key not in self.interface:
raise AttributeError(key)
return self.get(key, self._pass)
class ORMAdapter(sql_util.ColumnAdapter):
"""Extends ColumnAdapter to accept ORM entities.
The selectable is extracted from the given entity,
and the AliasedClass if any is referenced.
"""
def __init__(self, entity, equivalents=None, chain_to=None):
self.mapper, selectable, is_aliased_class = _entity_info(entity)
if is_aliased_class:
self.aliased_class = entity
else:
self.aliased_class = None
sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to)
def replace(self, elem):
entity = elem._annotations.get('parentmapper', None)
if not entity or entity.isa(self.mapper):
return sql_util.ColumnAdapter.replace(self, elem)
else:
return None
class AliasedClass(object):
"""Represents an 'alias'ed form of a mapped class for usage with Query.
The ORM equivalent of a :class:`~sqlalchemy.sql.expression.Alias`
object, this object mimics the mapped class using a
__getattr__ scheme and maintains a reference to a
real Alias object. It indicates to Query that the
selectable produced for this class should be aliased,
and also adapts PropComparators produced by the class'
InstrumentedAttributes so that they adapt the
"local" side of SQL expressions against the alias.
"""
def __init__(self, cls, alias=None, name=None):
self.__mapper = _class_to_mapper(cls)
self.__target = self.__mapper.class_
alias = alias or self.__mapper._with_polymorphic_selectable.alias()
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
# used to assign a name to the RowTuple object
# returned by Query.
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
def __getstate__(self):
return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
def __setstate__(self, state):
self.__mapper = state['mapper']
self.__target = self.__mapper.class_
alias = state['alias']
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
name = state['name']
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
def __adapt_element(self, elem):
return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
def __adapt_prop(self, prop):
existing = getattr(self.__target, prop.key)
comparator = existing.comparator.adapted(self.__adapt_element)
queryattr = attributes.QueryableAttribute(prop.key,
impl=existing.impl, parententity=self, comparator=comparator)
setattr(self, prop.key, queryattr)
return queryattr
def __getattr__(self, key):
prop = self.__mapper._get_property(key, raiseerr=False)
if prop:
return self.__adapt_prop(prop)
for base in self.__target.__mro__:
try:
attr = object.__getattribute__(base, key)
except AttributeError:
continue
else:
break
else:
raise AttributeError(key)
if hasattr(attr, 'func_code'):
is_method = getattr(self.__target, key, None)
if is_method and is_method.im_self is not None:
return util.types.MethodType(attr.im_func, self, self)
else:
return None
elif hasattr(attr, '__get__'):
return attr.__get__(None, self)
else:
return attr
def __repr__(self):
return '