# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import pytz

from oslo_db import api as oslo_db_api

from sqlalchemy.orm import subqueryload
import sqlalchemy.sql.expression as expr
import sqlalchemy.sql.functions as func
from sqlalchemy import Integer, and_, asc, desc, or_
from wsme.exc import ClientSideError

from storyboard._i18n import _
from storyboard.common import exception as exc
from storyboard.db.api import base as api_base
from storyboard.db.api import projects as projects_api
from storyboard.db.api import story_tags
from storyboard.db.api import story_types
from storyboard.db.api import teams as teams_api
from storyboard.db.api import users as users_api
from storyboard.db import models


STORY_STATUSES = {
    'active': ['todo', 'inprogress', 'review'],
    'merged': ['merged'],
    'invalid': ['invalid']
}


class _StorySummary(object):

    """Abstraction for a Story's calculated status and task counts.

    Takes a row of results from a query for both the Story model and some
    extra data about its status and count of tasks per-status. The status
    is calculated at query time, hence needing this abstraction.

    This was previously implemented as a different database model, but
    that approach was inflexible about status calculations, and caused
    queries with very poor performance at scale to be generated.

    """

    def __init__(self, row):
        self._summary_dict = {}

        story_obj = row.Story
        if story_obj:
            self._summary_dict.update(story_obj.as_dict())

        self._summary_dict['status'] = row.status
        for task_status in models.Task.TASK_STATUSES:
            self._summary_dict[task_status] = getattr(row, task_status)

    def __getattr__(self, attr):
        if attr in self._summary_dict:
            return self._summary_dict[attr]
        raise AttributeError(attr)

    def __getitem__(self, key):
        return self._summary_dict[key]

    def as_dict(self):
        return self._summary_dict


def story_get_simple(story_id, session=None, current_user=None,
                     no_permissions=False):
    """Return a story with the given ID.

    The `no_permissions` parameter should only ever be used in situations
    where the permissions have already been checked. For example, when
    updating the "updated_at" field when comments are made on stories.

    :param story_id: ID of the story to return.
    :param session: DB session to use.
    :param current_user: The ID of the user making the request.
    :param no_permissions: Skip filtering stories by permission.
    :return: The story being searched for, or None if nothing found.

    """
    query = api_base.model_query(models.Story, session) \
        .options(subqueryload(models.Story.tags)) \
        .filter_by(id=story_id)

    if not no_permissions:
        # Filter out stories that the current user can't see
        query = api_base.filter_private_stories(query, current_user)

    return query.first()


def story_get(story_id, session=None, current_user=None):
    """Return a summary object for the story with the given ID.

    The object returned from this function isn't a SQLAlchemy model, and
    so shouldn't be used with an expectation of modifications being reflected
    in the underlying database.

    If trying to query for an actual ``models.Story`` object,
    ``story_get_simple`` should be used instead.

    :param story_id: ID of the story to return.
    :param session: DB session to use.
    :param current_user: The ID of the user making the request.
    :return: A summary object for the story, or None if nothing found.

    """
    query = _story_query_base()
    query = query.filter(models.Story.id == story_id)

    query = query.options(subqueryload(models.Story.due_dates))
    query = query.options(subqueryload(models.Story.tags))
    query = query.options(subqueryload(models.Story.permissions))

    # Filter out stories that the current user can't see
    query = api_base.filter_private_stories(query, current_user)

    print(query.statement.compile(compile_kwargs={"literal_binds": True}))
    story = query.first()
    if story and story[0] is not None:
        story = _StorySummary(story)
    else:
        story = None
    return story


def _order_and_paginate(query, offset, limit, sort_keys, sort_dir=None,
                        sort_dirs=None):
    # NOTE(SotK): This is inspired by the `paginate_query` function in
    # oslo_db/sqlalchemy/utils.py, however it implements offset/limit
    # pagination rather than using a marker.
    #
    # Even if we supported marker-based pagination we'd still need a
    # custom implementation here, since we want to be able to do
    # pagination on fields that aren't part of the database model
    # (namely `status`). The oslo.db implementation requires that
    # all the sort keys are part of the ORM model, and also implements
    # comparison in a way that won't work for our use-case.
    if sort_dir and sort_dirs:
        raise AssertionError('Only one of sort_dir and sort_dirs can be set')

    if sort_dir is None and sort_dirs is None:
        sort_dir = 'asc'

    if sort_dirs is None:
        sort_dirs = [sort_dir for _key in sort_keys]

    if len(sort_dirs) != len(sort_keys):
        raise AssertionError(
            'sort_dirs and sort_keys must have the same length')

    for current_key, current_dir in zip(sort_keys, sort_dirs):
        try:
            sort_key_attr = getattr(models.Story, current_key)
        except AttributeError:
            # Fallback to ordering by the raw sort_key provided.
            # TODO(SotK): make sure that this is only an expected field?
            sort_key_attr = current_key
        sort_dir_func = {
            'asc': asc,
            'desc': desc
        }[current_dir]
        query = query.order_by(sort_dir_func(sort_key_attr))

    # TODO(SotK): We should switch to using marker-based pagination
    # as soon as we can, it just needs the web client to understand
    # what to do with a marker id.
    if offset is not None:
        query = query.offset(offset)

    if limit is not None:
        query = query.limit(limit)

    return query


def story_get_all(title=None, description=None, status=None, assignee_id=None,
                  creator_id=None, project_group_id=None, project_id=None,
                  subscriber_id=None, tags=None, updated_since=None,
                  board_id=None, worklist_id=None, marker=None, offset=None,
                  limit=None, tags_filter_type="all", sort_field='id',
                  sort_dir='asc', current_user=None):
    # Sanity checks, in case someone accidentally explicitly passes in 'None'
    if not sort_field:
        sort_field = 'id'
    if not sort_dir:
        sort_dir = 'asc'
    if not isinstance(status, list) and status is not None:
        status = [status]

    query = _story_build_query(title=title,
                               description=description,
                               status=status,
                               assignee_id=assignee_id,
                               creator_id=creator_id,
                               project_group_id=project_group_id,
                               project_id=project_id,
                               subscriber_id=subscriber_id,
                               tags=tags,
                               updated_since=updated_since,
                               board_id=board_id,
                               worklist_id=worklist_id,
                               tags_filter_type=tags_filter_type,
                               current_user=current_user)

    query = _order_and_paginate(
        query, offset, limit, [sort_field], sort_dir=sort_dir)

    query = query.options(subqueryload(models.Story.due_dates))
    query = query.options(subqueryload(models.Story.tags))
    query = query.options(subqueryload(models.Story.permissions))

    results = query.all()
    return [_StorySummary(row) for row in results]


def story_get_count(title=None, description=None, status=None,
                    assignee_id=None, creator_id=None,
                    project_group_id=None, project_id=None,
                    subscriber_id=None, tags=None, updated_since=None,
                    board_id=None, worklist_id=None,
                    tags_filter_type="all", current_user=None):
    query = _story_build_query(title=title,
                               description=description,
                               status=status,
                               assignee_id=assignee_id,
                               creator_id=creator_id,
                               project_group_id=project_group_id,
                               project_id=project_id,
                               subscriber_id=subscriber_id,
                               tags=tags,
                               updated_since=updated_since,
                               board_id=board_id,
                               worklist_id=worklist_id,
                               tags_filter_type=tags_filter_type,
                               current_user=current_user)
    return query.count()


def _story_query_base():
    session = api_base.get_session()
    select_items = []
    select_items.append(models.Story)
    select_items.append(
        expr.case(
            [(func.sum(models.Task.status.in_(STORY_STATUSES['active'])) > 0,
              'active'),
             ((func.sum(models.Task.status == 'merged')) > 0, 'merged')],
            else_='invalid'
        ).label('status')
    )
    for task_status in models.Task.TASK_STATUSES:
        select_items.append(expr.cast(
            func.sum(models.Task.status == task_status), Integer
        ).label(task_status))
    query = session.query(*select_items)

    # This join is needed to be able to do the status calculations
    query = query.outerjoin(
        (models.Task, models.Task.story_id == models.Story.id))

    return query


def _story_build_query(title=None, description=None, status=None,
                       assignee_id=None, creator_id=None,
                       project_group_id=None, project_id=None,
                       subscriber_id=None, tags=None, updated_since=None,
                       board_id=None, worklist_id=None,
                       tags_filter_type="all", current_user=None):
    # Get a basic story query, containing task summaries
    query = _story_query_base()

    # Get rid of anything the requester can't see
    query = api_base.filter_private_stories(query, current_user)

    # Apply basic filters
    query = api_base.apply_query_filters(query=query,
                                         model=models.Story,
                                         title=title,
                                         description=description,
                                         creator_id=creator_id)

    # Only get stories that have been updated since a given datetime
    if updated_since:
        query = query.filter(models.Story.updated_at > updated_since)

    # Filtering by tags
    if tags:
        if tags_filter_type == 'all':
            for tag in tags:
                query = query.filter(models.Story.tags.any(name=tag))
        elif tags_filter_type == 'any':
            query = query.filter(models.Story.tags.any
                                 (models.StoryTag.name.in_(tags)))
        else:
            raise exc.NotFound("Tags filter not found.")

    # Filtering by project group
    if project_group_id:
        # The project_group_mapping table contains all the information we need
        # here; namely the project id (to map to tasks) and the project group
        # id (which we're actually filtering by). There's no need to actually
        # join either of the projects or project_groups tables.
        pgm = models.project_group_mapping
        query = query.join((pgm, pgm.c.project_id == models.Task.project_id))
        query = query.filter(pgm.c.project_group_id == project_group_id)

    # Filtering by task assignee
    if assignee_id:
        query = query.filter(models.Task.assignee_id == assignee_id)

    # Filtering by project
    if project_id:
        query = query.filter(models.Task.project_id == project_id)

    # Filtering by subscriber
    if subscriber_id:
        on_clause = and_(
            models.Subscription.target_id == models.Story.id,
            models.Subscription.target_type == 'story'
        )
        query = query.join((models.Subscription, on_clause))
        query = query.filter(models.Subscription.user_id == subscriber_id)

    # Filtering by either worklist or board requires joining the worklists
    # table, via worklist_items
    if worklist_id or board_id:
        on_clause = and_(
            models.WorklistItem.item_id == models.Story.id,
            models.WorklistItem.item_type == 'story'
        )
        query = query.join((models.WorklistItem, on_clause), models.Worklist)

    # Filter by worklist
    if worklist_id:
        query = api_base.filter_private_worklists(
            query, current_user, hide_lanes=False)
        query = query.filter(models.Worklist.id == worklist_id)

    # Filter by board
    if board_id:
        query = query.join(models.BoardWorklist, models.Board)
        query = api_base.filter_private_boards(query, current_user)
        query = query.filter(models.Board.id == board_id)

    # We need to do GROUP BY to allow the status calculation to work correctly
    query = query.group_by(models.Story.id)

    # Filter by story status. This duplicates the definition of the different
    # story statuses from the SELECT clause of the query, but means we can
    # do all the filtering without a subquery.
    if status is not None:
        if any(s not in STORY_STATUSES for s in status):
            raise ValueError(
                f'Story status must be in {STORY_STATUSES.keys()}')
        criteria = []
        if 'active' in status:
            # A story is active if it has at least one unmerged valid task,
            # so the check is simple
            criteria.append(func.sum(
                models.Task.status.in_(STORY_STATUSES['active'])) > 0
            )
        if 'merged' in status:
            # A story is merged if it has at least one merged task, and also
            # doesn't meet the criteria to be active
            criteria.append(and_(
                func.sum(models.Task.status.in_(STORY_STATUSES['merged'])) > 0,
                func.sum(models.Task.status.in_(STORY_STATUSES['active'])) == 0
            ))
        if 'invalid' in status:
            # A story is invalid if it only has invalid tasks, or no tasks
            criteria.append(or_(
                and_(
                    func.sum(
                        models.Task.status.in_(STORY_STATUSES['invalid'])
                    ) >= 0,
                    func.sum(
                        models.Task.status.in_(
                            STORY_STATUSES['active']
                            + STORY_STATUSES['merged']
                        )
                    ) == 0
                ),
                func.sum(models.Task.id) == None  # noqa
            ))
        query = query.having(or_(*criteria))

    return query


def story_create(values):
    return api_base.entity_create(models.Story, values)


def story_update(story_id, values, current_user=None):
    api_base.entity_update(models.Story, story_id, values)
    project_ids = get_project_ids(story_id, current_user=current_user)

    for project_id in project_ids:
        projects_api.project_update_updated_at(project_id)

    return story_get(story_id, current_user=current_user)


def get_project_ids(story_id, current_user=None):
    session = api_base.get_session()
    with session.begin(subtransactions=True):
        story = story_get_simple(story_id, session=session,
                current_user=current_user)
        if not story:
            raise exc.NotFound(_("%(name)s not found") %
                               {'name': "Story"})
        project_ids = {task.project_id for task in story.tasks}
    session.expunge(story)
    return project_ids


@oslo_db_api.wrap_db_retry(max_retries=3, retry_on_deadlock=True,
retry_interval=0.5, inc_retry_interval=True)
def story_update_updated_at(story_id):
    session = api_base.get_session()
    with session.begin(subtransactions=True):
        story = story_get_simple(story_id, session=session,
                                 no_permissions=True)
        if not story:
            raise exc.NotFound(_("%(name)s %(id)s not found") %
                               {'name': "Story", 'id': story_id})
        story.updated_at = datetime.datetime.now(tz=pytz.utc)
        session.add(story)
    session.expunge(story)


def story_add_tag(story_id, tag_name, current_user=None):
    session = api_base.get_session()

    with session.begin(subtransactions=True):

        # Get a tag or create a new one
        tag = story_tags.tag_get_by_name(tag_name, session=session)
        if not tag:
            tag = story_tags.tag_create({"name": tag_name})

        story = story_get_simple(
            story_id, session=session, current_user=current_user)
        if not story:
            raise exc.NotFound(_("%(name)s %(id)s not found") %
                               {'name': "Story", 'id': story_id})

        if tag_name in [t.name for t in story.tags]:
            raise exc.DBDuplicateEntry(
                _("The Story %(id)d already has a tag %(tag)s") %
                {'id': story_id, 'tag': tag_name})

        story.tags.append(tag)
        story.updated_at = datetime.datetime.now(tz=pytz.utc)

        session.add(story)
    session.expunge(story)


def story_remove_tag(story_id, tag_name, current_user=None):
    session = api_base.get_session()

    with session.begin(subtransactions=True):

        story = story_get_simple(
            story_id, session=session, current_user=current_user)
        if not story:
            raise exc.NotFound(_("%(name)s %(id)s not found") %
                               {'name': "Story", 'id': story_id})

        if tag_name not in [t.name for t in story.tags]:
            raise exc.NotFound(_("The Story %(story_id)d has "
                                 "no tag %(tag)s") %
                               {'story_id': story_id, 'tag': tag_name})

        tag = [t for t in story.tags if t.name == tag_name][0]
        story.tags.remove(tag)
        story.updated_at = datetime.datetime.now(tz=pytz.utc)
        session.add(story)
    session.expunge(story)


def story_delete(story_id, current_user=None):
    story = story_get_simple(story_id, current_user=current_user)

    if story:
        api_base.entity_hard_delete(models.Story, story_id)


def story_check_story_type_id(story_dict):
    if "story_type_id" in story_dict and not story_dict["story_type_id"]:
        del story_dict["story_type_id"]


def story_can_create_story(story_type_id):
    if not story_type_id:
        return True

    story_type = story_types.story_type_get(story_type_id)

    if not story_type:
        raise exc.NotFound("Story type %s not found." % story_type_id)

    if not story_type.visible:
        return False

    return True


def story_can_mutate(story, new_story_type_id):
    if not new_story_type_id:
        return True

    if story.story_type_id == new_story_type_id:
        return True

    old_story_type = story_types.story_type_get(story.story_type_id)
    new_story_type = story_types.story_type_get(new_story_type_id)

    if not new_story_type:
        raise exc.NotFound(_("Story type %s not found.") % new_story_type_id)

    if not old_story_type.private and new_story_type.private:
        return False

    mutation = story_types.story_type_get_mutations(story.story_type_id,
                                                    new_story_type_id)

    if not mutation:
        return False

    if not new_story_type.restricted:
        return True

    query = api_base.model_query(models.Task)
    query = query.filter_by(story_id=story.id)
    tasks = query.all()
    branch_ids = set()

    for task in tasks:
        if task.branch_id:
            branch_ids.add(task.branch_id)

    branch_ids = list(branch_ids)

    query = api_base.model_query(models.Branch)
    branch = query.filter(models.Branch.id.in_(branch_ids),
                          models.Branch.restricted.__eq__(1)).first()

    if not branch:
        return True

    return False


def create_permission(story, users, teams, session=None):
    story = story_get_simple(story.id, session=session, no_permissions=True)
    permission_dict = {
        'name': 'view_story_%d' % story.id,
        'codename': 'view_story'
    }
    permission = api_base.entity_create(models.Permission, permission_dict)
    story.permissions.append(permission)
    if users is not None:
        for user in users:
            user = users_api.user_get(user.id)
            user.permissions.append(permission)
    if teams is not None:
        for team in teams:
            team = teams_api.team_get(team.id)
            team.permissions.append(permission)
    return permission


def update_permission(story, users, teams, session=None):
    story = story_get_simple(story.id, session=session, no_permissions=True)
    if not story.permissions:
        raise exc.NotFound(_("Permissions for story %d not found.")
                           % story.id)
    permission = story.permissions[0]
    permission_dict = {
        'name': permission.name,
        'codename': permission.codename
    }
    if users is not None:
        permission_dict['users'] = [users_api.user_get(user.id)
                                    for user in users]
    if teams is not None:
        permission_dict['teams'] = [teams_api.team_get(team.id)
                                    for team in teams]

    return api_base.entity_update(models.Permission,
                                  permission.id,
                                  permission_dict)


def add_user(story_id, user_id, current_user=None):
    session = api_base.get_session()

    with session.begin(subtransactions=True):
        story = story_get_simple(
            story_id, session=session, current_user=current_user)
        if not story:
            raise exc.NotFound(_("Story %s not found") % story_id)

        user = users_api.user_get(user_id, session=session)
        if not user:
            raise exc.NotFound(_("User %s not found") % user_id)

        if not story.permissions:
            create_permission(story, [user], [], session)
            return
        permission = story.permissions[0]
        if user_id in [u.id for u in permission.users]:
            raise ClientSideError(_("The User %{user_id}d is already in the "
                                    "permission list for Story "
                                    "%{story_id}d") %
                                  {"user_id": user_id, "story_id": story_id})
        permission.users.append(user)
        session.add(permission)

    return story


def delete_user(story_id, user_id, current_user=None):
    session = api_base.get_session()

    with session.begin(subtransactions=True):
        story = story_get_simple(
            story_id, session=session, current_user=current_user)
        if not story:
            raise exc.NotFound(_("Story %s not found") % story_id)

        user = users_api.user_get(user_id, session=session)
        if not user:
            raise exc.NotFound(_("User %s not found") % user_id)

        if not story.permissions:
            raise ClientSideError(_("The User %{user_id}d isn't in the "
                                    "permission list for Story "
                                    "%{story_id}d") %
                                  {"user_id": user_id, "story_id": story_id})

        permission = story.permissions[0]
        if user_id not in [u.id for u in permission.users]:
            raise ClientSideError(_("The User %{user_id}d isn't in the "
                                    "permission list for Story "
                                    "%{story_id}d") %
                                  {"user_id": user_id, "story_id": story_id})

        entry = [u for u in permission.users if u.id == user_id][0]
        permission.users.remove(entry)
        session.add(permission)

    return story


def add_team(story_id, team_id, current_user=None):
    session = api_base.get_session()

    with session.begin(subtransactions=True):
        story = story_get_simple(
            story_id, session=session, current_user=current_user)
        if not story:
            raise exc.NotFound(_("Story %s not found") % story_id)

        team = teams_api.team_get(team_id, session=session)
        if not team:
            raise exc.NotFound(_("Team %s not found") % team_id)

        if not story.permissions:
            create_permission(story, [], [team], session)
            return
        permission = story.permissions[0]
        if team_id in [t.id for t in permission.teams]:
            raise ClientSideError(_("The Team %{team_id}d is already in the "
                                    "permission list for Story "
                                    "%{story_id}d") %
                                  {"team_id": team_id, "story_id": story_id})
        permission.teams.append(team)
        session.add(permission)

    return story


def delete_team(story_id, team_id, current_user=None):
    session = api_base.get_session()

    with session.begin(subtransactions=True):
        story = story_get_simple(
            story_id, session=session, current_user=current_user)
        if not story:
            raise exc.NotFound(_("Story %s not found") % story_id)

        team = teams_api.team_get(team_id, session=session)
        if not team:
            raise exc.NotFound(_("User %s not found") % team_id)

        if not story.permissions:
            raise ClientSideError(_("The Team %{team_id}d isn't in the "
                                    "permission list for Story "
                                    "%{story_id}d") %
                                  {"team_id": team_id, "story_id": story_id})

        permission = story.permissions[0]
        if team_id not in [t.id for t in permission.teams]:
            raise ClientSideError(_("The Team %{team_id}d isn't in the "
                                    "permission list for Story "
                                    "%{story_id}d") %
                                  {"team_id": team_id, "story_id": story_id})

        entry = [t for t in permission.teams if t.id == team_id][0]
        permission.teams.remove(entry)
        session.add(permission)

    return story
