#!/usr/bin/python
# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Test suite for timeout_util.py"""

import datetime
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(
    os.path.abspath(__file__)))))

from chromite.lib import cros_test_lib
from chromite.lib import timeout_util


# pylint: disable=W0212,R0904


class TestTimeouts(cros_test_lib.TestCase):
  """Tests for timeout_util.Timeout."""

  def testTimeout(self):
    """Tests that we can nest Timeout correctly."""
    self.assertFalse('mock' in str(time.sleep).lower())
    with timeout_util.Timeout(30):
      with timeout_util.Timeout(20):
        with timeout_util.Timeout(1):
          self.assertRaises(timeout_util.TimeoutError, time.sleep, 10)

        # Should not raise a timeout exception as 20 > 2.
        time.sleep(1)

  def testTimeoutNested(self):
    """Tests that we still re-raise an alarm if both are reached."""
    with timeout_util.Timeout(1):
      try:
        with timeout_util.Timeout(2):
          self.assertRaises(timeout_util.TimeoutError, time.sleep, 1)

      # Craziness to catch nested timeouts.
      except timeout_util.TimeoutError:
        pass
      else:
        self.assertTrue(False, 'Should have thrown an exception')


class TestWaitFors(cros_test_lib.TestCase):
  """Tests for assorted timeout_utils WaitForX methods."""

  def setUp(self):
    self.values_ix = 0
    self.timestart = None
    self.timestop = None

  def GetFunc(self, return_values):
    """Return a functor that returns given values in sequence with each call."""
    self.values_ix = 0
    self.timestart = None
    self.timestop = None

    def _Func():
      if not self.timestart:
        self.timestart = datetime.datetime.utcnow()

      val = return_values[self.values_ix]
      self.values_ix += 1

      self.timestop = datetime.datetime.utcnow()
      return val

    return _Func

  def GetTryCount(self):
    """Get number of times func was tried."""
    return self.values_ix

  def GetTrySeconds(self):
    """Get number of seconds that span all func tries."""
    delta = self.timestop - self.timestart
    return int(delta.seconds + 0.5)

  def _TestWaitForSuccess(self, maxval, timeout, **kwargs):
    """Run through a test for WaitForSuccess."""

    func = self.GetFunc(range(20))
    def _RetryCheck(val):
      return val < maxval

    return timeout_util.WaitForSuccess(_RetryCheck, func, timeout, **kwargs)

  def _TestWaitForReturnValue(self, values, timeout, **kwargs):
    """Run through a test for WaitForReturnValue."""
    func = self.GetFunc(range(20))
    return timeout_util.WaitForReturnValue(values, func, timeout, **kwargs)

  def testWaitForSuccess1(self):
    """Test success after a few tries."""
    self.assertEquals(4, self._TestWaitForSuccess(4, 10, period=1))
    self.assertEquals(5, self.GetTryCount())
    self.assertEquals(4, self.GetTrySeconds())

  def testWaitForSuccess2(self):
    """Test timeout after a couple tries."""
    self.assertRaises(timeout_util.TimeoutError, self._TestWaitForSuccess,
                      4, 3, period=1)
    self.assertEquals(3, self.GetTryCount())
    self.assertEquals(2, self.GetTrySeconds())

  def testWaitForSuccess3(self):
    """Test success on first try."""
    self.assertEquals(0, self._TestWaitForSuccess(0, 10, period=1))
    self.assertEquals(1, self.GetTryCount())
    self.assertEquals(0, self.GetTrySeconds())

  def testWaitForSuccess4(self):
    """Test success after a few tries with longer period."""
    self.assertEquals(3, self._TestWaitForSuccess(3, 10, period=2))
    self.assertEquals(4, self.GetTryCount())
    self.assertEquals(6, self.GetTrySeconds())

  def testWaitForReturnValue1(self):
    """Test value found after a few tries."""
    self.assertEquals(4, self._TestWaitForReturnValue((4, 5), 10, period=1))
    self.assertEquals(5, self.GetTryCount())
    self.assertEquals(4, self.GetTrySeconds())

  def testWaitForReturnValue2(self):
    """Test value found on first try."""
    self.assertEquals(0, self._TestWaitForReturnValue((0, 1), 10, period=1))
    self.assertEquals(1, self.GetTryCount())
    self.assertEquals(0, self.GetTrySeconds())


if __name__ == '__main__':
  cros_test_lib.main()
