#!/usr/bin/env python

# Copyright (c) 2010 Citrix Systems, Inc.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

#
# XenAPI plugin for fetching images from nova-objectstore.
#

import base64
import errno
import hmac
import os
import os.path
import sha
import time
import urlparse

import XenAPIPlugin

from pluginlib_nova import *
configure_logging('objectstore')


KERNEL_DIR = '/boot/guest'

DOWNLOAD_CHUNK_SIZE = 2 * 1024 * 1024
SECTOR_SIZE = 512
MBR_SIZE_SECTORS = 63
MBR_SIZE_BYTES = MBR_SIZE_SECTORS * SECTOR_SIZE


def is_vdi_pv(session, args):
    logging.debug("Checking wheter VDI has PV kernel")
    vdi = exists(args, 'vdi-ref')
    pv = with_vdi_in_dom0(session, vdi, False,
                     lambda dev: _is_vdi_pv('/dev/%s' % dev))
    if pv:
        return 'true'
    else:
        return 'false'


def _is_vdi_pv(dest):
    logging.debug("Running pygrub against %s", dest)
    output = os.popen('pygrub -qn %s' % dest)
    pv = False
    for line in output.readlines():
        #try to find kernel string
        m = re.search('(?<=kernel:)/.*(?:>)', line)
        if m:
            if m.group(0).find('xen') != -1:
                pv = True
    logging.debug("PV:%d", pv)
    return pv


def get_vdi(session, args):
    src_url = exists(args, 'src_url')
    username = exists(args, 'username')
    password = exists(args, 'password')
    raw_image = validate_bool(args, 'raw', 'false')
    add_partition = validate_bool(args, 'add_partition', 'false')
    (proto, netloc, url_path, _, _, _) = urlparse.urlparse(src_url)
    sr = find_sr(session)
    if sr is None:
        raise Exception('Cannot find SR to write VDI to')
    virtual_size = \
        get_content_length(proto, netloc, url_path, username, password)
    if virtual_size < 0:
        raise Exception('Cannot get VDI size')
    vdi_size = virtual_size
    if add_partition:
        # Make room for MBR.
        vdi_size += MBR_SIZE_BYTES

    vdi = create_vdi(session, sr, src_url, vdi_size, False)
    with_vdi_in_dom0(session, vdi, False,
                     lambda dev: get_vdi_(proto, netloc, url_path,
                                          username, password,
                                          add_partition, raw_image,
                                          virtual_size, '/dev/%s' % dev))
    return session.xenapi.VDI.get_uuid(vdi)


def get_vdi_(proto, netloc, url_path, username, password,
             add_partition, raw_image, virtual_size, dest):

    #vdi should not be partitioned for raw images
    if add_partition and not raw_image:
        write_partition(virtual_size, dest)

    offset = (add_partition and not raw_image and MBR_SIZE_BYTES) or 0
    get(proto, netloc, url_path, username, password, dest, offset)


def write_partition(virtual_size, dest):
    mbr_last = MBR_SIZE_SECTORS - 1
    primary_first = MBR_SIZE_SECTORS
    primary_last = MBR_SIZE_SECTORS + (virtual_size / SECTOR_SIZE) - 1

    logging.debug('Writing partition table %d %d to %s...',
                  primary_first, primary_last, dest)

    result = os.system('parted --script %s mklabel msdos' % dest)
    if result != 0:
        raise Exception('Failed to mklabel')
    result = os.system('parted --script %s mkpart primary %ds %ds' %
                       (dest, primary_first, primary_last))
    if result != 0:
        raise Exception('Failed to mkpart')

    logging.debug('Writing partition table %s done.', dest)


def find_sr(session):
    host = get_this_host(session)
    srs = session.xenapi.SR.get_all()
    for sr in srs:
        sr_rec = session.xenapi.SR.get_record(sr)
        if not ('i18n-key' in sr_rec['other_config'] and
                sr_rec['other_config']['i18n-key'] == 'local-storage'):
            continue
        for pbd in sr_rec['PBDs']:
            pbd_rec = session.xenapi.PBD.get_record(pbd)
            if pbd_rec['host'] == host:
                return sr
    return None


def get_kernel(session, args):
    src_url = exists(args, 'src_url')
    username = exists(args, 'username')
    password = exists(args, 'password')

    (proto, netloc, url_path, _, _, _) = urlparse.urlparse(src_url)

    dest = os.path.join(KERNEL_DIR, url_path[1:])

    # Paranoid check against people using ../ to do rude things.
    if os.path.commonprefix([KERNEL_DIR, dest]) != KERNEL_DIR:
        raise Exception('Illegal destination %s %s', (url_path, dest))

    dirname = os.path.dirname(dest)
    try:
        os.makedirs(dirname)
    except os.error, e:
        if e.errno != errno.EEXIST:
            raise
        if not os.path.isdir(dirname):
            raise Exception('Cannot make directory %s', dirname)

    try:
        os.remove(dest)
    except:
        pass

    get(proto, netloc, url_path, username, password, dest, 0)

    return dest


def get_content_length(proto, netloc, url_path, username, password):
    headers = make_headers('HEAD', url_path, username, password)
    return with_http_connection(
        proto, netloc,
        lambda conn: get_content_length_(url_path, headers, conn))


def get_content_length_(url_path, headers, conn):
    conn.request('HEAD', url_path, None, headers)
    response = conn.getresponse()
    if response.status != 200:
        raise Exception('%d %s' % (response.status, response.reason))

    return long(response.getheader('Content-Length', -1))


def get(proto, netloc, url_path, username, password, dest, offset):
    headers = make_headers('GET', url_path, username, password)
    download(proto, netloc, url_path, headers, dest, offset)


def make_headers(verb, url_path, username, password):
    headers = {}
    headers['Date'] = \
        time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
    headers['Authorization'] = \
        'AWS %s:%s' % (username,
                       s3_authorization(verb, url_path, password, headers))
    return headers


def s3_authorization(verb, path, password, headers):
    sha1 = hmac.new(password, digestmod=sha)
    sha1.update(plaintext(verb, path, headers))
    return base64.encodestring(sha1.digest()).strip()


def plaintext(verb, path, headers):
    return '%s\n\n\n%s\n%s' % (verb,
                           "\n".join([headers[h] for h in headers]),
                           path)


def download(proto, netloc, url_path, headers, dest, offset):
    with_http_connection(
        proto, netloc,
        lambda conn: download_(url_path, dest, offset, headers, conn))


def download_(url_path, dest, offset, headers, conn):
    conn.request('GET', url_path, None, headers)
    response = conn.getresponse()
    if response.status != 200:
        raise Exception('%d %s' % (response.status, response.reason))

    length = response.getheader('Content-Length', -1)

    with_file(
        dest, 'a',
        lambda dest_file: download_all(response, length, dest_file, offset))


def download_all(response, length, dest_file, offset):
    dest_file.seek(offset)
    i = 0
    while True:
        buf = response.read(DOWNLOAD_CHUNK_SIZE)
        if buf:
            dest_file.write(buf)
        else:
            return
        i += len(buf)
        if length != -1 and i >= length:
            return


if __name__ == '__main__':
    XenAPIPlugin.dispatch({'get_vdi': get_vdi,
                           'get_kernel': get_kernel,
                           'is_vdi_pv': is_vdi_pv})
