# Copyright (C) 2016 Linaro Limited
#
# Author: Tyler Baker <tyler.baker@linaro.org>
#
# This file is part of LAVA Dispatcher.
#
# LAVA Dispatcher is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# LAVA Dispatcher is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along
# with this program; if not, see <http://www.gnu.org/licenses>.

from lava_common.exceptions import (
    ConfigurationError,
    InfrastructureError,
)
from lava_dispatcher.action import (
    Action,
    Pipeline,
)
from lava_dispatcher.logical import Boot, RetryAction
from lava_dispatcher.actions.boot import BootAction
from lava_dispatcher.utils.udev import WaitDFUDeviceAction
from lava_dispatcher.connections.serial import ConnectDevice
from lava_dispatcher.power import ResetDevice
from lava_dispatcher.utils.shell import which
from lava_dispatcher.utils.strings import substitute


class DFU(Boot):

    compatibility = 4  # FIXME: change this to 5 and update test cases

    def __init__(self, parent, parameters):
        super().__init__(parent)
        self.action = BootDFU()
        self.action.section = self.action_type
        self.action.job = self.job
        parent.add_action(self.action, parameters)

    @classmethod
    def accepts(cls, device, parameters):
        if 'dfu' not in device['actions']['boot']['methods']:
            return False, '"dfu" was not in the device configuration boot methods'
        if 'method' not in parameters:
            return False, '"method" was in the parameters'
        if parameters['method'] != 'dfu':
            return False, '"method" was not "dfu"'
        if 'board_id' not in device:
            return False, '"board_id" is not in the device configuration'
        return True, 'accepted'


class BootDFU(BootAction):

    name = 'boot-dfu-image'
    description = "boot dfu image with retry"
    summary = "boot dfu image with retry"

    def populate(self, parameters):
        self.internal_pipeline = Pipeline(parent=self, job=self.job, parameters=parameters)
        self.internal_pipeline.add_action(BootDFURetry())


class BootDFURetry(RetryAction):

    name = 'boot-dfu-retry'
    description = "boot dfu image using the command line interface"
    summary = "boot dfu image"

    def populate(self, parameters):
        self.internal_pipeline = Pipeline(parent=self, job=self.job, parameters=parameters)
        self.internal_pipeline.add_action(ConnectDevice())
        self.internal_pipeline.add_action(ResetDevice())
        self.internal_pipeline.add_action(WaitDFUDeviceAction())
        self.internal_pipeline.add_action(FlashDFUAction())


class FlashDFUAction(Action):

    name = "flash-dfu"
    description = "use dfu to flash the images"
    summary = "use dfu to flash the images"

    def __init__(self):
        super().__init__()
        self.base_command = []
        self.exec_list = []
        self.board_id = '0000000000'
        self.usb_vendor_id = '0000'
        self.usb_product_id = '0000'

    def validate(self):
        super().validate()
        try:
            boot = self.job.device['actions']['boot']['methods']['dfu']
            dfu_binary = which(boot['parameters']['command'])
            self.base_command = [dfu_binary]
            self.base_command.extend(boot['parameters'].get('options', []))
            if self.job.device['board_id'] == '0000000000':
                self.errors = "[FLASH_DFU] board_id unset"
            if self.job.device['usb_vendor_id'] == '0000':
                self.errors = '[FLASH_DFU] usb_vendor_id unset'
            if self.job.device['usb_product_id'] == '0000':
                self.errors = '[FLASH_DFU] usb_product_id unset'
            self.usb_vendor_id = self.job.device['usb_vendor_id']
            self.usb_product_id = self.job.device['usb_product_id']
            self.board_id = self.job.device['board_id']
            self.base_command.extend(['--serial', self.board_id])
            self.base_command.extend(['--device', '%s:%s' % (self.usb_vendor_id, self.usb_product_id)])
        except AttributeError as exc:
            raise ConfigurationError(exc)
        except (KeyError, TypeError):
            self.errors = "Invalid parameters for %s" % self.name
        substitutions = {}
        for action in self.get_namespace_keys('download-action'):
            dfu_full_command = []
            image_arg = self.get_namespace_data(action='download-action', label=action, key='image_arg')
            action_arg = self.get_namespace_data(action='download-action', label=action, key='file')
            if not image_arg or not action_arg:
                self.errors = "Missing image_arg for %s. " % action
                continue
            if not isinstance(image_arg, str):
                self.errors = "image_arg is not a string (try quoting it)"
                continue
            substitutions["{%s}" % action] = action_arg
            dfu_full_command.extend(self.base_command)
            dfu_full_command.extend(substitute([image_arg], substitutions))
            self.exec_list.append(dfu_full_command)
        if len(self.exec_list) < 1:
            self.errors = "No DFU command to execute"

    def run(self, connection, max_end_time):
        connection = self.get_namespace_data(
            action='shared', label='shared', key='connection', deepcopy=False)
        connection = super().run(connection, max_end_time)
        count = 1
        for dfu_command in self.exec_list:
            if count == (len(self.exec_list)):
                if self.job.device['actions']['boot']['methods']['dfu'].get('reset_works', True):
                    dfu_command.extend(['--reset'])
            dfu = ' '.join(dfu_command)
            output = self.run_command(dfu.split(' '))
            if output:
                if "No error condition is present\nDone!\n" not in output:
                    raise InfrastructureError("command failed: %s" % dfu)
            else:
                raise InfrastructureError("command failed: %s" % dfu)
            count += 1
        self.set_namespace_data(action='shared', label='shared', key='connection', value=connection)
        return connection
