aurora-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject git commit: Adding grouping for the sla commands.
Date Wed, 28 May 2014 01:01:21 GMT
Repository: incubator-aurora
Updated Branches:
  refs/heads/master 34173d1cd -> 2753763e7


Adding grouping for the sla commands.

Bugs closed: AURORA-441

Reviewed at https://reviews.apache.org/r/21741/


Project: http://git-wip-us.apache.org/repos/asf/incubator-aurora/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-aurora/commit/2753763e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-aurora/tree/2753763e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-aurora/diff/2753763e

Branch: refs/heads/master
Commit: 2753763e78716cd4b3bdbb5dcb573c9989cfc6a4
Parents: 34173d1
Author: Maxim Khutornenko <maxim@apache.org>
Authored: Tue May 27 18:00:51 2014 -0700
Committer: Maxim Khutornenko <maxim@apache.org>
Committed: Tue May 27 18:00:51 2014 -0700

----------------------------------------------------------------------
 .../apache/aurora/admin/host_maintenance.py     | 28 ++----
 src/main/python/apache/aurora/client/api/sla.py | 98 +++++++++++++-------
 src/main/python/apache/aurora/client/base.py    | 39 ++++++++
 .../apache/aurora/client/commands/admin.py      | 94 +++++++++++--------
 .../aurora/client/commands/maintenance.py       | 21 ++---
 .../aurora/admin/test_host_maintenance.py       |  6 +-
 .../python/apache/aurora/client/api/test_sla.py | 92 +++++++++++++++---
 .../aurora/client/commands/test_admin_sla.py    | 93 +++++++++++++------
 8 files changed, 325 insertions(+), 146 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/main/python/apache/aurora/admin/host_maintenance.py
----------------------------------------------------------------------
diff --git a/src/main/python/apache/aurora/admin/host_maintenance.py b/src/main/python/apache/aurora/admin/host_maintenance.py
index 15c2b57..ca26de1 100644
--- a/src/main/python/apache/aurora/admin/host_maintenance.py
+++ b/src/main/python/apache/aurora/admin/host_maintenance.py
@@ -19,40 +19,26 @@ from twitter.common import log
 from twitter.common.quantity import Amount, Time
 
 from apache.aurora.client.api import AuroraClientAPI
-from apache.aurora.client.base import check_and_log_response
+from apache.aurora.client.base import (
+    check_and_log_response,
+    DEFAULT_GROUPING,
+    group_hosts,
+    GROUPING_FUNCTIONS
+)
 
 from gen.apache.aurora.api.ttypes import Hosts, MaintenanceMode
 
 
-def group_by_host(hostname):
-  return hostname
-
-
 class HostMaintenance(object):
   """Submit requests to the scheduler to put hosts into and out of maintenance
   mode so they can be operated upon without causing LOST tasks.
   """
 
-  DEFAULT_GROUPING = 'by_host'
-  GROUPING_FUNCTIONS = {
-    'by_host': group_by_host,
-  }
   START_MAINTENANCE_DELAY = Amount(30, Time.SECONDS)
 
   @classmethod
-  def group_hosts(cls, hostnames, grouping_function=DEFAULT_GROUPING):
-    try:
-      grouping_function = cls.GROUPING_FUNCTIONS[grouping_function]
-    except KeyError:
-      raise ValueError('Unknown grouping function %s!' % grouping_function)
-    groups = defaultdict(set)
-    for hostname in hostnames:
-      groups[grouping_function(hostname)].add(hostname)
-    return groups
-
-  @classmethod
   def iter_batches(cls, hostnames, grouping_function=DEFAULT_GROUPING):
-    groups = cls.group_hosts(hostnames, grouping_function)
+    groups = group_hosts(hostnames, grouping_function)
     groups = sorted(groups.items(), key=lambda v: v[0])
     for group in groups:
       yield Hosts(group[1])

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/main/python/apache/aurora/client/api/sla.py
----------------------------------------------------------------------
diff --git a/src/main/python/apache/aurora/client/api/sla.py b/src/main/python/apache/aurora/client/api/sla.py
index ce48cb0..d15491a 100644
--- a/src/main/python/apache/aurora/client/api/sla.py
+++ b/src/main/python/apache/aurora/client/api/sla.py
@@ -17,7 +17,7 @@ import time
 from collections import defaultdict, namedtuple
 from copy import deepcopy
 
-from apache.aurora.client.base import log_response
+from apache.aurora.client.base import DEFAULT_GROUPING, group_hosts, log_response
 from apache.aurora.common.aurora_job_key import AuroraJobKey
 
 from gen.apache.aurora.api.constants import LIVE_STATES
@@ -151,10 +151,15 @@ class DomainUpTimeSlaVector(object):
     self._cluster = cluster
     self._tasks = tasks
     self._now = time.time()
-    self._tasks_by_job, self._jobs_by_host = self._init_mappings(min_instance_count)
+    self._tasks_by_job, self._jobs_by_host, self._hosts_by_job = self._init_mappings(
+        min_instance_count)
     self._host_filter = hosts
 
-  def get_safe_hosts(self, percentage, duration, job_limits=None):
+  def get_safe_hosts(self,
+      percentage,
+      duration,
+      job_limits=None,
+      grouping_function=DEFAULT_GROUPING):
     """Returns hosts safe to restart with respect to their job SLA.
        Every host is analyzed separately without considering other job hosts.
 
@@ -163,32 +168,34 @@ class DomainUpTimeSlaVector(object):
        duration -- default task uptime duration in seconds. Used if job_limits mapping is
not found.
        job_limits -- optional SLA override map. Key: job key. Value JobUpTimeLimit. If specified,
                      replaces default percentage/duration within the job context.
+       grouping_function -- grouping function to use to group hosts.
     """
-    safe_hosts = defaultdict(list)
-    for host, job_keys in self._jobs_by_host.items():
-      if self._host_filter and host not in self._host_filter:
-        continue
+    safe_groups = []
+    for hosts, job_keys in self._iter_groups(
+        self._jobs_by_host.keys(), grouping_function, self._host_filter):
 
-      safe_limits = []
+      safe_hosts = defaultdict(list)
       for job_key in job_keys:
+        job_hosts = hosts.intersection(self._hosts_by_job[job_key])
         job_duration = duration
         job_percentage = percentage
         if job_limits and job_key in job_limits:
           job_duration = job_limits[job_key].duration_secs
           job_percentage = job_limits[job_key].percentage
 
-        filtered_percentage, _, _ = self._simulate_host_down(job_key, host, job_duration)
-        safe_limits.append(self.JobUpTimeLimit(job_key, filtered_percentage, job_duration))
-
+        filtered_percentage, _, _ = self._simulate_hosts_down(job_key, job_hosts, job_duration)
         if filtered_percentage < job_percentage:
           break
 
+        for host in job_hosts:
+          safe_hosts[host].append(self.JobUpTimeLimit(job_key, filtered_percentage, job_duration))
+
       else:
-        safe_hosts[host] = safe_limits
+        safe_groups.append(safe_hosts)
 
-    return safe_hosts
+    return safe_groups
 
-  def probe_hosts(self, percentage, duration):
+  def probe_hosts(self, percentage, duration, grouping_function=DEFAULT_GROUPING):
     """Returns predicted job SLAs following the removal of provided hosts.
 
        For every given host creates a list of JobUpTimeDetails with predicted job SLA details
@@ -199,12 +206,15 @@ class DomainUpTimeSlaVector(object):
        Arguments:
        percentage -- task up count percentage.
        duration -- task uptime duration in seconds.
+       grouping_function -- grouping function to use to group hosts.
     """
-    probed_hosts = defaultdict(list)
-    for host in self._host_filter or []:
-      for job_key in self._jobs_by_host.get(host, []):
-        filtered_percentage, total_count, filtered_vector = self._simulate_host_down(
-            job_key, host, duration)
+    probed_groups = []
+    for hosts, job_keys in self._iter_groups(self._host_filter or [], grouping_function):
+      probed_hosts = defaultdict(list)
+      for job_key in job_keys:
+        job_hosts = hosts.intersection(self._hosts_by_job[job_key])
+        filtered_percentage, total_count, filtered_vector = self._simulate_hosts_down(
+            job_key, job_hosts, duration)
 
         # Calculate wait time to SLA in case down host violates job's SLA.
         if filtered_percentage < percentage:
@@ -214,21 +224,41 @@ class DomainUpTimeSlaVector(object):
           safe = True
           wait_to_sla = 0
 
-        probed_hosts[host].append(
-            self.JobUpTimeDetails(job_key, filtered_percentage, safe, wait_to_sla))
+        for host in job_hosts:
+          probed_hosts[host].append(
+              self.JobUpTimeDetails(job_key, filtered_percentage, safe, wait_to_sla))
+
+      if probed_hosts:
+        probed_groups.append(probed_hosts)
+
+    return probed_groups
+
+  def _iter_groups(self, hosts_to_group, grouping_function, host_filter=None):
+    groups = group_hosts(hosts_to_group, grouping_function)
+    for _, hosts in sorted(groups.items(), key=lambda v: v[0]):
+      job_keys = set()
+      for host in hosts:
+        if host_filter and host not in self._host_filter:
+          continue
+        job_keys = job_keys.union(self._jobs_by_host.get(host, set()))
+      yield hosts, job_keys
+
+  def _create_group_results(self, group, uptime_details):
+    result = defaultdict(list)
+    for host in group.keys():
+      result[host].append(uptime_details)
 
-    return probed_hosts
 
-  def _simulate_host_down(self, job_key, host, duration):
+  def _simulate_hosts_down(self, job_key, hosts, duration):
     unfiltered_tasks = self._tasks_by_job[job_key]
 
     # Get total job task count to use in SLA calculation.
     total_count = len(unfiltered_tasks)
 
-    # Get a list of job tasks that would remain after the affected host goes down
+    # Get a list of job tasks that would remain after the affected hosts go down
     # and create an SLA vector with these tasks.
     filtered_tasks = [task for task in unfiltered_tasks
-                      if task.assignedTask.slaveHost != host]
+                      if task.assignedTask.slaveHost not in hosts]
     filtered_vector = JobUpTimeSlaVector(filtered_tasks, self._now)
 
     # Calculate the SLA that would be in effect should the host go down.
@@ -237,20 +267,24 @@ class DomainUpTimeSlaVector(object):
     return filtered_percentage, total_count, filtered_vector
 
   def _init_mappings(self, count):
-    jobs = defaultdict(list)
+    tasks_by_job = defaultdict(list)
     for task in self._tasks:
       if task.assignedTask.task.production:
-        jobs[job_key_from_scheduled(task, self._cluster)].append(task)
+        tasks_by_job[job_key_from_scheduled(task, self._cluster)].append(task)
 
     # Filter jobs by the min instance count.
-    jobs = defaultdict(list, ((job, tasks) for job, tasks in jobs.items() if len(tasks) >=
count))
+    tasks_by_job = defaultdict(list, ((job, tasks) for job, tasks
+        in tasks_by_job.items() if len(tasks) >= count))
 
-    hosts = defaultdict(list)
-    for job_key, tasks in jobs.items():
+    jobs_by_host = defaultdict(set)
+    hosts_by_job = defaultdict(set)
+    for job_key, tasks in tasks_by_job.items():
       for task in tasks:
-        hosts[task.assignedTask.slaveHost].append(job_key)
+        host = task.assignedTask.slaveHost
+        jobs_by_host[host].add(job_key)
+        hosts_by_job[job_key].add(host)
 
-    return jobs, hosts
+    return tasks_by_job, jobs_by_host, hosts_by_job
 
 
 class Sla(object):

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/main/python/apache/aurora/client/base.py
----------------------------------------------------------------------
diff --git a/src/main/python/apache/aurora/client/base.py b/src/main/python/apache/aurora/client/base.py
index d54c7cb..ef0855d 100644
--- a/src/main/python/apache/aurora/client/base.py
+++ b/src/main/python/apache/aurora/client/base.py
@@ -106,6 +106,45 @@ HOSTS_OPTION = optparse.Option(
     help='Comma separated list of hosts')
 
 
+def group_by_host(hostname):
+  return hostname
+
+DEFAULT_GROUPING = 'by_host'
+GROUPING_FUNCTIONS = {
+    'by_host': group_by_host,
+}
+
+def add_grouping(name, function):
+  GROUPING_FUNCTIONS[name] = function
+
+def remove_grouping(name):
+  GROUPING_FUNCTIONS.pop(name)
+
+def get_grouping_or_die(grouping_function):
+  try:
+    return GROUPING_FUNCTIONS[grouping_function]
+  except KeyError:
+    die('Unknown grouping function %s. Must be one of: %s'
+        % (grouping_function, GROUPING_FUNCTIONS.keys()))
+
+def group_hosts(hostnames, grouping_function=DEFAULT_GROUPING):
+  grouping_function = get_grouping_or_die(grouping_function)
+  groups = defaultdict(set)
+  for hostname in hostnames:
+    groups[grouping_function(hostname)].add(hostname)
+  return groups
+
+
+GROUPING_OPTION = optparse.Option(
+    '--grouping',
+    type='string',
+    metavar='GROUPING',
+    default=DEFAULT_GROUPING,
+    dest='grouping',
+    help='Grouping function to use to group hosts.  Options: %s.  Default: %%default' % (
+        ', '.join(GROUPING_FUNCTIONS.keys())))
+
+
 def parse_host_list(host_list):
   hosts = [hostname.strip() for hostname in host_list.split(",")]
   if not hosts:

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/main/python/apache/aurora/client/commands/admin.py
----------------------------------------------------------------------
diff --git a/src/main/python/apache/aurora/client/commands/admin.py b/src/main/python/apache/aurora/client/commands/admin.py
index b4c0483..919eea9 100644
--- a/src/main/python/apache/aurora/client/commands/admin.py
+++ b/src/main/python/apache/aurora/client/commands/admin.py
@@ -31,6 +31,8 @@ from apache.aurora.client.base import (
     check_and_log_response,
     die,
     FILENAME_OPTION,
+    get_grouping_or_die,
+    GROUPING_OPTION,
     HOSTS_OPTION,
     parse_hosts,
     parse_hosts_optional,
@@ -43,12 +45,7 @@ from apache.aurora.common.shellify import shellify
 from gen.apache.aurora.api.constants import ACTIVE_STATES, TERMINAL_STATES
 from gen.apache.aurora.api.ttypes import ResponseCode, ScheduleStatus, TaskQuery
 
-"""Command-line client for managing admin-only interactions with the aurora scheduler.
-"""
-
-
-
-
+"""Command-line client for managing admin-only interactions with the aurora scheduler."""
 
 
 MIN_SLA_INSTANCE_COUNT = optparse.Option(
@@ -333,36 +330,40 @@ def scheduler_snapshot(cluster):
 
 
 @app.command
-@app.command_option('-I', '--include_file', dest='include_filename', default=None,
-    help='Inclusion filter. An optional text file listing host names (one per line)'
-         'to include into the result set if found.')
-@app.command_option('-i', '--include_hosts', dest='include_hosts', default=None,
-    help='Inclusion filter. An optional comma-separated list of host names'
-         'to include into the result set if found.')
 @app.command_option('-X', '--exclude_file', dest='exclude_filename', default=None,
     help='Exclusion filter. An optional text file listing host names (one per line)'
          'to exclude from the result set if found.')
 @app.command_option('-x', '--exclude_hosts', dest='exclude_hosts', default=None,
     help='Exclusion filter. An optional comma-separated list of host names'
          'to exclude from the result set if found.')
+@app.command_option(GROUPING_OPTION)
+@app.command_option('-I', '--include_file', dest='include_filename', default=None,
+    help='Inclusion filter. An optional text file listing host names (one per line)'
+         'to include into the result set if found.')
+@app.command_option('-i', '--include_hosts', dest='include_hosts', default=None,
+    help='Inclusion filter. An optional comma-separated list of host names'
+         'to include into the result set if found.')
 @app.command_option('-l', '--list_jobs', dest='list_jobs', default=False, action='store_true',
     help='Lists all affected job keys with projected new SLAs if their tasks get killed'
          'in the following column format:\n'
          'HOST  JOB  PREDICTED_SLA  DURATION_SECONDS')
+@app.command_option(MIN_SLA_INSTANCE_COUNT)
 @app.command_option('-o', '--override_file', dest='override_filename', default=None,
     help='An optional text file to load job specific SLAs that will override'
          'cluster-wide command line percentage and duration values.'
          'The file can have multiple lines in the following format:'
          '"cluster/role/env/job percentage duration". Example: cl/mesos/prod/labrat 95 2h')
-@app.command_option(MIN_SLA_INSTANCE_COUNT)
 @requires.exactly('cluster', 'percentage', 'duration')
 def sla_list_safe_domain(cluster, percentage, duration):
   """usage: sla_list_safe_domain
-            [--exclude_hosts=filename]
-            [--include_hosts=filename]
+            [--exclude_file=FILENAME]
+            [--exclude_hosts=HOSTS]
+            [--grouping=GROUPING]
+            [--include_file=FILENAME]
+            [--include_hosts=HOSTS]
             [--list_jobs]
-            [--override_jobs=filename]
-            [--min_job_instance_count]
+            [--min_job_instance_count=COUNT]
+            [--override_jobs=FILENAME]
             cluster percentage duration
 
   Returns a list of relevant hosts where it would be safe to kill
@@ -376,6 +377,12 @@ def sla_list_safe_domain(cluster, percentage, duration):
   Applied to all jobs except those listed in --override_jobs file.
   Format: XdYhZmWs (each field is optional but must be in that order.)
   Examples: 5m, 1d3h45m.
+
+  NOTE: if --grouping option is specified and is set to anything other than
+        default (by_host) the results will be processed and filtered based
+        on the grouping function on a all-or-nothing basis. In other words,
+        the group is 'safe' IFF it is safe to kill tasks on all hosts in the
+        group at the same time.
   """
   def parse_jobs_file(filename):
     result = {}
@@ -403,36 +410,41 @@ def sla_list_safe_domain(cluster, percentage, duration):
   exclude_hosts = parse_hosts_optional(options.exclude_hosts, options.exclude_filename)
   include_hosts = parse_hosts_optional(options.include_hosts, options.include_filename)
   override_jobs = parse_jobs_file(options.override_filename) if options.override_filename
else {}
+  get_grouping_or_die(options.grouping)
 
   vector = AuroraClientAPI(
       CLUSTERS[cluster],
       options.verbosity).sla_get_safe_domain_vector(options.min_instance_count, include_hosts)
-  hosts = vector.get_safe_hosts(sla_percentage, sla_duration.as_(Time.SECONDS), override_jobs)
+  groups = vector.get_safe_hosts(sla_percentage, sla_duration.as_(Time.SECONDS),
+      override_jobs, options.grouping)
 
   results = []
-  for host in sorted(hosts.keys()):
-    if exclude_hosts and host in exclude_hosts:
-      continue
-
-    if options.list_jobs:
-      results.append('\n'.join(['%s\t%s\t%.2f\t%d' %
-          (host, d.job.to_path(), d.percentage, d.duration_secs) for d in sorted(hosts[host])]))
-    else:
-      results.append('%s' % host)
+  for group in groups:
+    for host in sorted(group.keys()):
+      if exclude_hosts and host in exclude_hosts:
+        continue
+
+      if options.list_jobs:
+        results.append('\n'.join(['%s\t%s\t%.2f\t%d' %
+            (host, d.job.to_path(), d.percentage, d.duration_secs) for d in sorted(group[host])]))
+      else:
+        results.append('%s' % host)
 
   print_results(results)
 
 
 @app.command
 @app.command_option(FILENAME_OPTION)
+@app.command_option(GROUPING_OPTION)
 @app.command_option(HOSTS_OPTION)
 @app.command_option(MIN_SLA_INSTANCE_COUNT)
 @requires.exactly('cluster', 'percentage', 'duration')
 def sla_probe_hosts(cluster, percentage, duration):
   """usage: sla_probe_hosts
-            [--filename=filename]
-            [--hosts=hosts]
-            [--min_job_instance_count]
+            [--filename=FILENAME]
+            [--grouping=GROUPING]
+            [--hosts=HOSTS]
+            [--min_job_instance_count=COUNT]
             cluster percentage duration
 
   Probes individual hosts with respect to their job SLA.
@@ -455,22 +467,24 @@ def sla_probe_hosts(cluster, percentage, duration):
   sla_percentage = parse_sla_percentage(percentage)
   sla_duration = parse_time(duration)
   hosts = parse_hosts(options.filename, options.hosts)
+  get_grouping_or_die(options.grouping)
 
   vector = AuroraClientAPI(
       CLUSTERS[cluster],
       options.verbosity).sla_get_safe_domain_vector(options.min_instance_count, hosts)
-  probed_hosts = vector.probe_hosts(sla_percentage, sla_duration.as_(Time.SECONDS))
+  groups = vector.probe_hosts(sla_percentage, sla_duration.as_(Time.SECONDS), options.grouping)
 
   results = []
-  for host, job_details in sorted(probed_hosts.items()):
-    results.append('\n'.join(
-        ['%s\t%s\t%.2f\t%s\t%s' %
-            (host,
-             d.job.to_path(),
-             d.predicted_percentage,
-             d.safe,
-             'n/a' if d.safe_in_secs is None else d.safe_in_secs)
-            for d in sorted(job_details)]))
+  for group in groups:
+    for host, job_details in sorted(group.items()):
+      results.append('\n'.join(
+          ['%s\t%s\t%.2f\t%s\t%s' %
+              (host,
+               d.job.to_path(),
+               d.predicted_percentage,
+               d.safe,
+               'n/a' if d.safe_in_secs is None else d.safe_in_secs)
+              for d in sorted(job_details)]))
 
   print_results(results)
 

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/main/python/apache/aurora/client/commands/maintenance.py
----------------------------------------------------------------------
diff --git a/src/main/python/apache/aurora/client/commands/maintenance.py b/src/main/python/apache/aurora/client/commands/maintenance.py
index 72bde47..f6ebe3b 100644
--- a/src/main/python/apache/aurora/client/commands/maintenance.py
+++ b/src/main/python/apache/aurora/client/commands/maintenance.py
@@ -19,19 +19,17 @@ import subprocess
 from twitter.common import app, log
 
 from apache.aurora.admin.host_maintenance import HostMaintenance
-from apache.aurora.client.base import die, FILENAME_OPTION, HOSTS_OPTION, parse_hosts, requires
+from apache.aurora.client.base import (
+    die,
+    FILENAME_OPTION,
+    get_grouping_or_die,
+    GROUPING_OPTION,
+    HOSTS_OPTION,
+    parse_hosts,
+    requires
+)
 from apache.aurora.common.clusters import CLUSTERS
 
-GROUPING_OPTION = optparse.Option(
-    '--grouping',
-    type='choice',
-    choices=HostMaintenance.GROUPING_FUNCTIONS.keys(),
-    metavar='GROUPING',
-    default=HostMaintenance.DEFAULT_GROUPING,
-    dest='grouping',
-    help='Grouping function to use to group hosts.  Options: %s.  Default: %%default' % (
-        ', '.join(HostMaintenance.GROUPING_FUNCTIONS.keys())))
-
 
 @app.command
 @app.command_option(FILENAME_OPTION)
@@ -78,6 +76,7 @@ def perform_maintenance_hosts(cluster):
   """
   options = app.get_options()
   drainable_hosts = parse_hosts(options.filename, options.hosts)
+  get_grouping_or_die(options.grouping)
 
   if options.post_drain_script:
     if not os.path.exists(options.post_drain_script):

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/test/python/apache/aurora/admin/test_host_maintenance.py
----------------------------------------------------------------------
diff --git a/src/test/python/apache/aurora/admin/test_host_maintenance.py b/src/test/python/apache/aurora/admin/test_host_maintenance.py
index 2713bfe..ed0782b 100644
--- a/src/test/python/apache/aurora/admin/test_host_maintenance.py
+++ b/src/test/python/apache/aurora/admin/test_host_maintenance.py
@@ -18,6 +18,7 @@ import mock
 import pytest
 
 from apache.aurora.admin.host_maintenance import HostMaintenance
+from apache.aurora.client.base import add_grouping, remove_grouping
 from apache.aurora.common.cluster import Cluster
 
 from gen.apache.aurora.api.ttypes import Hosts, Response, ResponseCode
@@ -55,8 +56,7 @@ def rack_grouping(hostname):
 
 
 def test_rack_grouping():
-  old_grouping_functions = HostMaintenance.GROUPING_FUNCTIONS.copy()
-  HostMaintenance.GROUPING_FUNCTIONS['by_rack'] = rack_grouping
+  add_grouping('by_rack', rack_grouping)
 
   example_host_list = [
     'west-aaa-001.example.com',
@@ -79,4 +79,4 @@ def test_rack_grouping():
     ]))
 
   finally:
-    HostMaintenance.GROUPING_FUNCTIONS = old_grouping_functions
+    remove_grouping('by_rack')

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/test/python/apache/aurora/client/api/test_sla.py
----------------------------------------------------------------------
diff --git a/src/test/python/apache/aurora/client/api/test_sla.py b/src/test/python/apache/aurora/client/api/test_sla.py
index f723410..9c3bb6d 100644
--- a/src/test/python/apache/aurora/client/api/test_sla.py
+++ b/src/test/python/apache/aurora/client/api/test_sla.py
@@ -15,9 +15,11 @@
 import time
 import unittest
 
+from contextlib import contextmanager
 from mock import call, Mock, patch
 
 from apache.aurora.client.api.sla import DomainUpTimeSlaVector, JobUpTimeSlaVector, Sla,
task_query
+from apache.aurora.client.base import add_grouping, DEFAULT_GROUPING, remove_grouping
 from apache.aurora.common.aurora_job_key import AuroraJobKey
 from apache.aurora.common.cluster import Cluster
 
@@ -38,6 +40,10 @@ from gen.apache.aurora.api.ttypes import (
 )
 
 
+def rack_grouping(hostname):
+  return hostname.split('-')[1]
+
+
 class SlaTest(unittest.TestCase):
   def setUp(self):
     self._scheduler = Mock()
@@ -103,20 +109,22 @@ class SlaTest(unittest.TestCase):
     )
     self.expect_task_status_call_job_scoped()
 
-  def assert_safe_domain_result(self, host, percentage, duration, in_limit=None, out_limit=None):
+  def assert_safe_domain_result(self, host, percentage, duration, in_limit=None, out_limit=None,
+      grouping=DEFAULT_GROUPING):
     vector = self._sla.get_domain_uptime_vector(self._cluster, self._min_count)
-    result = vector.get_safe_hosts(percentage, duration, in_limit)
+    result = vector.get_safe_hosts(percentage, duration, in_limit, grouping)
     assert 1 == len(result), ('Expected length:%s Actual length:%s' % (1, len(result)))
-    assert host in result, ('Expected host:%s not found in result' % host)
+    assert host in result[0], ('Expected host:%s not found in result' % host)
     if out_limit:
-      assert result[host][0].job.name == out_limit.job.name, (
-          'Expected job:%s Actual:%s' % (out_limit.job.name, result[host][0].job.name)
+      job_details = result[0][host][0]
+      assert job_details.job.name == out_limit.job.name, (
+          'Expected job:%s Actual:%s' % (out_limit.job.name, job_details.job.name)
       )
-      assert result[host][0].percentage == out_limit.percentage, (
-        'Expected %%:%s Actual %%:%s' % (out_limit.percentage, result[host][0].percentage)
+      assert job_details.percentage == out_limit.percentage, (
+        'Expected %%:%s Actual %%:%s' % (out_limit.percentage, job_details.percentage)
       )
-      assert result[host][0].duration == out_limit.duration, (
-        'Expected duration:%s Actual duration:%s' % (out_limit.duration, result[host][0].duration)
+      assert job_details.duration == out_limit.duration, (
+        'Expected duration:%s Actual duration:%s' % (out_limit.duration, job_details.duration)
       )
     self.expect_task_status_call_cluster_scoped()
 
@@ -126,10 +134,21 @@ class SlaTest(unittest.TestCase):
     assert len(hosts) == len(result), ('Expected length:%s Actual length:%s' % (1, len(result)))
     return result
 
+  def assert_probe_hosts_result_with_grouping(self, hosts, percent, duration, group_count):
+    vector = self._sla.get_domain_uptime_vector(self._cluster, self._min_count, hosts)
+    result = vector.probe_hosts(percent, duration, 'by_rack')
+    assert group_count == len(result), ('Expected length:%s Actual length:%s'
+        % (group_count, len(result)))
+    return result
+
   def assert_probe_host_job_details(self, result, host, f_percent, safe=True, wait_time=0):
-    assert host in result, ('Expected host:%s not found in result' % host)
+    job_details = None
+    for group in result:
+      if host in group:
+        job_details = group[host][0]
+        break
 
-    job_details = result[host][0]
+    assert job_details, ('Expected host:%s not found in result' % host)
     assert job_details.job.name == self._name, (
       'Expected job:%s Actual:%s' % (self._name, job_details.job.name)
     )
@@ -155,6 +174,11 @@ class SlaTest(unittest.TestCase):
   def expect_task_status_call_cluster_scoped(self):
     self._scheduler.getTasksStatus.assert_called_with(TaskQuery(statuses=LIVE_STATES))
 
+  @contextmanager
+  def group_by_rack(self):
+    add_grouping('by_rack', rack_grouping)
+    yield
+    remove_grouping('by_rack')
 
   def test_count_0(self):
     self.mock_get_tasks([])
@@ -293,6 +317,26 @@ class SlaTest(unittest.TestCase):
     assert 0 == len(vector.get_safe_hosts(50, 200)), 'Length must be empty.'
     self.expect_task_status_call_cluster_scoped()
 
+  def test_domain_uptime_with_grouping(self):
+    with self.group_by_rack():
+      self.mock_get_tasks([
+          self.create_task(100, 1, 'cl-r1-h01', self._name),
+          self.create_task(200, 3, 'cl-r2-h03', self._name),
+      ])
+      self.assert_safe_domain_result('cl-r1-h01', 50, 150, grouping='by_rack')
+
+  def test_domain_uptime_with_grouping_not_safe(self):
+    with self.group_by_rack():
+      self.mock_get_tasks([
+          self.create_task(200, 1, 'cl-r1-h01', self._name),
+          self.create_task(100, 2, 'cl-r1-h02', self._name),
+          self.create_task(200, 3, 'cl-r2-h03', self._name),
+          self.create_task(100, 4, 'cl-r2-h04', self._name),
+      ])
+      vector = self._sla.get_domain_uptime_vector(self._cluster, self._min_count)
+      assert 0 == len(vector.get_safe_hosts(50, 150, None, 'by_rack')), 'Length must be empty.'
+      self.expect_task_status_call_cluster_scoped()
+
 
   def test_probe_hosts_no_hosts(self):
     self.mock_get_tasks([])
@@ -361,6 +405,32 @@ class SlaTest(unittest.TestCase):
     vector = self._sla.get_domain_uptime_vector(self._cluster, ['h1', 'h2'])
     assert 0 == len(vector.probe_hosts(90, 200))
 
+  def test_probe_hosts_with_grouping_safe(self):
+    with self.group_by_rack():
+      self.mock_get_tasks([
+          self.create_task(100, 1, 'cl-r1-h01', self._name),
+          self.create_task(100, 3, 'cl-r2-h03', self._name),
+      ])
+      result = self.assert_probe_hosts_result_with_grouping(
+          ['cl-r1-h01', 'cl-r2-h03'], 50, 100, 2)
+      self.assert_probe_host_job_details(result, 'cl-r1-h01', 50.0)
+      self.assert_probe_host_job_details(result, 'cl-r2-h03', 50.0)
+
+  def test_probe_hosts_with_grouping_not_safe(self):
+    with self.group_by_rack():
+      self.mock_get_tasks([
+          self.create_task(100, 1, 'cl-r1-h01', self._name),
+          self.create_task(200, 2, 'cl-r1-h02', self._name),
+          self.create_task(100, 3, 'cl-r2-h03', self._name),
+          self.create_task(200, 4, 'cl-r2-h04', self._name),
+      ])
+      result = self.assert_probe_hosts_result_with_grouping(
+          ['cl-r1-h01', 'cl-r1-h02', 'cl-r2-h03', 'cl-r2-h04'], 50, 200, 2)
+      self.assert_probe_host_job_details(result, 'cl-r1-h01', 25.0, False, 100)
+      self.assert_probe_host_job_details(result, 'cl-r1-h02', 25.0, False, 100)
+      self.assert_probe_host_job_details(result, 'cl-r2-h03', 25.0, False, 100)
+      self.assert_probe_host_job_details(result, 'cl-r2-h04', 25.0, False, 100)
+
 
   def test_get_domain_uptime_vector_with_hosts(self):
     with patch('apache.aurora.client.api.sla.task_query', return_value=TaskQuery()) as (mock_query):

http://git-wip-us.apache.org/repos/asf/incubator-aurora/blob/2753763e/src/test/python/apache/aurora/client/commands/test_admin_sla.py
----------------------------------------------------------------------
diff --git a/src/test/python/apache/aurora/client/commands/test_admin_sla.py b/src/test/python/apache/aurora/client/commands/test_admin_sla.py
index d5b5dff..84a91d5 100644
--- a/src/test/python/apache/aurora/client/commands/test_admin_sla.py
+++ b/src/test/python/apache/aurora/client/commands/test_admin_sla.py
@@ -20,6 +20,7 @@ from twitter.common.contextutil import temporary_file
 
 from apache.aurora.client.api import AuroraClientAPI
 from apache.aurora.client.api.sla import DomainUpTimeSlaVector
+from apache.aurora.client.base import DEFAULT_GROUPING
 from apache.aurora.client.commands.admin import sla_list_safe_domain, sla_probe_hosts
 from apache.aurora.client.commands.util import AuroraClientCommandTest
 from apache.aurora.common.aurora_job_key import AuroraJobKey
@@ -31,10 +32,10 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 
   @classmethod
   def setup_mock_options(cls, exclude=None, include=None, override=None,
-                         exclude_list=None, include_list=None, list_jobs=False):
+                         exclude_list=None, include_list=None, list_jobs=False, grouping=None):
     mock_options = Mock(spec=['exclude_filename', 'exclude_hosts', 'include_filename',
         'include_hosts', 'override_filename', 'list_jobs', 'verbosity', 'disable_all_hooks',
-        'min_instance_count'])
+        'min_instance_count', 'grouping'])
 
     mock_options.exclude_filename = exclude
     mock_options.exclude_hosts = exclude_list
@@ -45,6 +46,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
     mock_options.verbosity = False
     mock_options.disable_all_hooks = False
     mock_options.min_instance_count = MIN_INSTANCE_COUNT
+    mock_options.grouping = grouping or DEFAULT_GROUPING
     return mock_options
 
   @classmethod
@@ -54,7 +56,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       host_name = 'h%s' % i
       job = AuroraJobKey.from_path('west/role/env/job%s' % i)
       hosts[host_name].append(DomainUpTimeSlaVector.JobUpTimeLimit(job, percentage, duration))
-    return hosts
+    return [hosts]
 
   @classmethod
   def create_mock_vector(cls, result):
@@ -67,7 +69,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
     mock_options = self.setup_mock_options()
     mock_vector = self.create_mock_vector(self.create_hosts(3, 80, 100))
     with contextlib.nested(
-        patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+        patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+            new=Mock(spec=AuroraClientAPI)),
         patch('apache.aurora.client.commands.admin.print_results'),
         patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
         patch('twitter.common.app.get_options', return_value=mock_options)
@@ -80,7 +83,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       mock_api.return_value.sla_get_safe_domain_vector.return_value = mock_vector
       sla_list_safe_domain(['west', '50', '100s'])
 
-      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {})
+      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {}, DEFAULT_GROUPING)
       mock_print_results.assert_called_once_with(['h0', 'h1', 'h2'])
 
   def test_safe_domain_exclude_hosts(self):
@@ -91,7 +94,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       fp.flush()
       mock_options = self.setup_mock_options(exclude=fp.name)
       with contextlib.nested(
-          patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+          patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+              new=Mock(spec=AuroraClientAPI)),
           patch('apache.aurora.client.commands.admin.print_results'),
           patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
           patch('twitter.common.app.get_options', return_value=mock_options)
@@ -105,7 +109,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 
         sla_list_safe_domain(['west', '50', '100s'])
 
-        mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {})
+        mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {}, DEFAULT_GROUPING)
         mock_print_results.assert_called_once_with(['h0', 'h2'])
 
   def test_safe_domain_exclude_hosts_from_list(self):
@@ -113,7 +117,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
     mock_vector = self.create_mock_vector(self.create_hosts(3, 80, 100))
     mock_options = self.setup_mock_options(exclude_list=','.join(['h0', 'h1']))
     with contextlib.nested(
-        patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+        patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+            new=Mock(spec=AuroraClientAPI)),
         patch('apache.aurora.client.commands.admin.print_results'),
         patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
         patch('twitter.common.app.get_options', return_value=mock_options)
@@ -127,7 +132,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 
       sla_list_safe_domain(['west', '50', '100s'])
 
-      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {})
+      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {}, DEFAULT_GROUPING)
       mock_print_results.assert_called_once_with(['h2'])
 
   def test_safe_domain_include_hosts(self):
@@ -139,7 +144,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       fp.flush()
       mock_options = self.setup_mock_options(include=fp.name)
       with contextlib.nested(
-          patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+          patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+              new=Mock(spec=AuroraClientAPI)),
           patch('apache.aurora.client.commands.admin.print_results'),
           patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
           patch('twitter.common.app.get_options', return_value=mock_options)
@@ -155,7 +161,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 
         mock_api.return_value.sla_get_safe_domain_vector.assert_called_once_with(
             MIN_INSTANCE_COUNT, [hostname])
-        mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {})
+        mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {}, DEFAULT_GROUPING)
         mock_print_results.assert_called_once_with([hostname])
 
   def test_safe_domain_include_hosts_from_list(self):
@@ -164,7 +170,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
     hosts = ['h0', 'h1']
     mock_options = self.setup_mock_options(include_list=','.join(hosts))
     with contextlib.nested(
-        patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+        patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+            new=Mock(spec=AuroraClientAPI)),
         patch('apache.aurora.client.commands.admin.print_results'),
         patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
         patch('twitter.common.app.get_options', return_value=mock_options)
@@ -180,7 +187,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 
       mock_api.return_value.sla_get_safe_domain_vector.assert_called_once_with(
           MIN_INSTANCE_COUNT, hosts)
-      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {})
+      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {}, DEFAULT_GROUPING)
       mock_print_results.assert_called_once_with(hosts)
 
   def test_safe_domain_override_jobs(self):
@@ -191,7 +198,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       fp.flush()
       mock_options = self.setup_mock_options(override=fp.name)
       with contextlib.nested(
-          patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+          patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+              new=Mock(spec=AuroraClientAPI)),
           patch('apache.aurora.client.commands.admin.print_results'),
           patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
           patch('twitter.common.app.get_options', return_value=mock_options)
@@ -207,7 +215,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 
         job_key = AuroraJobKey.from_path('west/role/env/job1')
         override = {job_key: DomainUpTimeSlaVector.JobUpTimeLimit(job_key, 30, 200)}
-        mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, override)
+        mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0,
+            override, DEFAULT_GROUPING)
         mock_print_results.assert_called_once_with(['h0', 'h1', 'h2'])
 
   def test_safe_domain_list_jobs(self):
@@ -215,7 +224,8 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
     mock_options = self.setup_mock_options(list_jobs=True)
     mock_vector = self.create_mock_vector(self.create_hosts(3, 50, 100))
     with contextlib.nested(
-        patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+        patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+            new=Mock(spec=AuroraClientAPI)),
         patch('apache.aurora.client.commands.admin.print_results'),
         patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
         patch('twitter.common.app.get_options', return_value=mock_options)
@@ -228,7 +238,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       mock_api.return_value.sla_get_safe_domain_vector.return_value = mock_vector
       sla_list_safe_domain(['west', '50', '100s'])
 
-      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {})
+      mock_vector.get_safe_hosts.assert_called_once_with(50.0, 100.0, {}, DEFAULT_GROUPING)
       mock_print_results.assert_called_once_with([
           'h0\twest/role/env/job0\t50.00\t100',
           'h1\twest/role/env/job1\t50.00\t100',
@@ -237,7 +247,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
   def test_safe_domain_invalid_percentage(self):
     """Tests execution of the sla_list_safe_domain command with invalid percentage"""
     mock_options = self.setup_mock_options()
-    with patch('twitter.common.app.get_options', return_value=mock_options) as (mock_options):
+    with patch('twitter.common.app.get_options', return_value=mock_options) as (_):
 
       try:
         sla_list_safe_domain(['west', '0', '100s'])
@@ -252,7 +262,7 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
       fp.write('30 200s')
       fp.flush()
       mock_options = self.setup_mock_options(override=fp.name)
-      with patch('twitter.common.app.get_options', return_value=mock_options) as (mock_options):
+      with patch('twitter.common.app.get_options', return_value=mock_options) as (_):
 
         try:
           sla_list_safe_domain(['west', '50', '100s'])
@@ -264,7 +274,19 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
   def test_safe_domain_hosts_error(self):
     """Tests execution of the sla_list_safe_domain command with both include file and list"""
     mock_options = self.setup_mock_options(include='file', include_list='list')
-    with patch('twitter.common.app.get_options', return_value=mock_options) as (mock_options):
+    with patch('twitter.common.app.get_options', return_value=mock_options) as (_):
+
+      try:
+        sla_list_safe_domain(['west', '50', '100s'])
+      except SystemExit:
+        pass
+      else:
+        assert 'Expected error is not raised.'
+
+  def test_safe_domain_grouping_error(self):
+    """Tests execution of the sla_list_safe_domain command invalid grouping"""
+    mock_options = self.setup_mock_options(grouping='foo')
+    with patch('twitter.common.app.get_options', return_value=mock_options) as (_):
 
       try:
         sla_list_safe_domain(['west', '50', '100s'])
@@ -277,11 +299,12 @@ class TestAdminSlaListSafeDomainCommand(AuroraClientCommandTest):
 class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
 
   @classmethod
-  def setup_mock_options(cls, hosts=None, filename=None):
-    mock_options = Mock()
+  def setup_mock_options(cls, hosts=None, filename=None, grouping=None):
+    mock_options = Mock(spec=['hosts', 'filename', 'verbosity', 'min_instance_count', 'grouping'])
     mock_options.hosts = hosts
     mock_options.filename = filename
     mock_options.verbosity = False
+    mock_options.grouping = grouping or DEFAULT_GROUPING
     return mock_options
 
   @classmethod
@@ -297,7 +320,7 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
       host_name = 'h%s' % i
       job = AuroraJobKey.from_path('west/role/env/job%s' % i)
       hosts[host_name].append(DomainUpTimeSlaVector.JobUpTimeDetails(job, predicted, safe,
safe_in))
-    return hosts
+    return [hosts]
 
   def test_probe_hosts_with_list(self):
     """Tests successful execution of the sla_probe_hosts command with host list."""
@@ -305,7 +328,8 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
     mock_options = self.setup_mock_options(hosts=','.join(hosts))
     mock_vector = self.create_mock_vector(self.create_probe_hosts(2, 80, True, 0))
     with contextlib.nested(
-        patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+        patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+            new=Mock(spec=AuroraClientAPI)),
         patch('apache.aurora.client.commands.admin.print_results'),
         patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
         patch('twitter.common.app.get_options', return_value=mock_options)
@@ -320,7 +344,7 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
 
       mock_api.return_value.sla_get_safe_domain_vector.assert_called_once_with(
           mock_options.min_instance_count, hosts)
-      mock_vector.probe_hosts.assert_called_once_with(90.0, 200.0)
+      mock_vector.probe_hosts.assert_called_once_with(90.0, 200.0, mock_options.grouping)
       mock_print_results.assert_called_once_with([
           'h0\twest/role/env/job0\t80.00\tTrue\t0',
           'h1\twest/role/env/job1\t80.00\tTrue\t0'
@@ -334,7 +358,8 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
       fp.flush()
       mock_options = self.setup_mock_options(filename=fp.name)
       with contextlib.nested(
-          patch('apache.aurora.client.commands.admin.AuroraClientAPI', new=Mock(spec=AuroraClientAPI)),
+          patch('apache.aurora.client.commands.admin.AuroraClientAPI',
+              new=Mock(spec=AuroraClientAPI)),
           patch('apache.aurora.client.commands.admin.print_results'),
           patch('apache.aurora.client.commands.admin.CLUSTERS', new=self.TEST_CLUSTERS),
           patch('twitter.common.app.get_options', return_value=mock_options)
@@ -349,7 +374,7 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
 
         mock_api.return_value.sla_get_safe_domain_vector.assert_called_once_with(
           mock_options.min_instance_count, ['h0'])
-        mock_vector.probe_hosts.assert_called_once_with(90.0, 200.0)
+        mock_vector.probe_hosts.assert_called_once_with(90.0, 200.0, mock_options.grouping)
         mock_print_results.assert_called_once_with([
             'h0\twest/role/env/job0\t80.00\tFalse\tn/a'
         ])
@@ -360,7 +385,7 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
       fp.write('h0')
       fp.flush()
       mock_options = self.setup_mock_options(hosts='h0', filename=fp.name)
-      with patch('twitter.common.app.get_options', return_value=mock_options) as (mock_options):
+      with patch('twitter.common.app.get_options', return_value=mock_options) as (_):
 
         try:
           sla_probe_hosts(['west', '50', '100s'])
@@ -368,3 +393,15 @@ class TestAdminSlaProbeHostsCommand(AuroraClientCommandTest):
           pass
         else:
           assert 'Expected error is not raised.'
+
+  def test_probe_grouping_error(self):
+    """Tests execution of the sla_probe_hosts command with invalid grouping."""
+    mock_options = self.setup_mock_options(hosts='h0', grouping='foo')
+    with patch('twitter.common.app.get_options', return_value=mock_options) as (_):
+
+      try:
+        sla_probe_hosts(['west', '50', '100s'])
+      except SystemExit:
+        pass
+      else:
+        assert 'Expected error is not raised.'


Mime
View raw message