#    Copyright 2013 IBM Corp.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

"""Handles database requests from other rock services."""

import copy
import itertools

from oslo import messaging
import six


from rock.agent import api as agent_api
from rock.agent import rpcapi as agent_rpcapi
from rock.db import base
from rock import exception
from rock.i18n import _
#from rock import image
from rock import manager
#from rock import network
#from rock.network.security_group import openstack_driver
# from rock import notifications
from rock import objects
from rock.objects import base as rock_object
from rock.openstack.common import excutils
from rock.openstack.common import jsonutils
from rock.openstack.common import log as logging
from rock.openstack.common import periodic_task
from rock.openstack.common import timeutils
#from rock import quota
#from rock.scheduler import client as scheduler_client
#from rock.scheduler import driver as scheduler_driver
#from rock.scheduler import utils as scheduler_utils

LOG = logging.getLogger(__name__)

# Instead of having a huge list of arguments to instance_update(), we just
# accept a dict of fields to update and use this whitelist to validate it.
allowed_updates = ['task_state', 'vm_state', 'expected_task_state',
                   'power_state', 'access_ip_v4', 'access_ip_v6',
                   'launched_at', 'terminated_at', 'host', 'node',
                   'memory_mb', 'vcpus', 'root_gb', 'ephemeral_gb',
                   'instance_type_id', 'root_device_name', 'launched_on',
                   'progress', 'vm_mode', 'default_ephemeral_device',
                   'default_swap_device', 'root_device_name',
                   'system_metadata', 'updated_at'
                   ]

# Fields that we want to convert back into a datetime object.
datetime_fields = ['launched_at', 'terminated_at', 'updated_at']


class ConductorManager(manager.Manager):
    """Mission: Conduct things.

    The methods in the base API for rock-conductor are various proxy operations
    performed on behalf of the rock-agent service running on compute nodes.
    Compute nodes are not allowed to directly access the database, so this set
    of methods allows them to get specific work done without locally accessing
    the database.
    """

    target = messaging.Target(version='1.0')

    def __init__(self, *args, **kwargs):
        super(ConductorManager, self).__init__(service_name='server',
                                               *args, **kwargs)
        #self.security_group_api = (
        #    openstack_driver.get_openstack_security_group_driver())
        #self._network_api = None
        self._agent_api = None
        self.agent_task_mgr = AgentTaskManager()
        #self.cells_rpcapi = cells_rpcapi.CellsAPI()
        # self.additional_endpoints.append(self.agent_task_mgr)

    @property
    def agent_api(self):
        if self._agent_api is None:
            self._agent_api = agent_api.API()
        return self._agent_api

    def ping(self, context, arg):
        # NOTE(russellb) This method can be removed in 2.0 of this API.  It is
        # now a part of the base rpc API.
        return jsonutils.to_primitive({'service': 'conductor', 'arg': arg})

    @messaging.expected_exceptions(KeyError, ValueError,
                                   exception.InvalidUUID,
                                   exception.InstanceNotFound,
                                   exception.UnexpectedTaskStateError)

    def bw_usage_update(self, context, uuid, mac, start_period,
                        bw_in, bw_out, last_ctr_in, last_ctr_out,
                        last_refreshed, update_cells):
        if [bw_in, bw_out, last_ctr_in, last_ctr_out].count(None) != 4:
            self.db.bw_usage_update(context, uuid, mac, start_period,
                                    bw_in, bw_out, last_ctr_in, last_ctr_out,
                                    last_refreshed,
                                    update_cells=update_cells)
        usage = self.db.bw_usage_get(context, uuid, start_period, mac)
        return jsonutils.to_primitive(usage)

    @messaging.expected_exceptions(exception.ComputeHostNotFound,
                                   exception.HostBinaryNotFound)
    def service_get_all_by(self, context, topic, host, binary):
        if not any((topic, host, binary)):
            result = self.db.service_get_all(context)
        elif all((topic, host)):
            if topic == 'agent':
                result = self.db.service_get_by_compute_host(context, host)
                # FIXME(comstud) Potentially remove this on bump to v3.0
                result = [result]
            else:
                result = self.db.service_get_by_host_and_topic(context,
                                                               host, topic)
        elif all((host, binary)):
            result = self.db.service_get_by_args(context, host, binary)
        elif topic:
            result = self.db.service_get_all_by_topic(context, topic)
        elif host:
            result = self.db.service_get_all_by_host(context, host)

        return jsonutils.to_primitive(result)

    """@messaging.expected_exceptions(exception.InstanceActionNotFound)
    def action_event_start(self, context, values):
        evt = self.db.action_event_start(context, values)
        return jsonutils.to_primitive(evt)

    @messaging.expected_exceptions(exception.InstanceActionNotFound,
                                   exception.InstanceActionEventNotFound)
    def action_event_finish(self, context, values):
        evt = self.db.action_event_finish(context, values)
        return jsonutils.to_primitive(evt)"""

    # @periodic_task.periodic_task(run_immediately=True)
    def allocate_acc(self, context):
        if hasattr(self, 'allocate') :
            self.allocate += 1
        else:
            self.allocate = 1
        if self.allocate % 10 != 1:
            return
        agent_api = agent_rpcapi.AgentAPI()
        from rock import context as rock_context
        acc_list = list(objects.AcceleratorList.get_by_compute_node(rock_context.get_admin_context(), '2'))
        # for acc in acc_list:
        acc = acc_list[0]
        dh_capability = {}
        dh_capability['num'] = 1
        dh_capability['pps'] = 1
        dh_capability['bps'] = 1
        ipsec_capability = {}
        ipsec_capability['dh'] = dh_capability
        acc_capability = {}
        acc_capability['ipsec'] = ipsec_capability

        gb_capability_gb = {}
        gb_capability_gb['num'] = 1
        gb_capability_gb['pps'] = 1
        gb_capability_gb['bps'] = 1
        gb_capability = {}
        gb_capability['gb'] = gb_capability_gb
        acc_capability['gb'] = gb_capability
        temp_capability = acc['acc_capability']
        for key, element in acc_capability.items():
            if isinstance(element, str):
                temp_capability[key] = element
            else:
                temp_capability[key] = jsonutils.dumps(element)
        acc['acc_capability'] = temp_capability
        LOG.info('==acc====> %s' % acc)
        agent_rpcapi.AgentAPI().allocate_acc(context, acc)
        LOG.info("================> OK ")

    def dict_1(self, p_dict):
        for key, value in p_dict.items():
            if isinstance(value, dict):
                value = self.dict_1(value)
            else:
                p_dict[key] = 1
        return p_dict

    def service_create(self, context, values):
        svc = self.db.service_create(context, values)
        return jsonutils.to_primitive(svc)

    @messaging.expected_exceptions(exception.ServiceNotFound)
    def service_destroy(self, context, service_id):
        self.db.service_destroy(context, service_id)

    def compute_node_create(self, context, values):
        result = self.db.compute_node_create(context, values)
        return jsonutils.to_primitive(result)

    def compute_node_update(self, context, node, values):
        result = self.db.compute_node_update(context, node['id'], values)
        return jsonutils.to_primitive(result)

    def compute_node_delete(self, context, node):
        result = self.db.compute_node_delete(context, node['id'])
        return jsonutils.to_primitive(result)

    @messaging.expected_exceptions(exception.ServiceNotFound)
    def service_update(self, context, service, values):
        svc = self.db.service_update(context, service['id'], values)
        return jsonutils.to_primitive(svc)

    def task_log_get(self, context, task_name, begin, end, host, state):
        result = self.db.task_log_get(context, task_name, begin, end, host,
                                      state)
        return jsonutils.to_primitive(result)

    def task_log_begin_task(self, context, task_name, begin, end, host,
                            task_items, message):
        result = self.db.task_log_begin_task(context.elevated(), task_name,
                                             begin, end, host, task_items,
                                             message)
        return jsonutils.to_primitive(result)

    def task_log_end_task(self, context, task_name, begin, end, host,
                          errors, message):
        result = self.db.task_log_end_task(context.elevated(), task_name,
                                           begin, end, host, errors, message)
        return jsonutils.to_primitive(result)

    def _object_dispatch(self, target, method, context, args, kwargs):
        """Dispatch a call to an object method.

        This ensures that object methods get called and any exception
        that is raised gets wrapped in an ExpectedException for forwarding
        back to the caller (without spamming the conductor logs).
        """
        try:
            # NOTE(danms): Keep the getattr inside the try block since
            # a missing method is really a client problem
            return getattr(target, method)(context, *args, **kwargs)
        except Exception:
            raise messaging.ExpectedException()

    def object_class_action(self, context, objname, objmethod,
                            objver, args, kwargs):
        """Perform a classmethod action on an object."""
        objclass = rock_object.RockObject.obj_class_from_name(objname,
                                                              objver)
        result = self._object_dispatch(objclass, objmethod, context,
                                       args, kwargs)
        # NOTE(danms): The RPC layer will convert to primitives for us,
        # but in this case, we need to honor the version the client is
        # asking for, so we do it before returning here.
        return (result.obj_to_primitive(target_version=objver)
                if isinstance(result, rock_object.RockObject) else result)

    def object_action(self, context, objinst, objmethod, args, kwargs):
        """Perform an action on an object."""
        oldobj = objinst.obj_clone()
        result = self._object_dispatch(objinst, objmethod, context,
                                       args, kwargs)
        updates = dict()
        # NOTE(danms): Diff the object with the one passed to us and
        # generate a list of changes to forward back
        for name, field in objinst.fields.items():
            if not objinst.obj_attr_is_set(name):
                # Avoid demand-loading anything
                continue
            if (not oldobj.obj_attr_is_set(name) or
                    oldobj[name] != objinst[name]):
                updates[name] = field.to_primitive(objinst, name,
                                                   objinst[name])
        # This is safe since a field named this would conflict with the
        # method anyway
        updates['obj_what_changed'] = objinst.obj_what_changed()
        return updates, result

    def object_backport(self, context, objinst, target_version):
        return objinst.obj_to_primitive(target_version=target_version)

    def accelerator_filter(self, context, host_name, acc_specs):
        """ rock filter for Nova
        host_name: str of destination host name
        acc_specs: request of accelerator. format like: {'AES:pps':'2000', 'IPSEC:Num':'1'} """
        LOG.info("entry accelerator_filter")
        req_list = []
        for key in acc_specs.keys():
            items = key.split(':')
            if len(items) != 3:
                raise exception.AllocateAcceleratorInvalidRequest(reason="invalid accelerator type: %s:%s" %(key, acc_specs[key]))
            capability = int(acc_specs[key]) 
            if capability == 0:
                raise exception.AllocateAcceleratorInvalidRequest(reason="invalid capability: %s:%s" %(key, acc_specs[key]))
            dev = {"type":items[0], 'sub_type':items[1], "unit":items[2], "capability":capability}
            req_list.append(dev)
        result = self.db.accelerator_filter(None, context, host_name, req_list)
        LOG.info("exit accelerator_filter:%d" %(result))
        return result

class AgentTaskManager(base.Base):
    """Namespace for rock-agent methods.

    This class presents an rpc API for rock-conductor under the 'agent_task'
    namespace.  The methods here are agent operations that are invoked
    by the API service.  These methods see the operation to completion, which
    may involve coordinating activities on multiple rock-agent nodes.
    """

    target = messaging.Target(namespace='agent_task', version='1.0')

    def __init__(self):
        super(AgentTaskManager, self).__init__()
        self.agent_rpcapi = agent_rpcapi.AgentAPI()
        #self.scheduler_client = scheduler_client.SchedulerClient()
