diff --git a/ereuse_devicehub/resources/lot/models.py b/ereuse_devicehub/resources/lot/models.py index 0da6ff85..ecc9d066 100644 --- a/ereuse_devicehub/resources/lot/models.py +++ b/ereuse_devicehub/resources/lot/models.py @@ -2,9 +2,9 @@ import uuid from datetime import datetime from flask import g +from sqlalchemy import TEXT from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import aliased -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression as exp from sqlalchemy_utils import LtreeType from sqlalchemy_utils.types.ltree import LQUERY from teal.db import UUIDLtree @@ -53,22 +53,32 @@ class Lot(Thing): Path.add(self.id, child) db.session.refresh(self) # todo is this useful? - def remove_child(self, child: 'Lot'): - Path.delete(self.id, child.id) + def remove_child(self, child): + if isinstance(child, Lot): + Path.delete(self.id, child.id) + else: + assert isinstance(child, uuid.UUID) + Path.delete(self.id, child) @property def children(self): """The children lots.""" # From https://stackoverflow.com/a/41158890 - # todo test - cls = self.__class__ - exp = '*.{}.*{{1}}'.format(UUIDLtree.convert(self.id)) - child_lots = aliased(Lot) - + id = UUIDLtree.convert(self.id) return self.query \ - .join(cls.paths) \ - .filter(Path.path.lquery(expression.cast(exp, LQUERY))) \ - .join(child_lots, Path.lot) + .join(self.__class__.paths) \ + .filter(Path.path.lquery(exp.cast('*.{}.*{{1}}'.format(id), LQUERY))) + + @property + def parents(self): + """The parent lots.""" + id = UUIDLtree.convert(self.id) + i = db.func.index(Path.path, id) + parent_id = db.func.replace(exp.cast(db.func.subpath(Path.path, i - 1, i), TEXT), '_', '-') + join_clause = parent_id == exp.cast(Lot.id, TEXT) + return self.query.join(Path, join_clause).filter( + Path.path.lquery(exp.cast('*{{1}}.{}.*'.format(id), LQUERY)) + ) def __contains__(self, child: 'Lot'): return Path.has_lot(self.id, child.id) diff --git a/ereuse_devicehub/resources/lot/models.pyi b/ereuse_devicehub/resources/lot/models.pyi index c50b8165..1d10774e 100644 --- a/ereuse_devicehub/resources/lot/models.pyi +++ b/ereuse_devicehub/resources/lot/models.pyi @@ -31,7 +31,7 @@ class Lot(Thing): def add_child(self, child: Union[Lot, uuid.UUID]): pass - def remove_child(self, child: Lot): + def remove_child(self, child: Union[Lot, uuid.UUID]): pass @classmethod @@ -42,6 +42,10 @@ class Lot(Thing): def children(self) -> LotQuery: pass + @property + def parents(self) -> LotQuery: + pass + class Path: id = ... # type: Column diff --git a/ereuse_devicehub/resources/lot/schemas.py b/ereuse_devicehub/resources/lot/schemas.py index c148e537..d0c723a3 100644 --- a/ereuse_devicehub/resources/lot/schemas.py +++ b/ereuse_devicehub/resources/lot/schemas.py @@ -12,6 +12,5 @@ class Lot(Thing): name = f.String(validate=f.validate.Length(max=STR_SIZE), required=True) closed = f.Boolean(missing=False, description=m.Lot.closed.comment) devices = NestedOn(Device, many=True, dump_only=True) - children = NestedOn('Lot', - many=True, - dump_only=True) + children = NestedOn('Lot', many=True, dump_only=True) + parents = NestedOn('Lot', many=True, dump_only=True) diff --git a/ereuse_devicehub/resources/lot/views.py b/ereuse_devicehub/resources/lot/views.py index d8cb8f81..fd1c7bfa 100644 --- a/ereuse_devicehub/resources/lot/views.py +++ b/ereuse_devicehub/resources/lot/views.py @@ -49,6 +49,7 @@ class LotBaseChildrenView(View): lot = self.get_lot(id) self._post(lot, self.get_ids()) db.session.commit() + ret = self.schema.jsonify(lot) ret.status_code = 201 return ret diff --git a/tests/test_lot.py b/tests/test_lot.py index 20ddbd6a..c3c91b2b 100644 --- a/tests/test_lot.py +++ b/tests/test_lot.py @@ -210,10 +210,15 @@ def test_post_get_lot(user: UserClient): def test_post_add_children_view(user: UserClient): """Tests adding children lots to a lot through the view.""" - l, _ = user.post(({'name': 'Parent'}), res=Lot) + parent, _ = user.post(({'name': 'Parent'}), res=Lot) child, _ = user.post(({'name': 'Child'}), res=Lot) - l, _ = user.post({}, res=Lot, item='{}/children'.format(l['id']), query=[('id', child['id'])]) - assert l['children'][0]['id'] == child['id'] + parent, _ = user.post({}, + res=Lot, + item='{}/children'.format(parent['id']), + query=[('id', child['id'])]) + assert parent['children'][0]['id'] == child['id'] + child, _ = user.get(res=Lot, item=child['id']) + assert child['parents'][0]['id'] == parent['id'] @pytest.mark.xfail(reason='Just develop the test')