kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rsiva...@apache.org
Subject [kafka] branch trunk updated: KAFKA-4292: Configurable SASL callback handlers (KIP-86) (#2022)
Date Thu, 05 Apr 2018 08:41:46 GMT
This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 9f8c316  KAFKA-4292: Configurable SASL callback handlers (KIP-86) (#2022)
9f8c316 is described below

commit 9f8c3167eb2fcab158147eb4fefdabc933b8a3a1
Author: Rajini Sivaram <rajinisivaram@googlemail.com>
AuthorDate: Thu Apr 5 09:41:42 2018 +0100

    KAFKA-4292: Configurable SASL callback handlers (KIP-86) (#2022)
    
    Implementation of KIP-86. Client, server and login callback handlers have been made configurable for both brokers and clients.
    
    Reviewers: Jun Rao <junrao@gmail.com>, Ron Dagostino <rndgstn@gmail.com>, Manikumar Reddy <manikumar.reddy@gmail.com>
---
 .../apache/kafka/common/config/SaslConfigs.java    |  24 +-
 .../config/internals/BrokerSecurityConfigs.java    |   6 +
 .../apache/kafka/common/network/ListenerName.java  |   6 +-
 .../kafka/common/network/SaslChannelBuilder.java   | 105 ++++-
 .../apache/kafka/common/security/JaasContext.java  |   2 +-
 .../security/auth/AuthenticateCallbackHandler.java |  62 +++
 .../security/{authenticator => auth}/Login.java    |  21 +-
 .../security/authenticator/AbstractLogin.java      |  38 +-
 .../authenticator/AuthCallbackHandler.java         |  45 ---
 .../security/authenticator/LoginManager.java       | 114 ++++--
 .../authenticator/SaslClientAuthenticator.java     |  13 +-
 .../authenticator/SaslClientCallbackHandler.java   |  30 +-
 .../authenticator/SaslServerAuthenticator.java     |  37 +-
 .../authenticator/SaslServerCallbackHandler.java   |  37 +-
 .../KerberosClientCallbackHandler.java}            |  52 +--
 .../common/security/kerberos/KerberosLogin.java    |  26 +-
 .../security/plain/PlainAuthenticateCallback.java  |  63 +++
 .../common/security/plain/PlainLoginModule.java    |   2 +
 .../plain/{ => internal}/PlainSaslServer.java      |  33 +-
 .../{ => internal}/PlainSaslServerProvider.java    |   4 +-
 .../plain/internal/PlainServerCallbackHandler.java |  76 ++++
 .../common/security/scram/ScramCredential.java     |  20 +
 .../security/scram/ScramCredentialCallback.java    |  19 +-
 .../security/scram/ScramExtensionsCallback.java    |  12 +
 .../common/security/scram/ScramLoginModule.java    |   3 +
 .../scram/{ => internal}/ScramCredentialUtils.java |   3 +-
 .../scram/{ => internal}/ScramExtensions.java      |   4 +-
 .../scram/{ => internal}/ScramFormatter.java       |   9 +-
 .../scram/{ => internal}/ScramMechanism.java       |   2 +-
 .../scram/{ => internal}/ScramMessages.java        |   2 +-
 .../scram/{ => internal}/ScramSaslClient.java      |  24 +-
 .../{ => internal}/ScramSaslClientProvider.java    |   4 +-
 .../scram/{ => internal}/ScramSaslServer.java      |  22 +-
 .../{ => internal}/ScramSaslServerProvider.java    |   4 +-
 .../{ => internal}/ScramServerCallbackHandler.java |  22 +-
 .../token/delegation/DelegationTokenCache.java     |   4 +-
 .../apache/kafka/common/network/NioEchoServer.java |  12 +-
 .../kafka/common/security/TestSecurityConfig.java  |   3 +
 .../auth/DefaultKafkaPrincipalBuilderTest.java     |   2 +-
 .../security/authenticator/LoginManagerTest.java   |  24 +-
 .../authenticator/SaslAuthenticatorTest.java       | 430 +++++++++++++++++++--
 .../authenticator/SaslServerAuthenticatorTest.java |  11 +-
 .../authenticator/TestDigestLoginModule.java       |  68 +---
 .../security/authenticator/TestJaasConfig.java     |   2 +-
 .../plain/{ => internal}/PlainSaslServerTest.java  |   7 +-
 .../{ => internal}/ScramCredentialUtilsTest.java   |  16 +-
 .../scram/{ => internal}/ScramFormatterTest.java   |  13 +-
 .../scram/{ => internal}/ScramMessagesTest.java    |  21 +-
 .../scram/{ => internal}/ScramSaslServerTest.java  |  13 +-
 .../src/main/scala/kafka/admin/ConfigCommand.scala |   2 +-
 .../scala/kafka/security/CredentialProvider.scala  |   3 +-
 .../kafka/server/DelegationTokenManager.scala      |   3 +-
 .../scala/kafka/server/DynamicConfigManager.scala  |   4 +-
 core/src/main/scala/kafka/server/KafkaConfig.scala |  14 +-
 core/src/main/scala/kafka/server/KafkaServer.scala |   2 +-
 .../scala/kafka/utils/VerifiableProperties.scala   |  10 +-
 .../DelegationTokenEndToEndAuthorizationTest.scala |   2 +-
 .../kafka/api/SaslEndToEndAuthorizationTest.scala  |   1 +
 .../SaslPlainSslEndToEndAuthorizationTest.scala    |  87 ++++-
 .../SaslScramSslEndToEndAuthorizationTest.scala    |   2 +-
 .../scala/integration/kafka/api/SaslSetup.scala    |   2 +-
 .../scala/unit/kafka/admin/ConfigCommandTest.scala |   2 +-
 .../unit/kafka/network/SocketServerTest.scala      |   2 +-
 .../delegation/DelegationTokenManagerTest.scala    |   2 +-
 .../scala/unit/kafka/server/KafkaConfigTest.scala  |   4 +
 .../scala/unit/kafka/utils/JaasTestUtils.scala     |   2 +-
 66 files changed, 1262 insertions(+), 454 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java
index f61b7dd..148ab15 100644
--- a/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java
+++ b/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java
@@ -49,7 +49,24 @@ public class SaslConfigs {
     public static final String SASL_JAAS_CONFIG = "sasl.jaas.config";
     public static final String SASL_JAAS_CONFIG_DOC = "JAAS login context parameters for SASL connections in the format used by JAAS configuration files. "
         + "JAAS configuration file format is described <a href=\"http://docs.oracle.com/javase/8/docs/technotes/guides/security/jgss/tutorials/LoginConfigFile.html\">here</a>. "
-        + "The format for the value is: '<loginModuleClass> <controlFlag> (<optionName>=<optionValue>)*;'";
+        + "The format for the value is: '<loginModuleClass> <controlFlag> (<optionName>=<optionValue>)*;'. For brokers, "
+        + "the config must be prefixed with listener prefix and SASL mechanism name in lower-case. For example, "
+        + "listener.name.sasl_ssl.scram-sha-256.sasl.jaas.config=com.example.ScramLoginModule required;";
+
+    public static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS = "sasl.client.callback.handler.class";
+    public static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL client callback handler class "
+        + "that implements the AuthenticateCallbackHandler interface.";
+
+    public static final String SASL_LOGIN_CALLBACK_HANDLER_CLASS = "sasl.login.callback.handler.class";
+    public static final String SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL login callback handler class "
+            + "that implements the AuthenticateCallbackHandler interface. For brokers, login callback handler config must be prefixed with "
+            + "listener prefix and SASL mechanism name in lower-case. For example, "
+            + "listener.name.sasl_ssl.scram-sha-256.sasl.login.callback.handler.class=com.example.CustomScramLoginCallbackHandler";
+
+    public static final String SASL_LOGIN_CLASS = "sasl.login.class";
+    public static final String SASL_LOGIN_CLASS_DOC = "The fully qualified name of a class that implements the Login interface. "
+        + "For brokers, login config must be prefixed with listener prefix and SASL mechanism name in lower-case. For example, "
+        + "listener.name.sasl_ssl.scram-sha-256.sasl.login.class=com.example.CustomScramLogin";
 
     public static final String SASL_KERBEROS_SERVICE_NAME = "sasl.kerberos.service.name";
     public static final String SASL_KERBEROS_SERVICE_NAME_DOC = "The Kerberos principal name that Kafka runs as. "
@@ -95,6 +112,9 @@ public class SaslConfigs {
                 .define(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER, ConfigDef.Type.DOUBLE, SaslConfigs.DEFAULT_KERBEROS_TICKET_RENEW_JITTER, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER_DOC)
                 .define(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN, ConfigDef.Type.LONG, SaslConfigs.DEFAULT_KERBEROS_MIN_TIME_BEFORE_RELOGIN, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN_DOC)
                 .define(SaslConfigs.SASL_MECHANISM, ConfigDef.Type.STRING, SaslConfigs.DEFAULT_SASL_MECHANISM, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_MECHANISM_DOC)
-                .define(SaslConfigs.SASL_JAAS_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_JAAS_CONFIG_DOC);
+                .define(SaslConfigs.SASL_JAAS_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_JAAS_CONFIG_DOC)
+                .define(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC)
+                .define(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC)
+                .define(SaslConfigs.SASL_LOGIN_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_LOGIN_CLASS_DOC);
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java
index 18616ec..a29d806 100644
--- a/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java
+++ b/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java
@@ -33,6 +33,7 @@ public class BrokerSecurityConfigs {
     public static final String SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG = "sasl.kerberos.principal.to.local.rules";
     public static final String SSL_CLIENT_AUTH_CONFIG = "ssl.client.auth";
     public static final String SASL_ENABLED_MECHANISMS_CONFIG = "sasl.enabled.mechanisms";
+    public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS = "sasl.server.callback.handler.class";
 
     public static final String PRINCIPAL_BUILDER_CLASS_DOC = "The fully qualified name of a class that implements the " +
             "KafkaPrincipalBuilder interface, which is used to build the KafkaPrincipal object used during " +
@@ -67,4 +68,9 @@ public class BrokerSecurityConfigs {
             + "Only GSSAPI is enabled by default.";
     public static final List<String> DEFAULT_SASL_ENABLED_MECHANISMS = Collections.singletonList(SaslConfigs.GSSAPI_MECHANISM);
 
+    public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL server callback handler "
+            + "class that implements the AuthenticateCallbackHandler interface. Server callback handlers must be prefixed with "
+            + "listener prefix and SASL mechanism name in lower-case. For example, "
+            + "listener.name.sasl_ssl.plain.sasl.server.callback.handler.class=com.example.CustomPlainCallbackHandler.";
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java b/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java
index fc0cb14..2decccb 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java
@@ -73,6 +73,10 @@ public final class ListenerName {
     }
 
     public String saslMechanismConfigPrefix(String saslMechanism) {
-        return configPrefix() + saslMechanism.toLowerCase(Locale.ROOT) + ".";
+        return configPrefix() + saslMechanismPrefix(saslMechanism);
+    }
+
+    public static String saslMechanismPrefix(String saslMechanism) {
+        return saslMechanism.toLowerCase(Locale.ROOT) + ".";
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java
index 095f826..5502164 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java
@@ -21,22 +21,35 @@ import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.config.SslConfigs;
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs;
 import org.apache.kafka.common.memory.MemoryPool;
-import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.security.JaasContext;
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
-import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.Login;
+import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.authenticator.DefaultLogin;
 import org.apache.kafka.common.security.authenticator.LoginManager;
 import org.apache.kafka.common.security.authenticator.SaslClientAuthenticator;
+import org.apache.kafka.common.security.authenticator.SaslClientCallbackHandler;
 import org.apache.kafka.common.security.authenticator.SaslServerAuthenticator;
+import org.apache.kafka.common.security.authenticator.SaslServerCallbackHandler;
+import org.apache.kafka.common.security.kerberos.KerberosClientCallbackHandler;
+import org.apache.kafka.common.security.kerberos.KerberosLogin;
+import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
+import org.apache.kafka.common.security.plain.internal.PlainSaslServer;
+import org.apache.kafka.common.security.plain.internal.PlainServerCallbackHandler;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramServerCallbackHandler;
 import org.apache.kafka.common.security.ssl.SslFactory;
+import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.apache.kafka.common.utils.Java;
+import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
+import java.io.IOException;
 import java.net.Socket;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
@@ -66,6 +79,7 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
     private SslFactory sslFactory;
     private Map<String, ?> configs;
     private KerberosShortNamer kerberosShortNamer;
+    private Map<String, AuthenticateCallbackHandler> saslCallbackHandlers;
 
     public SaslChannelBuilder(Mode mode,
                               Map<String, JaasContext> jaasContexts,
@@ -87,15 +101,25 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
         this.clientSaslMechanism = clientSaslMechanism;
         this.credentialCache = credentialCache;
         this.tokenCache = tokenCache;
+        this.saslCallbackHandlers = new HashMap<>();
     }
 
+    @SuppressWarnings("unchecked")
     @Override
     public void configure(Map<String, ?> configs) throws KafkaException {
         try {
             this.configs = configs;
-            boolean hasKerberos = jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM);
+            if (mode == Mode.SERVER)
+                createServerCallbackHandlers(configs);
+            else
+                createClientCallbackHandler(configs);
+            for (Map.Entry<String, AuthenticateCallbackHandler> entry : saslCallbackHandlers.entrySet()) {
+                String mechanism = entry.getKey();
+                entry.getValue().configure(configs, mechanism, jaasContexts.get(mechanism).configurationEntries());
+            }
 
-            if (hasKerberos) {
+            Class<? extends Login> defaultLoginClass = DefaultLogin.class;
+            if (jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM)) {
                 String defaultRealm;
                 try {
                     defaultRealm = defaultKerberosRealm();
@@ -106,12 +130,13 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
                 List<String> principalToLocalRules = (List<String>) configs.get(BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG);
                 if (principalToLocalRules != null)
                     kerberosShortNamer = KerberosShortNamer.fromUnparsedRules(defaultRealm, principalToLocalRules);
+                defaultLoginClass = KerberosLogin.class;
             }
             for (Map.Entry<String, JaasContext> entry : jaasContexts.entrySet()) {
                 String mechanism = entry.getKey();
                 // With static JAAS configuration, use KerberosLogin if Kerberos is enabled. With dynamic JAAS configuration,
                 // use KerberosLogin only for the LoginContext corresponding to GSSAPI
-                LoginManager loginManager = LoginManager.acquireLoginManager(entry.getValue(), mechanism, hasKerberos, configs);
+                LoginManager loginManager = LoginManager.acquireLoginManager(entry.getValue(), mechanism, defaultLoginClass, configs);
                 loginManagers.put(mechanism, loginManager);
                 subjects.put(mechanism, loginManager.subject());
             }
@@ -120,7 +145,7 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
                 this.sslFactory = new SslFactory(mode, "none", isInterBrokerListener);
                 this.sslFactory.configure(configs);
             }
-        } catch (Exception e) {
+        } catch (Throwable e) {
             close();
             throw new KafkaException(e);
         }
@@ -156,11 +181,20 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
             TransportLayer transportLayer = buildTransportLayer(id, key, socketChannel);
             Authenticator authenticator;
             if (mode == Mode.SERVER) {
-                authenticator = buildServerAuthenticator(configs, id, transportLayer, subjects);
+                authenticator = buildServerAuthenticator(configs,
+                        saslCallbackHandlers,
+                        id,
+                        transportLayer,
+                        subjects);
             } else {
                 LoginManager loginManager = loginManagers.get(clientSaslMechanism);
-                authenticator = buildClientAuthenticator(configs, id, socket.getInetAddress().getHostName(),
-                        loginManager.serviceName(), transportLayer, loginManager.subject());
+                authenticator = buildClientAuthenticator(configs,
+                        saslCallbackHandlers.get(clientSaslMechanism),
+                        id,
+                        socket.getInetAddress().getHostName(),
+                        loginManager.serviceName(),
+                        transportLayer,
+                        subjects.get(clientSaslMechanism));
             }
             return new KafkaChannel(id, transportLayer, authenticator, maxReceiveSize, memoryPool != null ? memoryPool : MemoryPool.NONE);
         } catch (Exception e) {
@@ -174,6 +208,8 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
         for (LoginManager loginManager : loginManagers.values())
             loginManager.release();
         loginManagers.clear();
+        for (AuthenticateCallbackHandler handler : saslCallbackHandlers.values())
+            handler.close();
     }
 
     private TransportLayer buildTransportLayer(String id, SelectionKey key, SocketChannel socketChannel) throws IOException {
@@ -186,16 +222,23 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
     }
 
     // Visible to override for testing
-    protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs, String id,
-            TransportLayer transportLayer, Map<String, Subject> subjects) throws IOException {
-        return new SaslServerAuthenticator(configs, id, jaasContexts, subjects,
-                kerberosShortNamer, credentialCache, listenerName, securityProtocol, transportLayer, tokenCache);
+    protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs,
+                                                               Map<String, AuthenticateCallbackHandler> callbackHandlers,
+                                                               String id,
+                                                               TransportLayer transportLayer,
+                                                               Map<String, Subject> subjects) throws IOException {
+        return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects,
+                kerberosShortNamer, listenerName, securityProtocol, transportLayer);
     }
 
     // Visible to override for testing
-    protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs, String id,
-            String serverHost, String servicePrincipal, TransportLayer transportLayer, Subject subject) throws IOException {
-        return new SaslClientAuthenticator(configs, id, subject, servicePrincipal,
+    protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
+                                                               AuthenticateCallbackHandler callbackHandler,
+                                                               String id,
+                                                               String serverHost,
+                                                               String servicePrincipal,
+                                                               TransportLayer transportLayer, Subject subject) throws IOException {
+        return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal,
                 serverHost, clientSaslMechanism, handshakeRequestEnable, transportLayer);
     }
 
@@ -224,4 +267,30 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
         getDefaultRealmMethod = classRef.getDeclaredMethod("getDefaultRealm", new Class[0]);
         return (String) getDefaultRealmMethod.invoke(kerbConf, new Object[0]);
     }
+
+    private void createClientCallbackHandler(Map<String, ?> configs) {
+        Class<? extends AuthenticateCallbackHandler> clazz = (Class<? extends AuthenticateCallbackHandler>) configs.get(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS);
+        if (clazz == null)
+            clazz = clientSaslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM) ? KerberosClientCallbackHandler.class : SaslClientCallbackHandler.class;
+        AuthenticateCallbackHandler callbackHandler = Utils.newInstance(clazz);
+        saslCallbackHandlers.put(clientSaslMechanism, callbackHandler);
+    }
+
+    private void createServerCallbackHandlers(Map<String, ?> configs) throws ClassNotFoundException {
+        for (String mechanism : jaasContexts.keySet()) {
+            AuthenticateCallbackHandler callbackHandler;
+            String prefix = ListenerName.saslMechanismPrefix(mechanism);
+            Class<? extends AuthenticateCallbackHandler> clazz =
+                    (Class<? extends AuthenticateCallbackHandler>) configs.get(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
+            if (clazz != null)
+                callbackHandler = Utils.newInstance(clazz);
+            else if (mechanism.equals(PlainSaslServer.PLAIN_MECHANISM))
+                callbackHandler = new PlainServerCallbackHandler();
+            else if (ScramMechanism.isScram(mechanism))
+                callbackHandler = new ScramServerCallbackHandler(credentialCache.cache(mechanism, ScramCredential.class), tokenCache);
+            else
+                callbackHandler = new SaslServerCallbackHandler();
+            saslCallbackHandlers.put(mechanism, callbackHandler);
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java b/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java
index c2bfbe4..849a978 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java
@@ -183,7 +183,7 @@ public class JaasContext {
      * Returns the configuration option for <code>key</code> from this context.
      * If login module name is specified, return option value only from that module.
      */
-    public String configEntryOption(String key, String loginModuleName) {
+    public static String configEntryOption(List<AppConfigurationEntry> configurationEntries, String key, String loginModuleName) {
         for (AppConfigurationEntry entry : configurationEntries) {
             if (loginModuleName != null && !loginModuleName.equals(entry.getLoginModuleName()))
                 continue;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java
new file mode 100644
index 0000000..8951d3a
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.auth;
+
+import java.util.List;
+import java.util.Map;
+
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.login.AppConfigurationEntry;
+
+/*
+ * Callback handler for SASL-based authentication
+ */
+public interface AuthenticateCallbackHandler extends CallbackHandler {
+
+    /**
+     * Configures this callback handler for the specified SASL mechanism.
+     *
+     * @param configs Key-value pairs containing the parsed configuration options of
+     *        the client or broker. Note that these are the Kafka configuration options
+     *        and not the JAAS configuration options. JAAS config options may be obtained
+     *        from `jaasConfigEntries` for callbacks which obtain some configs from the
+     *        JAAS configuration. For configs that may be specified as both Kafka config
+     *        as well as JAAS config (e.g. sasl.kerberos.service.name), the configuration
+     *        is treated as invalid if conflicting values are provided.
+     * @param saslMechanism Negotiated SASL mechanism. For clients, this is the SASL
+     *        mechanism configured for the client. For brokers, this is the mechanism
+     *        negotiated with the client and is one of the mechanisms enabled on the broker.
+     * @param jaasConfigEntries JAAS configuration entries from the JAAS login context.
+     *        This list contains a single entry for clients and may contain more than
+     *        one entry for brokers if multiple mechanisms are enabled on a listener using
+     *        static JAAS configuration where there is no mapping between mechanisms and
+     *        login module entries. In this case, callback handlers can use the login module in
+     *        `jaasConfigEntries` to identify the entry corresponding to `saslMechanism`.
+     *        Alternatively, dynamic JAAS configuration option
+     *        {@link org.apache.kafka.common.config.SaslConfigs#SASL_JAAS_CONFIG} may be
+     *        configured on brokers with listener and mechanism prefix, in which case
+     *        only the configuration entry corresponding to `saslMechanism` will be provided
+     *        in `jaasConfigEntries`.
+     */
+    void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries);
+
+    /**
+     * Closes this instance.
+     */
+    void close();
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/Login.java b/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java
similarity index 55%
rename from clients/src/main/java/org/apache/kafka/common/security/authenticator/Login.java
rename to clients/src/main/java/org/apache/kafka/common/security/auth/Login.java
index b41d1b2..eda5e7a 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/Login.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java
@@ -14,13 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.authenticator;
-
-import org.apache.kafka.common.security.JaasContext;
+package org.apache.kafka.common.security.auth;
 
 import java.util.Map;
 
 import javax.security.auth.Subject;
+import javax.security.auth.login.Configuration;
 import javax.security.auth.login.LoginContext;
 import javax.security.auth.login.LoginException;
 
@@ -31,8 +30,21 @@ public interface Login {
 
     /**
      * Configures this login instance.
+     * @param configs Key-value pairs containing the parsed configuration options of
+     *        the client or broker. Note that these are the Kafka configuration options
+     *        and not the JAAS configuration options. The JAAS options may be obtained
+     *        from `jaasConfiguration`.
+     * @param contextName JAAS context name for this login which may be used to obtain
+     *        the login context from `jaasConfiguration`.
+     * @param jaasConfiguration JAAS configuration containing the login context named
+     *        `contextName`. If static JAAS configuration is used, this `Configuration`
+     *         may also contain other login contexts.
+     * @param loginCallbackHandler Login callback handler instance to use for this Login.
+     *        Login callback handler class may be configured using
+     *        {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_CALLBACK_HANDLER_CLASS}.
      */
-    void configure(Map<String, ?> configs, JaasContext jaasContext);
+    void configure(Map<String, ?> configs, String contextName, Configuration jaasConfiguration,
+                   AuthenticateCallbackHandler loginCallbackHandler);
 
     /**
      * Performs login for each login module specified for the login context of this instance.
@@ -54,4 +66,3 @@ public interface Login {
      */
     void close();
 }
-
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java
index 643f859..7e13508 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java
@@ -16,20 +16,23 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.auth.login.Configuration;
 import javax.security.auth.login.LoginContext;
 import javax.security.auth.login.LoginException;
 import javax.security.sasl.RealmCallback;
 import javax.security.auth.callback.Callback;
-import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.Subject;
 
-import org.apache.kafka.common.security.JaasContext;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.Login;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -38,17 +41,22 @@ import java.util.Map;
 public abstract class AbstractLogin implements Login {
     private static final Logger log = LoggerFactory.getLogger(AbstractLogin.class);
 
-    private JaasContext jaasContext;
+    private String contextName;
+    private Configuration configuration;
     private LoginContext loginContext;
+    private AuthenticateCallbackHandler loginCallbackHandler;
 
     @Override
-    public void configure(Map<String, ?> configs, JaasContext jaasContext) {
-        this.jaasContext = jaasContext;
+    public void configure(Map<String, ?> configs, String contextName, Configuration configuration,
+                          AuthenticateCallbackHandler loginCallbackHandler) {
+        this.contextName = contextName;
+        this.configuration = configuration;
+        this.loginCallbackHandler = loginCallbackHandler;
     }
 
     @Override
     public LoginContext login() throws LoginException {
-        loginContext = new LoginContext(jaasContext.name(), null, new LoginCallbackHandler(), jaasContext.configuration());
+        loginContext = new LoginContext(contextName, null, loginCallbackHandler, configuration);
         loginContext.login();
         log.info("Successfully logged in.");
         return loginContext;
@@ -59,8 +67,12 @@ public abstract class AbstractLogin implements Login {
         return loginContext.getSubject();
     }
 
-    protected JaasContext jaasContext() {
-        return jaasContext;
+    protected String contextName() {
+        return contextName;
+    }
+
+    protected Configuration configuration() {
+        return configuration;
     }
 
     /**
@@ -70,7 +82,11 @@ public abstract class AbstractLogin implements Login {
      * callback handlers which require additional user input.
      *
      */
-    public static class LoginCallbackHandler implements CallbackHandler {
+    public static class DefaultLoginCallbackHandler implements AuthenticateCallbackHandler {
+
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        }
 
         @Override
         public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
@@ -90,6 +106,10 @@ public abstract class AbstractLogin implements Login {
                 }
             }
         }
+
+        @Override
+        public void close() {
+        }
     }
 }
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AuthCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AuthCallbackHandler.java
deleted file mode 100644
index d517162..0000000
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AuthCallbackHandler.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.common.security.authenticator;
-
-import java.util.Map;
-
-import org.apache.kafka.common.network.Mode;
-
-import javax.security.auth.Subject;
-import javax.security.auth.callback.CallbackHandler;
-
-/*
- * Callback handler for SASL-based authentication
- */
-public interface AuthCallbackHandler extends CallbackHandler {
-
-    /**
-     * Configures this callback handler.
-     *
-     * @param configs Configuration
-     * @param mode The mode that indicates if this is a client or server connection
-     * @param subject Subject from login context
-     * @param saslMechanism Negotiated SASL mechanism
-     */
-    void configure(Map<String, ?> configs, Mode mode, Subject subject, String saslMechanism);
-
-    /**
-     * Closes this instance.
-     */
-    void close();
-}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java
index 81dc063..4ae798d 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java
@@ -23,11 +23,17 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Objects;
 
+
+import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.config.types.Password;
+import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.security.JaasContext;
-import org.apache.kafka.common.security.kerberos.KerberosLogin;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.Login;
+import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -36,20 +42,23 @@ public class LoginManager {
     private static final Logger LOGGER = LoggerFactory.getLogger(LoginManager.class);
 
     // static configs (broker or client)
-    private static final Map<String, LoginManager> STATIC_INSTANCES = new HashMap<>();
+    private static final Map<LoginMetadata<String>, LoginManager> STATIC_INSTANCES = new HashMap<>();
 
-    // dynamic configs (client-only)
-    private static final Map<Password, LoginManager> DYNAMIC_INSTANCES = new HashMap<>();
+    // dynamic configs (broker or client)
+    private static final Map<LoginMetadata<Password>, LoginManager> DYNAMIC_INSTANCES = new HashMap<>();
 
     private final Login login;
-    private final Object cacheKey;
+    private final LoginMetadata<?> loginMetadata;
+    private final AuthenticateCallbackHandler loginCallbackHandler;
     private int refCount;
 
-    private LoginManager(JaasContext jaasContext, boolean hasKerberos, Map<String, ?> configs,
-                         Object cacheKey) throws IOException, LoginException {
-        this.cacheKey = cacheKey;
-        login = hasKerberos ? new KerberosLogin() : new DefaultLogin();
-        login.configure(configs, jaasContext);
+    private LoginManager(JaasContext jaasContext, String saslMechanism, Map<String, ?> configs,
+                         LoginMetadata<?> loginMetadata) throws IOException, LoginException {
+        this.loginMetadata = loginMetadata;
+        this.login = Utils.newInstance(loginMetadata.loginClass);
+        loginCallbackHandler = Utils.newInstance(loginMetadata.loginCallbackClass);
+        loginCallbackHandler.configure(configs, saslMechanism, jaasContext.configurationEntries());
+        login.configure(configs, jaasContext.name(), jaasContext.configuration(), loginCallbackHandler);
         login.login();
     }
 
@@ -72,28 +81,34 @@ public class LoginManager {
      * @param saslMechanism SASL mechanism for which login manager is being acquired. For dynamic contexts, the single
      *                      login module in `jaasContext` corresponds to this SASL mechanism. Hence `Login` class is
      *                      chosen based on this mechanism.
-     * @param hasKerberos Boolean flag that indicates if Kerberos is enabled for the server listener or client. Since
-     *                    static broker configuration may contain multiple login modules in a login context, KerberosLogin
-     *                    must be used if Kerberos is enabled on the listener, even if `saslMechanism` is not GSSAPI.
+     * @param defaultLoginClass Default login class to use if an override is not specified in `configs`
      * @param configs Config options used to configure `Login` if a new login manager is created.
      *
      */
-    public static LoginManager acquireLoginManager(JaasContext jaasContext, String saslMechanism, boolean hasKerberos,
+    public static LoginManager acquireLoginManager(JaasContext jaasContext, String saslMechanism,
+                                                   Class<? extends Login> defaultLoginClass,
                                                    Map<String, ?> configs) throws IOException, LoginException {
+        Class<? extends Login> loginClass = configuredClassOrDefault(configs, jaasContext,
+                saslMechanism, SaslConfigs.SASL_LOGIN_CLASS, defaultLoginClass);
+        Class<? extends AuthenticateCallbackHandler> loginCallbackClass = configuredClassOrDefault(configs,
+                jaasContext, saslMechanism, SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS,
+                AbstractLogin.DefaultLoginCallbackHandler.class);
         synchronized (LoginManager.class) {
             LoginManager loginManager;
             Password jaasConfigValue = jaasContext.dynamicJaasConfig();
             if (jaasConfigValue != null) {
-                loginManager = DYNAMIC_INSTANCES.get(jaasConfigValue);
+                LoginMetadata<Password> loginMetadata = new LoginMetadata<>(jaasConfigValue, loginClass, loginCallbackClass);
+                loginManager = DYNAMIC_INSTANCES.get(loginMetadata);
                 if (loginManager == null) {
-                    loginManager = new LoginManager(jaasContext, saslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM), configs, jaasConfigValue);
-                    DYNAMIC_INSTANCES.put(jaasConfigValue, loginManager);
+                    loginManager = new LoginManager(jaasContext, saslMechanism, configs, loginMetadata);
+                    DYNAMIC_INSTANCES.put(loginMetadata, loginManager);
                 }
             } else {
-                loginManager = STATIC_INSTANCES.get(jaasContext.name());
+                LoginMetadata<String> loginMetadata = new LoginMetadata<>(jaasContext.name(), loginClass, loginCallbackClass);
+                loginManager = STATIC_INSTANCES.get(loginMetadata);
                 if (loginManager == null) {
-                    loginManager = new LoginManager(jaasContext, hasKerberos, configs, jaasContext.name());
-                    STATIC_INSTANCES.put(jaasContext.name(), loginManager);
+                    loginManager = new LoginManager(jaasContext, saslMechanism, configs, loginMetadata);
+                    STATIC_INSTANCES.put(loginMetadata, loginManager);
                 }
             }
             return loginManager.acquire();
@@ -110,7 +125,7 @@ public class LoginManager {
 
     // Only for testing
     Object cacheKey() {
-        return cacheKey;
+        return loginMetadata.configInfo;
     }
 
     private LoginManager acquire() {
@@ -127,12 +142,13 @@ public class LoginManager {
             if (refCount == 0)
                 throw new IllegalStateException("release() called on disposed " + this);
             else if (refCount == 1) {
-                if (cacheKey instanceof Password) {
-                    DYNAMIC_INSTANCES.remove(cacheKey);
+                if (loginMetadata.configInfo instanceof Password) {
+                    DYNAMIC_INSTANCES.remove(loginMetadata);
                 } else {
-                    STATIC_INSTANCES.remove(cacheKey);
+                    STATIC_INSTANCES.remove(loginMetadata);
                 }
                 login.close();
+                loginCallbackHandler.close();
             }
             --refCount;
             LOGGER.trace("{} released", this);
@@ -150,10 +166,56 @@ public class LoginManager {
     /* Should only be used in tests. */
     public static void closeAll() {
         synchronized (LoginManager.class) {
-            for (String key : new ArrayList<>(STATIC_INSTANCES.keySet()))
+            for (LoginMetadata<String> key : new ArrayList<>(STATIC_INSTANCES.keySet()))
                 STATIC_INSTANCES.remove(key).login.close();
-            for (Password key : new ArrayList<>(DYNAMIC_INSTANCES.keySet()))
+            for (LoginMetadata<Password> key : new ArrayList<>(DYNAMIC_INSTANCES.keySet()))
                 DYNAMIC_INSTANCES.remove(key).login.close();
         }
     }
+
+    private static <T> Class<? extends T> configuredClassOrDefault(Map<String, ?> configs,
+                                                     JaasContext jaasContext,
+                                                     String saslMechanism,
+                                                     String configName,
+                                                     Class<? extends T> defaultClass) {
+        String prefix  = jaasContext.type() == JaasContext.Type.SERVER ? ListenerName.saslMechanismPrefix(saslMechanism) : "";
+        Class<? extends T> clazz = (Class<? extends T>) configs.get(prefix + configName);
+        if (clazz != null && jaasContext.configurationEntries().size() != 1) {
+            String errorMessage = configName + " cannot be specified with multiple login modules in the JAAS context. " +
+                    SaslConfigs.SASL_JAAS_CONFIG + " must be configured to override mechanism-specific configs.";
+            throw new ConfigException(errorMessage);
+        }
+        if (clazz == null)
+            clazz = defaultClass;
+        return clazz;
+    }
+
+    private static class LoginMetadata<T> {
+        final T configInfo;
+        final Class<? extends Login> loginClass;
+        final Class<? extends AuthenticateCallbackHandler> loginCallbackClass;
+
+        LoginMetadata(T configInfo, Class<? extends Login> loginClass,
+                      Class<? extends AuthenticateCallbackHandler> loginCallbackClass) {
+            this.configInfo = configInfo;
+            this.loginClass = loginClass;
+            this.loginCallbackClass = loginCallbackClass;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(configInfo, loginClass, loginCallbackClass);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            LoginMetadata<?> loginMetadata = (LoginMetadata<?>) o;
+            return Objects.equals(configInfo, loginMetadata.configInfo) &&
+                   Objects.equals(loginClass, loginMetadata.loginClass) &&
+                   Objects.equals(loginCallbackClass, loginMetadata.loginCallbackClass);
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
index 8b01165..2ef6d77 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
@@ -24,9 +24,8 @@ import org.apache.kafka.common.errors.IllegalSaslStateException;
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
 import org.apache.kafka.common.network.Authenticator;
-import org.apache.kafka.common.network.Mode;
-import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.NetworkSend;
+import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.Send;
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
@@ -41,6 +40,7 @@ import org.apache.kafka.common.requests.SaslAuthenticateRequest;
 import org.apache.kafka.common.requests.SaslAuthenticateResponse;
 import org.apache.kafka.common.requests.SaslHandshakeRequest;
 import org.apache.kafka.common.requests.SaslHandshakeResponse;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
@@ -87,7 +87,7 @@ public class SaslClientAuthenticator implements Authenticator {
     private final SaslClient saslClient;
     private final Map<String, ?> configs;
     private final String clientPrincipalName;
-    private final AuthCallbackHandler callbackHandler;
+    private final AuthenticateCallbackHandler callbackHandler;
 
     // buffers used in `authenticate`
     private NetworkReceive netInBuffer;
@@ -105,6 +105,7 @@ public class SaslClientAuthenticator implements Authenticator {
     private short saslAuthenticateVersion;
 
     public SaslClientAuthenticator(Map<String, ?> configs,
+                                   AuthenticateCallbackHandler callbackHandler,
                                    String node,
                                    Subject subject,
                                    String servicePrincipal,
@@ -114,6 +115,7 @@ public class SaslClientAuthenticator implements Authenticator {
                                    TransportLayer transportLayer) throws IOException {
         this.node = node;
         this.subject = subject;
+        this.callbackHandler = callbackHandler;
         this.host = host;
         this.servicePrincipal = servicePrincipal;
         this.mechanism = mechanism;
@@ -133,9 +135,6 @@ public class SaslClientAuthenticator implements Authenticator {
             else
                 this.clientPrincipalName = null;
 
-            callbackHandler = new SaslClientCallbackHandler();
-            callbackHandler.configure(configs, Mode.CLIENT, subject, mechanism);
-
             saslClient = createSaslClient();
         } catch (Exception e) {
             throw new SaslAuthenticationException("Failed to configure SaslClientAuthenticator", e);
@@ -325,8 +324,6 @@ public class SaslClientAuthenticator implements Authenticator {
     public void close() throws IOException {
         if (saslClient != null)
             saslClient.dispose();
-        if (callbackHandler != null)
-            callbackHandler.close();
     }
 
     private byte[] receiveToken() throws IOException {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
index 31c51c2..5b2a281 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
+import java.security.AccessController;
+import java.util.List;
 import java.util.Map;
 
 import javax.security.auth.Subject;
@@ -23,52 +25,46 @@ import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
 
 import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.network.Mode;
 import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 
 /**
- * Callback handler for Sasl clients. The callbacks required for the SASL mechanism
+ * Default callback handler for Sasl clients. The callbacks required for the SASL mechanism
  * configured for the client should be supported by this callback handler. See
  * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html">Java SASL API</a>
  * for the list of SASL callback handlers required for each SASL mechanism.
  */
-public class SaslClientCallbackHandler implements AuthCallbackHandler {
+public class SaslClientCallbackHandler implements AuthenticateCallbackHandler {
 
-    private boolean isKerberos;
-    private Subject subject;
+    private String mechanism;
 
     @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String mechanism) {
-        this.isKerberos = mechanism.equals(SaslConfigs.GSSAPI_MECHANISM);
-        this.subject = subject;
+    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.mechanism  = saslMechanism;
     }
 
     @Override
     public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
+        Subject subject = Subject.getSubject(AccessController.getContext());
         for (Callback callback : callbacks) {
             if (callback instanceof NameCallback) {
                 NameCallback nc = (NameCallback) callback;
-                if (!isKerberos && subject != null && !subject.getPublicCredentials(String.class).isEmpty()) {
+                if (subject != null && !subject.getPublicCredentials(String.class).isEmpty()) {
                     nc.setName(subject.getPublicCredentials(String.class).iterator().next());
                 } else
                     nc.setName(nc.getDefaultName());
             } else if (callback instanceof PasswordCallback) {
-                if (!isKerberos && subject != null && !subject.getPrivateCredentials(String.class).isEmpty()) {
+                if (subject != null && !subject.getPrivateCredentials(String.class).isEmpty()) {
                     char[] password = subject.getPrivateCredentials(String.class).iterator().next().toCharArray();
                     ((PasswordCallback) callback).setPassword(password);
                 } else {
                     String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" +
                              " client code does not currently support obtaining a password from the user.";
-                    if (isKerberos) {
-                        errorMessage += " Make sure -Djava.security.auth.login.config property passed to JVM and" +
-                             " the client is configured to use a ticket cache (using" +
-                             " the JAAS configuration setting 'useTicketCache=true)'. Make sure you are using" +
-                             " FQDN of the Kafka broker you are trying to connect to.";
-                    }
                     throw new UnsupportedCallbackException(callback, errorMessage);
                 }
             } else if (callback instanceof RealmCallback) {
@@ -83,7 +79,7 @@ public class SaslClientCallbackHandler implements AuthCallbackHandler {
                     ac.setAuthorizedID(authzId);
             } else if (callback instanceof ScramExtensionsCallback) {
                 ScramExtensionsCallback sc = (ScramExtensionsCallback) callback;
-                if (!isKerberos && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
+                if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
                     sc.extensions((Map<String, String>) subject.getPublicCredentials(Map.class).iterator().next());
                 }
             }  else {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 2a80e5b..5140afb 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -28,7 +28,6 @@ import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.network.Authenticator;
 import org.apache.kafka.common.network.ChannelBuilders;
 import org.apache.kafka.common.network.ListenerName;
-import org.apache.kafka.common.network.Mode;
 import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.NetworkSend;
 import org.apache.kafka.common.network.Send;
@@ -46,18 +45,15 @@ import org.apache.kafka.common.requests.SaslAuthenticateRequest;
 import org.apache.kafka.common.requests.SaslAuthenticateResponse;
 import org.apache.kafka.common.requests.SaslHandshakeRequest;
 import org.apache.kafka.common.requests.SaslHandshakeResponse;
-import org.apache.kafka.common.security.JaasContext;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
 import org.apache.kafka.common.security.auth.SaslAuthenticationContext;
 import org.apache.kafka.common.security.kerberos.KerberosName;
 import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
-import org.apache.kafka.common.security.scram.ScramCredential;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
-import org.apache.kafka.common.security.scram.ScramServerCallbackHandler;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.apache.kafka.common.utils.Utils;
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.ietf.jgss.GSSContext;
 import org.ietf.jgss.GSSCredential;
 import org.ietf.jgss.GSSException;
@@ -102,14 +98,12 @@ public class SaslServerAuthenticator implements Authenticator {
     private final SecurityProtocol securityProtocol;
     private final ListenerName listenerName;
     private final String connectionId;
-    private final Map<String, JaasContext> jaasContexts;
     private final Map<String, Subject> subjects;
-    private final CredentialCache credentialCache;
     private final TransportLayer transportLayer;
     private final Set<String> enabledMechanisms;
     private final Map<String, ?> configs;
     private final KafkaPrincipalBuilder principalBuilder;
-    private final DelegationTokenCache tokenCache;
+    private final Map<String, AuthenticateCallbackHandler> callbackHandlers;
 
     // Current SASL state
     private SaslState saslState = SaslState.INITIAL_REQUEST;
@@ -119,7 +113,6 @@ public class SaslServerAuthenticator implements Authenticator {
     private AuthenticationException pendingException = null;
     private SaslServer saslServer;
     private String saslMechanism;
-    private AuthCallbackHandler callbackHandler;
 
     // buffers used in `authenticate`
     private NetworkReceive netInBuffer;
@@ -128,23 +121,19 @@ public class SaslServerAuthenticator implements Authenticator {
     private boolean enableKafkaSaslAuthenticateHeaders;
 
     public SaslServerAuthenticator(Map<String, ?> configs,
+                                   Map<String, AuthenticateCallbackHandler> callbackHandlers,
                                    String connectionId,
-                                   Map<String, JaasContext> jaasContexts,
                                    Map<String, Subject> subjects,
                                    KerberosShortNamer kerberosNameParser,
-                                   CredentialCache credentialCache,
                                    ListenerName listenerName,
                                    SecurityProtocol securityProtocol,
-                                   TransportLayer transportLayer,
-                                   DelegationTokenCache tokenCache) throws IOException {
+                                   TransportLayer transportLayer) throws IOException {
+        this.callbackHandlers = callbackHandlers;
         this.connectionId = connectionId;
-        this.jaasContexts = jaasContexts;
         this.subjects = subjects;
-        this.credentialCache = credentialCache;
         this.listenerName = listenerName;
         this.securityProtocol = securityProtocol;
         this.enableKafkaSaslAuthenticateHeaders = false;
-        this.tokenCache = tokenCache;
         this.transportLayer = transportLayer;
 
         this.configs = configs;
@@ -154,8 +143,8 @@ public class SaslServerAuthenticator implements Authenticator {
             throw new IllegalArgumentException("No SASL mechanisms are enabled");
         this.enabledMechanisms = new HashSet<>(enabledMechanisms);
         for (String mechanism : enabledMechanisms) {
-            if (!jaasContexts.containsKey(mechanism))
-                throw new IllegalArgumentException("Jaas context not specified for SASL mechanism " + mechanism);
+            if (!callbackHandlers.containsKey(mechanism))
+                throw new IllegalArgumentException("Callback handler not specified for SASL mechanism " + mechanism);
             if (!subjects.containsKey(mechanism))
                 throw new IllegalArgumentException("Subject cannot be null for SASL mechanism " + mechanism);
         }
@@ -168,11 +157,7 @@ public class SaslServerAuthenticator implements Authenticator {
     private void createSaslServer(String mechanism) throws IOException {
         this.saslMechanism = mechanism;
         Subject subject = subjects.get(mechanism);
-        if (!ScramMechanism.isScram(mechanism))
-            callbackHandler = new SaslServerCallbackHandler(jaasContexts.get(mechanism));
-        else
-            callbackHandler = new ScramServerCallbackHandler(credentialCache.cache(mechanism, ScramCredential.class), tokenCache);
-        callbackHandler.configure(configs, Mode.SERVER, subject, saslMechanism);
+        final AuthenticateCallbackHandler callbackHandler = callbackHandlers.get(mechanism);
         if (mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) {
             saslServer = createSaslKerberosServer(callbackHandler, configs, subject);
         } else {
@@ -189,7 +174,7 @@ public class SaslServerAuthenticator implements Authenticator {
         }
     }
 
-    private SaslServer createSaslKerberosServer(final AuthCallbackHandler saslServerCallbackHandler, final Map<String, ?> configs, Subject subject) throws IOException {
+    private SaslServer createSaslKerberosServer(final AuthenticateCallbackHandler saslServerCallbackHandler, final Map<String, ?> configs, Subject subject) throws IOException {
         // server is using a JAAS-authenticated subject: determine service principal name and hostname from kafka server's subject.
         final String servicePrincipal = SaslClientAuthenticator.firstPrincipal(subject);
         KerberosName kerberosName;
@@ -316,8 +301,6 @@ public class SaslServerAuthenticator implements Authenticator {
             Utils.closeQuietly((Closeable) principalBuilder, "principal builder");
         if (saslServer != null)
             saslServer.dispose();
-        if (callbackHandler != null)
-            callbackHandler.close();
     }
 
     private void setSaslState(SaslState saslState) throws IOException {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java
index 7d5372d..d3d43cb 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java
@@ -16,51 +16,46 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
-import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 
-import org.apache.kafka.common.security.JaasContext;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import javax.security.auth.Subject;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
 
-import org.apache.kafka.common.network.Mode;
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
- * Callback handler for Sasl servers. The callbacks required for all the SASL
+ * Default callback handler for Sasl servers. The callbacks required for all the SASL
  * mechanisms enabled in the server should be supported by this callback handler. See
  * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html">Java SASL API</a>
  * for the list of SASL callback handlers required for each SASL mechanism.
  */
-public class SaslServerCallbackHandler implements AuthCallbackHandler {
+public class SaslServerCallbackHandler implements AuthenticateCallbackHandler {
     private static final Logger LOG = LoggerFactory.getLogger(SaslServerCallbackHandler.class);
-    private final JaasContext jaasContext;
 
-    public SaslServerCallbackHandler(JaasContext jaasContext) throws IOException {
-        this.jaasContext = jaasContext;
-    }
+    private String mechanism;
 
     @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String saslMechanism) {
-    }
-
-    public JaasContext jaasContext() {
-        return jaasContext;
+    public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.mechanism = mechanism;
     }
 
     @Override
     public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
         for (Callback callback : callbacks) {
-            if (callback instanceof RealmCallback) {
+            if (callback instanceof RealmCallback)
                 handleRealmCallback((RealmCallback) callback);
-            } else if (callback instanceof AuthorizeCallback) {
+            else if (callback instanceof AuthorizeCallback && mechanism.equals(SaslConfigs.GSSAPI_MECHANISM))
                 handleAuthorizeCallback((AuthorizeCallback) callback);
-            }
+            else
+                throw new UnsupportedCallbackException(callback);
         }
     }
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java
similarity index 53%
copy from clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
copy to clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java
index 31c51c2..fa9cad2 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java
@@ -14,37 +14,30 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.authenticator;
+package org.apache.kafka.common.security.kerberos;
 
-import java.util.Map;
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 
-import javax.security.auth.Subject;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
-
-import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.network.Mode;
-import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import java.util.List;
+import java.util.Map;
 
 /**
- * Callback handler for Sasl clients. The callbacks required for the SASL mechanism
- * configured for the client should be supported by this callback handler. See
- * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html">Java SASL API</a>
- * for the list of SASL callback handlers required for each SASL mechanism.
+ * Callback handler for SASL/GSSAPI clients.
  */
-public class SaslClientCallbackHandler implements AuthCallbackHandler {
-
-    private boolean isKerberos;
-    private Subject subject;
+public class KerberosClientCallbackHandler implements AuthenticateCallbackHandler {
 
     @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String mechanism) {
-        this.isKerberos = mechanism.equals(SaslConfigs.GSSAPI_MECHANISM);
-        this.subject = subject;
+    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        if (!saslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM))
+            throw new IllegalStateException("Kerberos callback handler should only be used with GSSAPI");
     }
 
     @Override
@@ -52,25 +45,15 @@ public class SaslClientCallbackHandler implements AuthCallbackHandler {
         for (Callback callback : callbacks) {
             if (callback instanceof NameCallback) {
                 NameCallback nc = (NameCallback) callback;
-                if (!isKerberos && subject != null && !subject.getPublicCredentials(String.class).isEmpty()) {
-                    nc.setName(subject.getPublicCredentials(String.class).iterator().next());
-                } else
-                    nc.setName(nc.getDefaultName());
+                nc.setName(nc.getDefaultName());
             } else if (callback instanceof PasswordCallback) {
-                if (!isKerberos && subject != null && !subject.getPrivateCredentials(String.class).isEmpty()) {
-                    char[] password = subject.getPrivateCredentials(String.class).iterator().next().toCharArray();
-                    ((PasswordCallback) callback).setPassword(password);
-                } else {
-                    String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" +
+                String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" +
                              " client code does not currently support obtaining a password from the user.";
-                    if (isKerberos) {
-                        errorMessage += " Make sure -Djava.security.auth.login.config property passed to JVM and" +
+                errorMessage += " Make sure -Djava.security.auth.login.config property passed to JVM and" +
                              " the client is configured to use a ticket cache (using" +
                              " the JAAS configuration setting 'useTicketCache=true)'. Make sure you are using" +
                              " FQDN of the Kafka broker you are trying to connect to.";
-                    }
-                    throw new UnsupportedCallbackException(callback, errorMessage);
-                }
+                throw new UnsupportedCallbackException(callback, errorMessage);
             } else if (callback instanceof RealmCallback) {
                 RealmCallback rc = (RealmCallback) callback;
                 rc.setText(rc.getDefaultText());
@@ -81,11 +64,6 @@ public class SaslClientCallbackHandler implements AuthCallbackHandler {
                 ac.setAuthorized(authId.equals(authzId));
                 if (ac.isAuthorized())
                     ac.setAuthorizedID(authzId);
-            } else if (callback instanceof ScramExtensionsCallback) {
-                ScramExtensionsCallback sc = (ScramExtensionsCallback) callback;
-                if (!isKerberos && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
-                    sc.extensions((Map<String, String>) subject.getPublicCredentials(Map.class).iterator().next());
-                }
             }  else {
                 throw new UnsupportedCallbackException(callback, "Unrecognized SASL ClientCallback");
             }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java
index 65c3b1c..ec996a8 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java
@@ -18,6 +18,7 @@ package org.apache.kafka.common.security.kerberos;
 
 import javax.security.auth.kerberos.KerberosPrincipal;
 import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.auth.login.Configuration;
 import javax.security.auth.login.LoginContext;
 import javax.security.auth.login.LoginException;
 import javax.security.auth.kerberos.KerberosTicket;
@@ -25,6 +26,7 @@ import javax.security.auth.Subject;
 
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.JaasUtils;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.authenticator.AbstractLogin;
 import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.utils.KafkaThread;
@@ -33,11 +35,12 @@ import org.apache.kafka.common.utils.Time;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Arrays;
 import java.util.Date;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 import java.util.Set;
-import java.util.Map;
 
 /**
  * This class is responsible for refreshing Kerberos credentials for
@@ -78,13 +81,15 @@ public class KerberosLogin extends AbstractLogin {
     private String serviceName;
     private long lastLogin;
 
-    public void configure(Map<String, ?> configs, JaasContext jaasContext) {
-        super.configure(configs, jaasContext);
+    @Override
+    public void configure(Map<String, ?> configs, String contextName, Configuration configuration,
+                          AuthenticateCallbackHandler callbackHandler) {
+        super.configure(configs, contextName, configuration, callbackHandler);
         this.ticketRenewWindowFactor = (Double) configs.get(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR);
         this.ticketRenewJitter = (Double) configs.get(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER);
         this.minTimeBeforeRelogin = (Long) configs.get(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN);
         this.kinitCmd = (String) configs.get(SaslConfigs.SASL_KERBEROS_KINIT_CMD);
-        this.serviceName = getServiceName(configs, jaasContext);
+        this.serviceName = getServiceName(configs, contextName, configuration);
     }
 
     /**
@@ -99,13 +104,13 @@ public class KerberosLogin extends AbstractLogin {
         subject = loginContext.getSubject();
         isKrbTicket = !subject.getPrivateCredentials(KerberosTicket.class).isEmpty();
 
-        List<AppConfigurationEntry> entries = jaasContext().configurationEntries();
-        if (entries.isEmpty()) {
+        AppConfigurationEntry[] entries = configuration().getAppConfigurationEntry(contextName());
+        if (entries.length == 0) {
             isUsingTicketCache = false;
             principal = null;
         } else {
             // there will only be a single entry
-            AppConfigurationEntry entry = entries.get(0);
+            AppConfigurationEntry entry = entries[0];
             if (entry.getOptions().get("useTicketCache") != null) {
                 String val = (String) entry.getOptions().get("useTicketCache");
                 isUsingTicketCache = val.equals("true");
@@ -280,8 +285,9 @@ public class KerberosLogin extends AbstractLogin {
         return serviceName;
     }
 
-    private static String getServiceName(Map<String, ?> configs, JaasContext jaasContext) {
-        String jaasServiceName = jaasContext.configEntryOption(JaasUtils.SERVICE_NAME, null);
+    private static String getServiceName(Map<String, ?> configs, String contextName, Configuration configuration) {
+        List<AppConfigurationEntry> configEntries = Arrays.asList(configuration.getAppConfigurationEntry(contextName));
+        String jaasServiceName = JaasContext.configEntryOption(configEntries, JaasUtils.SERVICE_NAME, null);
         String configServiceName = (String) configs.get(SaslConfigs.SASL_KERBEROS_SERVICE_NAME);
         if (jaasServiceName != null && configServiceName != null && !jaasServiceName.equals(configServiceName)) {
             String message = String.format("Conflicting serviceName values found in JAAS and Kafka configs " +
@@ -360,7 +366,7 @@ public class KerberosLogin extends AbstractLogin {
             loginContext.logout();
             //login and also update the subject field of this instance to
             //have the new credentials (pass it to the LoginContext constructor)
-            loginContext = new LoginContext(jaasContext().name(), subject, null, jaasContext().configuration());
+            loginContext = new LoginContext(contextName(), subject, null, configuration());
             log.info("Initiating re-login for {}", principal);
             loginContext.login();
         }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java
new file mode 100644
index 0000000..7f42645
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.plain;
+
+import javax.security.auth.callback.Callback;
+
+/*
+ * Authentication callback for SASL/PLAIN authentication. Callback handler must
+ * set authenticated flag to true if the client provided password in the callback
+ * matches the expected password.
+ */
+public class PlainAuthenticateCallback implements Callback {
+    private final char[] password;
+    private boolean authenticated;
+
+    /**
+     * Creates a callback with the password provided by the client
+     * @param password The password provided by the client during SASL/PLAIN authentication
+     */
+    public PlainAuthenticateCallback(char[] password) {
+        this.password = password;
+    }
+
+    /**
+     * Returns the password provided by the client during SASL/PLAIN authentication
+     */
+    public char[] password() {
+        return password;
+    }
+
+    /**
+     * Returns true if client password matches expected password, false otherwise.
+     * This state is set the server-side callback handler.
+     */
+    public boolean authenticated() {
+        return this.authenticated;
+    }
+
+    /**
+     * Sets the authenticated state. This is set by the server-side callback handler
+     * by matching the client provided password with the expected password.
+     *
+     * @param authenticated true indicates successful authentication
+     */
+    public void authenticated(boolean authenticated) {
+        this.authenticated = authenticated;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java
index c8b29fc..f0a5971 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.security.plain;
 
+import org.apache.kafka.common.security.plain.internal.PlainSaslServerProvider;
+
 import java.util.Map;
 
 import javax.security.auth.Subject;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServer.java
similarity index 87%
rename from clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServer.java
rename to clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServer.java
index e54887f..811d9e9 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServer.java
@@ -14,21 +14,22 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.plain;
+package org.apache.kafka.common.security.plain.internal;
 
 import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
 import java.util.Map;
 
+import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
 import javax.security.sasl.Sasl;
 import javax.security.sasl.SaslException;
 import javax.security.sasl.SaslServer;
 import javax.security.sasl.SaslServerFactory;
 
 import org.apache.kafka.common.errors.SaslAuthenticationException;
-import org.apache.kafka.common.security.JaasContext;
-import org.apache.kafka.common.security.authenticator.SaslServerCallbackHandler;
+import org.apache.kafka.common.security.plain.PlainAuthenticateCallback;
 
 /**
  * Simple SaslServer implementation for SASL/PLAIN. In order to make this implementation
@@ -46,15 +47,13 @@ import org.apache.kafka.common.security.authenticator.SaslServerCallbackHandler;
 public class PlainSaslServer implements SaslServer {
 
     public static final String PLAIN_MECHANISM = "PLAIN";
-    private static final String JAAS_USER_PREFIX = "user_";
-
-    private final JaasContext jaasContext;
 
+    private final CallbackHandler callbackHandler;
     private boolean complete;
     private String authorizationId;
 
-    public PlainSaslServer(JaasContext jaasContext) {
-        this.jaasContext = jaasContext;
+    public PlainSaslServer(CallbackHandler callbackHandler) {
+        this.callbackHandler = callbackHandler;
     }
 
     /**
@@ -101,12 +100,15 @@ public class PlainSaslServer implements SaslServer {
             throw new SaslException("Authentication failed: password not specified");
         }
 
-        String expectedPassword = jaasContext.configEntryOption(JAAS_USER_PREFIX + username,
-                PlainLoginModule.class.getName());
-        if (!password.equals(expectedPassword)) {
-            throw new SaslAuthenticationException("Authentication failed: Invalid username or password");
+        NameCallback nameCallback = new NameCallback("username", username);
+        PlainAuthenticateCallback authenticateCallback = new PlainAuthenticateCallback(password.toCharArray());
+        try {
+            callbackHandler.handle(new Callback[]{nameCallback, authenticateCallback});
+        } catch (Throwable e) {
+            throw new SaslAuthenticationException("Authentication failed: credentials for user could not be verified", e);
         }
-
+        if (!authenticateCallback.authenticated())
+            throw new SaslAuthenticationException("Authentication failed: Invalid username or password");
         if (!authorizationIdFromClient.isEmpty() && !authorizationIdFromClient.equals(username))
             throw new SaslAuthenticationException("Authentication failed: Client requested an authorization id that is different from username");
 
@@ -167,10 +169,7 @@ public class PlainSaslServer implements SaslServer {
             if (!PLAIN_MECHANISM.equals(mechanism))
                 throw new SaslException(String.format("Mechanism \'%s\' is not supported. Only PLAIN is supported.", mechanism));
 
-            if (!(cbh instanceof SaslServerCallbackHandler))
-                throw new SaslException("CallbackHandler must be of type SaslServerCallbackHandler, but it is: " + cbh.getClass());
-
-            return new PlainSaslServer(((SaslServerCallbackHandler) cbh).jaasContext());
+            return new PlainSaslServer(cbh);
         }
 
         @Override
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerProvider.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServerProvider.java
rename to clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerProvider.java
index ae14244..c222953 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServerProvider.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerProvider.java
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.plain;
+package org.apache.kafka.common.security.plain.internal;
 
 import java.security.Provider;
 import java.security.Security;
 
-import org.apache.kafka.common.security.plain.PlainSaslServer.PlainSaslServerFactory;
+import org.apache.kafka.common.security.plain.internal.PlainSaslServer.PlainSaslServerFactory;
 
 public class PlainSaslServerProvider extends Provider {
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainServerCallbackHandler.java
new file mode 100644
index 0000000..84fbdfd
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainServerCallbackHandler.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.plain.internal;
+
+import org.apache.kafka.common.security.JaasContext;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.security.plain.PlainAuthenticateCallback;
+import org.apache.kafka.common.security.plain.PlainLoginModule;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
+
+public class PlainServerCallbackHandler implements AuthenticateCallbackHandler {
+
+    private static final String JAAS_USER_PREFIX = "user_";
+    private List<AppConfigurationEntry> jaasConfigEntries;
+
+    @Override
+    public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.jaasConfigEntries = jaasConfigEntries;
+    }
+
+    @Override
+    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+        String username = null;
+        for (Callback callback: callbacks) {
+            if (callback instanceof NameCallback)
+                username = ((NameCallback) callback).getDefaultName();
+            else if (callback instanceof PlainAuthenticateCallback) {
+                PlainAuthenticateCallback plainCallback = (PlainAuthenticateCallback) callback;
+                boolean authenticated = authenticate(username, plainCallback.password());
+                plainCallback.authenticated(authenticated);
+            } else
+                throw new UnsupportedCallbackException(callback);
+        }
+    }
+
+    protected boolean authenticate(String username, char[] password) throws IOException {
+        if (username == null)
+            return false;
+        else {
+            String expectedPassword = JaasContext.configEntryOption(jaasConfigEntries,
+                    JAAS_USER_PREFIX + username,
+                    PlainLoginModule.class.getName());
+            return expectedPassword != null && Arrays.equals(password, expectedPassword.toCharArray());
+        }
+    }
+
+    @Override
+    public void close() throws KafkaException {
+    }
+
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java
index 09ff0aa..dfbfef1 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java
@@ -16,6 +16,11 @@
  */
 package org.apache.kafka.common.security.scram;
 
+/**
+ * SCRAM credential class that encapsulates the credential data persisted for each user that is
+ * accessible to the server. See <a href="https://tools.ietf.org/html/rfc5802#section-5">RFC rfc5802</a>
+ * for details.
+ */
 public class ScramCredential {
 
     private final byte[] salt;
@@ -23,6 +28,9 @@ public class ScramCredential {
     private final byte[] storedKey;
     private final int iterations;
 
+    /**
+     * Constructs a new credential.
+     */
     public ScramCredential(byte[] salt, byte[] storedKey, byte[] serverKey, int iterations) {
         this.salt = salt;
         this.serverKey = serverKey;
@@ -30,18 +38,30 @@ public class ScramCredential {
         this.iterations = iterations;
     }
 
+    /**
+     * Returns the salt used to process this credential using the SCRAM algorithm.
+     */
     public byte[] salt() {
         return salt;
     }
 
+    /**
+     * Server key computed from the client password using the SCRAM algorithm.
+     */
     public byte[] serverKey() {
         return serverKey;
     }
 
+    /**
+     * Stored key computed from the client password using the SCRAM algorithm.
+     */
     public byte[] storedKey() {
         return storedKey;
     }
 
+    /**
+     * Number of iterations used to process this credential using the SCRAM algorithm.
+     */
     public int iterations() {
         return iterations;
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java
index 931210a..d5988cb 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java
@@ -18,14 +18,23 @@ package org.apache.kafka.common.security.scram;
 
 import javax.security.auth.callback.Callback;
 
+/**
+ * Callback used for SCRAM mechanisms.
+ */
 public class ScramCredentialCallback implements Callback {
     private ScramCredential scramCredential;
 
-    public ScramCredential scramCredential() {
-        return scramCredential;
-    }
-
+    /**
+     * Sets the SCRAM credential for this instance.
+     */
     public void scramCredential(ScramCredential scramCredential) {
         this.scramCredential = scramCredential;
     }
-}
\ No newline at end of file
+
+    /**
+     * Returns the SCRAM credential if set on this instance.
+     */
+    public ScramCredential scramCredential() {
+        return scramCredential;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
index b40468b..debe163 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
@@ -21,13 +21,25 @@ import javax.security.auth.callback.Callback;
 import java.util.Collections;
 import java.util.Map;
 
+/**
+ * Optional callback used for SCRAM mechanisms if any extensions need to be set
+ * in the SASL/SCRAM exchange.
+ */
 public class ScramExtensionsCallback implements Callback {
     private Map<String, String> extensions = Collections.emptyMap();
 
+    /**
+     * Returns the extension names and values that are sent by the client to
+     * the server in the initial client SCRAM authentication message.
+     * Default is an empty map.
+     */
     public Map<String, String> extensions() {
         return extensions;
     }
 
+    /**
+     * Sets the SCRAM extensions on this callback.
+     */
     public void extensions(Map<String, String> extensions) {
         this.extensions = extensions;
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java
index 43df515..20d1f22 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java
@@ -16,6 +16,9 @@
  */
 package org.apache.kafka.common.security.scram;
 
+import org.apache.kafka.common.security.scram.internal.ScramSaslClientProvider;
+import org.apache.kafka.common.security.scram.internal.ScramSaslServerProvider;
+
 import java.util.Collections;
 import java.util.Map;
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialUtils.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtils.java
similarity index 96%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialUtils.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtils.java
index b4875d6..91e28a6 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialUtils.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtils.java
@@ -14,12 +14,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.util.Collection;
 import java.util.Properties;
 
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
 import org.apache.kafka.common.utils.Base64;
 
 /**
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensions.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramExtensions.java
similarity index 95%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensions.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramExtensions.java
index 0f461c0..66d9362 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensions.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramExtensions.java
@@ -14,7 +14,9 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
+
+import org.apache.kafka.common.security.scram.ScramLoginModule;
 
 import java.util.Collections;
 import java.util.HashMap;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramFormatter.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramFormatter.java
similarity index 94%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramFormatter.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramFormatter.java
index 406c285..6fcb7a1 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramFormatter.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramFormatter.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
@@ -27,9 +27,10 @@ import javax.crypto.Mac;
 import javax.crypto.spec.SecretKeySpec;
 
 import org.apache.kafka.common.KafkaException;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 
 /**
  * Scram message salt and hash functions defined in <a href="https://tools.ietf.org/html/rfc5802">RFC 5802</a>.
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMechanism.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMechanism.java
similarity index 97%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramMechanism.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMechanism.java
index d8c0c6d..73be4cf 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMechanism.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMechanism.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.util.Collection;
 import java.util.Collections;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMessages.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMessages.java
similarity index 99%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramMessages.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMessages.java
index 05b3d77..439b274 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMessages.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMessages.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import org.apache.kafka.common.utils.Base64;
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClient.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClient.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClient.java
index 71109df..a98a86d 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClient.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClient.java
@@ -14,9 +14,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
@@ -34,9 +33,10 @@ import javax.security.sasl.SaslClientFactory;
 import javax.security.sasl.SaslException;
 
 import org.apache.kafka.common.errors.IllegalSaslStateException;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -100,9 +100,15 @@ public class ScramSaslClient implements SaslClient {
                     ScramExtensionsCallback extensionsCallback = new ScramExtensionsCallback();
 
                     try {
-                        callbackHandler.handle(new Callback[]{nameCallback, extensionsCallback});
-                    } catch (IOException | UnsupportedCallbackException e) {
-                        throw new SaslException("User name could not be obtained", e);
+                        callbackHandler.handle(new Callback[]{nameCallback});
+                        try {
+                            callbackHandler.handle(new Callback[]{extensionsCallback});
+                        } catch (UnsupportedCallbackException e) {
+                            log.debug("Extensions callback is not supported by client callback handler {}, no extensions will be added",
+                                    callbackHandler);
+                        }
+                    } catch (Throwable e) {
+                        throw new SaslException("User name or extensions could not be obtained", e);
                     }
 
                     String username = nameCallback.getName();
@@ -121,7 +127,7 @@ public class ScramSaslClient implements SaslClient {
                     PasswordCallback passwordCallback = new PasswordCallback("Password:", false);
                     try {
                         callbackHandler.handle(new Callback[]{passwordCallback});
-                    } catch (IOException | UnsupportedCallbackException e) {
+                    } catch (Throwable e) {
                         throw new SaslException("User name could not be obtained", e);
                     }
                     this.clientFinalMessage = handleServerFirstMessage(passwordCallback.getPassword());
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClientProvider.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClientProvider.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClientProvider.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClientProvider.java
index d389f04..4d9ff81 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClientProvider.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClientProvider.java
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.security.Provider;
 import java.security.Security;
 
-import org.apache.kafka.common.security.scram.ScramSaslClient.ScramSaslClientFactory;
+import org.apache.kafka.common.security.scram.internal.ScramSaslClient.ScramSaslClientFactory;
 
 public class ScramSaslClientProvider extends Provider {
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServer.java
similarity index 92%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServer.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServer.java
index 314c1d4..deee0b8 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServer.java
@@ -14,9 +14,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import java.io.IOException;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
 import java.util.Arrays;
@@ -27,17 +26,20 @@ import java.util.Set;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
-import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.sasl.SaslException;
 import javax.security.sasl.SaslServer;
 import javax.security.sasl.SaslServerFactory;
 
+import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.IllegalSaslStateException;
 import org.apache.kafka.common.errors.SaslAuthenticationException;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.ScramCredentialCallback;
+import org.apache.kafka.common.security.scram.ScramLoginModule;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCredentialCallback;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
@@ -133,7 +135,9 @@ public class ScramSaslServer implements SaslServer {
                                 scramCredential.iterations());
                         setState(State.RECEIVE_CLIENT_FINAL_MESSAGE);
                         return serverFirstMessage.toBytes();
-                    } catch (IOException | NumberFormatException | UnsupportedCallbackException e) {
+                    } catch (SaslException | AuthenticationException e) {
+                        throw e;
+                    } catch (Throwable e) {
                         throw new SaslException("Authentication failed: Credentials could not be obtained", e);
                     }
 
@@ -154,7 +158,7 @@ public class ScramSaslServer implements SaslServer {
                 default:
                     throw new IllegalSaslStateException("Unexpected challenge in Sasl server state " + state);
             }
-        } catch (SaslException e) {
+        } catch (SaslException | AuthenticationException e) {
             clearCredentials();
             setState(State.FAILED);
             throw e;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerProvider.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServerProvider.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerProvider.java
index 9f2a6b3..099e50e 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServerProvider.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerProvider.java
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.security.Provider;
 import java.security.Security;
 
-import org.apache.kafka.common.security.scram.ScramSaslServer.ScramSaslServerFactory;
+import org.apache.kafka.common.security.scram.internal.ScramSaslServer.ScramSaslServerFactory;
 
 public class ScramSaslServerProvider extends Provider {
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramServerCallbackHandler.java
similarity index 82%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramServerCallbackHandler.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramServerCallbackHandler.java
index 5e37eae..377aa3d 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramServerCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramServerCallbackHandler.java
@@ -14,23 +14,25 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 
-import javax.security.auth.Subject;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 
-import org.apache.kafka.common.network.Mode;
-import org.apache.kafka.common.security.authenticator.AuthCallbackHandler;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.ScramCredentialCallback;
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCredentialCallback;
 
-public class ScramServerCallbackHandler implements AuthCallbackHandler {
+public class ScramServerCallbackHandler implements AuthenticateCallbackHandler {
 
     private final CredentialCache.Cache<ScramCredential> credentialCache;
     private final DelegationTokenCache tokenCache;
@@ -43,6 +45,11 @@ public class ScramServerCallbackHandler implements AuthCallbackHandler {
     }
 
     @Override
+    public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.saslMechanism = mechanism;
+    }
+
+    @Override
     public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
         String username = null;
         for (Callback callback : callbacks) {
@@ -61,11 +68,6 @@ public class ScramServerCallbackHandler implements AuthCallbackHandler {
     }
 
     @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String saslMechanism) {
-        this.saslMechanism = saslMechanism;
-    }
-
-    @Override
     public void close() {
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java
index 78575b8..adea210 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java
@@ -19,8 +19,8 @@ package org.apache.kafka.common.security.token.delegation;
 
 import org.apache.kafka.common.security.authenticator.CredentialCache;
 import org.apache.kafka.common.security.scram.ScramCredential;
-import org.apache.kafka.common.security.scram.ScramCredentialUtils;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramCredentialUtils;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 
 import java.util.Collection;
 import java.util.HashMap;
diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
index 0352ade..fab8e93 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
@@ -24,8 +24,8 @@ import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
-import org.apache.kafka.common.security.scram.ScramCredentialUtils;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.test.TestCondition;
@@ -79,8 +79,12 @@ public class NioEchoServer extends Thread {
         this.newChannels = Collections.synchronizedList(new ArrayList<SocketChannel>());
         this.credentialCache = credentialCache;
         this.tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames());
-        if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL)
-            ScramCredentialUtils.createCache(credentialCache, ScramMechanism.mechanismNames());
+        if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) {
+            for (String mechanism : ScramMechanism.mechanismNames()) {
+                if (credentialCache.cache(mechanism, ScramCredential.class) == null)
+                    credentialCache.createCache(mechanism, ScramCredential.class);
+            }
+        }
         if (channelBuilder == null)
             channelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, credentialCache, tokenCache);
         this.metrics = new Metrics();
diff --git a/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java b/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java
index 05294cf..81a883c 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java
@@ -31,6 +31,9 @@ public class TestSecurityConfig extends AbstractConfig {
             .define(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Type.LIST,
                     BrokerSecurityConfigs.DEFAULT_SASL_ENABLED_MECHANISMS,
                     Importance.MEDIUM, BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_DOC)
+            .define(BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, Type.CLASS,
+                    null,
+                    Importance.MEDIUM, BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC)
             .define(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Type.CLASS,
                     null, Importance.MEDIUM, BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC)
             .withClientSslSupport()
diff --git a/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java
index a30c09f..fdf3687 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java
@@ -22,7 +22,7 @@ import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder;
 import org.apache.kafka.common.security.kerberos.KerberosName;
 import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.easymock.EasyMock;
 import org.easymock.EasyMockSupport;
 import org.junit.Test;
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java
index 8be72fb..5436b2a 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java
@@ -59,17 +59,17 @@ public class LoginManagerTest {
         JaasContext staticContext = JaasContext.loadClientContext(Collections.<String, Object>emptyMap());
 
         LoginManager dynamicLogin = LoginManager.acquireLoginManager(dynamicContext, "PLAIN",
-                false, configs);
+                DefaultLogin.class, configs);
         assertEquals(dynamicPlainContext, dynamicLogin.cacheKey());
         LoginManager staticLogin = LoginManager.acquireLoginManager(staticContext, "SCRAM-SHA-256",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(dynamicLogin, staticLogin);
         assertEquals("KafkaClient", staticLogin.cacheKey());
 
         assertSame(dynamicLogin, LoginManager.acquireLoginManager(dynamicContext, "PLAIN",
-                false, configs));
+                DefaultLogin.class, configs));
         assertSame(staticLogin, LoginManager.acquireLoginManager(staticContext, "SCRAM-SHA-256",
-                false, configs));
+                DefaultLogin.class, configs));
 
         verifyLoginManagerRelease(dynamicLogin, 2, dynamicContext, configs);
         verifyLoginManagerRelease(staticLogin, 2, staticContext, configs);
@@ -86,23 +86,23 @@ public class LoginManagerTest {
         JaasContext scramJaasContext = JaasContext.loadServerContext(listenerName, "SCRAM-SHA-256", configs);
 
         LoginManager dynamicPlainLogin = LoginManager.acquireLoginManager(plainJaasContext, "PLAIN",
-                false, configs);
+                DefaultLogin.class, configs);
         assertEquals(dynamicPlainContext, dynamicPlainLogin.cacheKey());
         LoginManager dynamicDigestLogin = LoginManager.acquireLoginManager(digestJaasContext, "DIGEST-MD5",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(dynamicPlainLogin, dynamicDigestLogin);
         assertEquals(dynamicDigestContext, dynamicDigestLogin.cacheKey());
         LoginManager staticScramLogin = LoginManager.acquireLoginManager(scramJaasContext, "SCRAM-SHA-256",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(dynamicPlainLogin, staticScramLogin);
         assertEquals("KafkaServer", staticScramLogin.cacheKey());
 
         assertSame(dynamicPlainLogin, LoginManager.acquireLoginManager(plainJaasContext, "PLAIN",
-                false, configs));
+                DefaultLogin.class, configs));
         assertSame(dynamicDigestLogin, LoginManager.acquireLoginManager(digestJaasContext, "DIGEST-MD5",
-                false, configs));
+                DefaultLogin.class, configs));
         assertSame(staticScramLogin, LoginManager.acquireLoginManager(scramJaasContext, "SCRAM-SHA-256",
-                false, configs));
+                DefaultLogin.class, configs));
 
         verifyLoginManagerRelease(dynamicPlainLogin, 2, plainJaasContext, configs);
         verifyLoginManagerRelease(dynamicDigestLogin, 2, digestJaasContext, configs);
@@ -116,13 +116,13 @@ public class LoginManagerTest {
         for (int i = 0; i < acquireCount - 1; i++)
             loginManager.release();
         assertSame(loginManager, LoginManager.acquireLoginManager(jaasContext, "PLAIN",
-                false, configs));
+                DefaultLogin.class, configs));
 
         // Release all references and verify that new LoginManager is created on next acquire
         for (int i = 0; i < 2; i++) // release all references
             loginManager.release();
         LoginManager newLoginManager = LoginManager.acquireLoginManager(jaasContext, "PLAIN",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(loginManager, newLoginManager);
         newLoginManager.release();
     }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
index b8edc61..bfd1d97 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
@@ -16,6 +16,32 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.security.auth.Subject;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.Configuration;
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.auth.login.LoginContext;
+import javax.security.auth.login.LoginException;
+
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.config.SaslConfigs;
@@ -37,6 +63,7 @@ import org.apache.kafka.common.network.Send;
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.security.auth.Login;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.AbstractResponse;
@@ -54,32 +81,20 @@ import org.apache.kafka.common.security.TestSecurityConfig;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
 import org.apache.kafka.common.security.scram.ScramCredential;
-import org.apache.kafka.common.security.scram.ScramCredentialUtils;
-import org.apache.kafka.common.security.scram.ScramFormatter;
+import org.apache.kafka.common.security.scram.internal.ScramCredentialUtils;
+import org.apache.kafka.common.security.scram.internal.ScramFormatter;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.apache.kafka.common.security.token.delegation.TokenInformation;
 import org.apache.kafka.common.utils.SecurityUtils;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.authenticator.TestDigestLoginModule.DigestServerCallbackHandler;
+import org.apache.kafka.common.security.plain.internal.PlainServerCallbackHandler;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import javax.security.auth.Subject;
-import javax.security.auth.login.Configuration;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
-import java.nio.channels.SelectionKey;
-import java.security.NoSuchAlgorithmException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -110,6 +125,7 @@ public class SaslAuthenticatorTest {
         saslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores);
         saslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores);
         credentialCache = new CredentialCache();
+        TestLogin.loginCount.set(0);
     }
 
     @After
@@ -234,6 +250,7 @@ public class SaslAuthenticatorTest {
         String node = "0";
         SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
         configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5"));
+        configureDigestMd5ServerCallback(securityProtocol);
 
         server = createEchoServer(securityProtocol);
         createAndCheckClientConnection(securityProtocol, node);
@@ -247,6 +264,7 @@ public class SaslAuthenticatorTest {
     public void testMultipleServerMechanisms() throws Exception {
         SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
         configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN", "SCRAM-SHA-256"));
+        configureDigestMd5ServerCallback(securityProtocol);
         server = createEchoServer(securityProtocol);
         updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
 
@@ -681,6 +699,207 @@ public class SaslAuthenticatorTest {
     }
 
     /**
+     * Tests SASL client authentication callback handler override.
+     */
+    @Test
+    public void testClientAuthenticateCallbackHandler() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        saslClientConfigs.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, TestClientCallbackHandler.class.getName());
+        jaasConfig.setClientOptions("PLAIN", "", ""); // remove username, password in login context
+
+        Map<String, Object> options = new HashMap<>();
+        options.put("user_" + TestClientCallbackHandler.USERNAME, TestClientCallbackHandler.PASSWORD);
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), options);
+        server = createEchoServer(securityProtocol);
+        createAndCheckClientConnection(securityProtocol, "good");
+
+        options.clear();
+        options.put("user_" + TestClientCallbackHandler.USERNAME, "invalid-password");
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), options);
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+    }
+
+    /**
+     * Tests SASL server authentication callback handler override.
+     */
+    @Test
+    public void testServerAuthenticateCallbackHandler() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), new HashMap<String, Object>());
+        String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN");
+        saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class.getName());
+        server = createEchoServer(securityProtocol);
+
+        // Set client username/password to the values used by `TestServerCallbackHandler`
+        jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
+        createAndCheckClientConnection(securityProtocol, "good");
+
+        // Set client username/password to the invalid values
+        jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalid-password");
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+    }
+
+    /**
+     * Test that callback handlers are only applied to connections for the mechanisms
+     * configured for the handler. Test enables two mechanisms 'PLAIN` and `DIGEST-MD5`
+     * on the servers with different callback handlers for the two mechanisms. Verifies
+     * that clients using both mechanisms authenticate successfully.
+     */
+    @Test
+    public void testAuthenticateCallbackHandlerMechanisms() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN"));
+
+        // Connections should fail using the digest callback handler if listener.mechanism prefix not specified
+        saslServerConfigs.put("plain." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class);
+        saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                DigestServerCallbackHandler.class);
+        server = createEchoServer(securityProtocol);
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+
+        // Connections should succeed using the server callback handler associated with the listener
+        ListenerName listener = ListenerName.forSecurityProtocol(securityProtocol);
+        saslServerConfigs.remove("plain." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
+        saslServerConfigs.remove("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
+        saslServerConfigs.put(listener.saslMechanismConfigPrefix("plain") + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class);
+        saslServerConfigs.put(listener.saslMechanismConfigPrefix("digest-md5") + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                DigestServerCallbackHandler.class);
+        server = createEchoServer(securityProtocol);
+
+        // Verify that DIGEST-MD5 (currently configured for client) works with `DigestServerCallbackHandler`
+        createAndCheckClientConnection(securityProtocol, "good-digest-md5");
+
+        // Verify that PLAIN works with `TestServerCallbackHandler`
+        jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
+        saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN");
+        createAndCheckClientConnection(securityProtocol, "good-plain");
+    }
+
+    /**
+     * Tests SASL login class override.
+     */
+    @Test
+    public void testClientLoginOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.setClientOptions("PLAIN", "invaliduser", "invalidpassword");
+        server = createEchoServer(securityProtocol);
+
+        // Connection should succeed using login override that sets correct username/password in Subject
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CLASS, TestLogin.class.getName());
+        createAndCheckClientConnection(securityProtocol, "1");
+        assertEquals(1, TestLogin.loginCount.get());
+
+        // Connection should fail without login override since username/password in jaas config is invalid
+        saslClientConfigs.remove(SaslConfigs.SASL_LOGIN_CLASS);
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+        assertEquals(1, TestLogin.loginCount.get());
+    }
+
+    /**
+     * Tests SASL server login class override.
+     */
+    @Test
+    public void testServerLoginOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        String prefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN");
+        saslServerConfigs.put(prefix + SaslConfigs.SASL_LOGIN_CLASS, TestLogin.class.getName());
+        server = createEchoServer(securityProtocol);
+
+        // Login is performed when server channel builder is created (before any connections are made on the server)
+        assertEquals(1, TestLogin.loginCount.get());
+
+        createAndCheckClientConnection(securityProtocol, "1");
+        assertEquals(1, TestLogin.loginCount.get());
+    }
+
+    /**
+     * Tests SASL login callback class override.
+     */
+    @Test
+    public void testClientLoginCallbackOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, TestPlainLoginModule.class.getName(),
+                Collections.<String, Object>emptyMap());
+        server = createEchoServer(securityProtocol);
+
+        // Connection should succeed using login callback override that sets correct username/password
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, TestLoginCallbackHandler.class.getName());
+        createAndCheckClientConnection(securityProtocol, "1");
+
+        // Connection should fail without login callback override since username/password in jaas config is invalid
+        saslClientConfigs.remove(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        try {
+            createClientConnection(securityProtocol, "invalid");
+        } catch (Exception e) {
+            assertTrue("Unexpected exception " + e.getCause(), e.getCause() instanceof LoginException);
+        }
+    }
+
+    /**
+     * Tests SASL server login callback class override.
+     */
+    @Test
+    public void testServerLoginCallbackOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, TestPlainLoginModule.class.getName(),
+                Collections.<String, Object>emptyMap());
+        jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
+        ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
+        String prefix = listenerName.saslMechanismConfigPrefix("PLAIN");
+        saslServerConfigs.put(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class);
+        Class<?> loginCallback = TestLoginCallbackHandler.class;
+
+        try {
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with default login handler");
+        } catch (KafkaException e) {
+            // Expected exception
+        }
+
+        try {
+            saslServerConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with login handler config without listener+mechanism prefix");
+        } catch (KafkaException e) {
+            // Expected exception
+            saslServerConfigs.remove(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        }
+
+        try {
+            saslServerConfigs.put("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with login handler config without listener prefix");
+        } catch (KafkaException e) {
+            // Expected exception
+            saslServerConfigs.remove("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        }
+
+        try {
+            saslServerConfigs.put(listenerName.configPrefix() + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with login handler config without mechanism prefix");
+        } catch (KafkaException e) {
+            // Expected exception
+            saslServerConfigs.remove("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        }
+
+        // Connection should succeed using login callback override for mechanism
+        saslServerConfigs.put(prefix + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+        server = createEchoServer(securityProtocol);
+        createAndCheckClientConnection(securityProtocol, "1");
+    }
+
+    /**
      * Tests that mechanisms with default implementation in Kafka may be disabled in
      * the Kafka server by removing from the enabled mechanism list.
      */
@@ -1028,10 +1247,12 @@ public class SaslAuthenticatorTest {
                 securityProtocol, listenerName, false, saslMechanism, true, credentialCache, null) {
 
             @Override
-            protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs, String id,
-                            TransportLayer transportLayer, Map<String, Subject> subjects) throws IOException {
-                return new SaslServerAuthenticator(configs, id, jaasContexts, subjects, null,
-                                credentialCache, listenerName, securityProtocol, transportLayer, null) {
+            protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs,
+                                                                       Map<String, AuthenticateCallbackHandler> callbackHandlers,
+                                                                       String id,
+                                                                       TransportLayer transportLayer,
+                                                                       Map<String, Subject> subjects) throws IOException {
+                return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, null, listenerName, securityProtocol, transportLayer) {
 
                     @Override
                     protected ApiVersionsResponse apiVersionsResponse() {
@@ -1072,11 +1293,15 @@ public class SaslAuthenticatorTest {
                 securityProtocol, listenerName, false, saslMechanism, true, null, null) {
 
             @Override
-            protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs, String id,
-                    String serverHost, String servicePrincipal,
-                    TransportLayer transportLayer, Subject subject) throws IOException {
-
-                return new SaslClientAuthenticator(configs, id, subject,
+            protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
+                                                                       AuthenticateCallbackHandler callbackHandler,
+                                                                       String id,
+                                                                       String serverHost,
+                                                                       String servicePrincipal,
+                                                                       TransportLayer transportLayer,
+                                                                       Subject subject) throws IOException {
+
+                return new SaslClientAuthenticator(configs, callbackHandler, id, subject,
                         servicePrincipal, serverHost, saslMechanism, true, transportLayer) {
                     @Override
                     protected SaslHandshakeRequest createSaslHandshakeRequest(short version) {
@@ -1173,9 +1398,19 @@ public class SaslAuthenticatorTest {
     private TestJaasConfig configureMechanisms(String clientMechanism, List<String> serverMechanisms) {
         saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism);
         saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms);
+        if (serverMechanisms.contains("DIGEST-MD5")) {
+            saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                    TestDigestLoginModule.DigestServerCallbackHandler.class.getName());
+        }
         return TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms);
     }
 
+    private void configureDigestMd5ServerCallback(SecurityProtocol securityProtocol) {
+        String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("DIGEST-MD5");
+        saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestDigestLoginModule.DigestServerCallbackHandler.class);
+    }
+
     private void createSelector(SecurityProtocol securityProtocol, Map<String, Object> clientConfigs) {
         if (selector != null) {
             selector.close();
@@ -1261,6 +1496,28 @@ public class SaslAuthenticatorTest {
         return selector.completedReceives().get(0).payload();
     }
 
+    public static class TestServerCallbackHandler extends PlainServerCallbackHandler {
+
+        static final String USERNAME = "TestServerCallbackHandler-user";
+        static final String PASSWORD = "TestServerCallbackHandler-password";
+        private volatile boolean configured;
+
+        @Override
+        public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            if (configured)
+                throw new IllegalStateException("Server callback handler configured twice");
+            configured = true;
+            super.configure(configs, mechanism, jaasConfigEntries);
+        }
+
+        @Override
+        protected boolean authenticate(String username, char[] password) throws IOException {
+            if (!configured)
+                throw new IllegalStateException("Server callback handler not configured");
+            return USERNAME.equals(username) && new String(password).equals(PASSWORD);
+        }
+    }
+
     @SuppressWarnings("unchecked")
     private void updateScramCredentialCache(String username, String password) throws NoSuchAlgorithmException {
         for (String mechanism : (List<String>) saslServerConfigs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG)) {
@@ -1289,4 +1546,121 @@ public class SaslAuthenticatorTest {
             }
         }
     }
+
+    public static class TestClientCallbackHandler implements AuthenticateCallbackHandler {
+
+        static final String USERNAME = "TestClientCallbackHandler-user";
+        static final String PASSWORD = "TestClientCallbackHandler-password";
+        private volatile boolean configured;
+
+        @Override
+        public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            if (configured)
+                throw new IllegalStateException("Client callback handler configured twice");
+            configured = true;
+        }
+
+        @Override
+        public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+            if (!configured)
+                throw new IllegalStateException("Client callback handler not configured");
+            for (Callback callback : callbacks) {
+                if (callback instanceof NameCallback)
+                    ((NameCallback) callback).setName(USERNAME);
+                else if (callback instanceof PasswordCallback)
+                    ((PasswordCallback) callback).setPassword(PASSWORD.toCharArray());
+                else
+                    throw new UnsupportedCallbackException(callback);
+            }
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    public static class TestLogin implements Login {
+
+        static AtomicInteger loginCount = new AtomicInteger();
+
+        private String contextName;
+        private Configuration configuration;
+        private Subject subject;
+        @Override
+        public void configure(Map<String, ?> configs, String contextName, Configuration configuration,
+                              AuthenticateCallbackHandler callbackHandler) {
+            assertEquals(1, configuration.getAppConfigurationEntry(contextName).length);
+            this.contextName = contextName;
+            this.configuration = configuration;
+        }
+
+        @Override
+        public LoginContext login() throws LoginException {
+            LoginContext context = new LoginContext(contextName, null, new AbstractLogin.DefaultLoginCallbackHandler(), configuration);
+            context.login();
+            subject = context.getSubject();
+            subject.getPublicCredentials().clear();
+            subject.getPrivateCredentials().clear();
+            subject.getPublicCredentials().add(TestJaasConfig.USERNAME);
+            subject.getPrivateCredentials().add(TestJaasConfig.PASSWORD);
+            loginCount.incrementAndGet();
+            return context;
+        }
+
+        @Override
+        public Subject subject() {
+            return subject;
+        }
+
+        @Override
+        public String serviceName() {
+            return "kafka";
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    public static class TestLoginCallbackHandler implements AuthenticateCallbackHandler {
+        private volatile boolean configured = false;
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            if (configured)
+                throw new IllegalStateException("Login callback handler configured twice");
+            configured = true;
+        }
+
+        @Override
+        public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+            if (!configured)
+                throw new IllegalStateException("Login callback handler not configured");
+
+            for (Callback callback : callbacks) {
+                if (callback instanceof NameCallback)
+                    ((NameCallback) callback).setName(TestJaasConfig.USERNAME);
+                else if (callback instanceof PasswordCallback)
+                    ((PasswordCallback) callback).setPassword(TestJaasConfig.PASSWORD.toCharArray());
+            }
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    public static final class TestPlainLoginModule extends PlainLoginModule {
+        @Override
+        public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) {
+            try {
+                NameCallback nameCallback = new NameCallback("name:");
+                PasswordCallback passwordCallback = new PasswordCallback("password:", false);
+                callbackHandler.handle(new Callback[]{nameCallback, passwordCallback});
+                subject.getPublicCredentials().add(nameCallback.getName());
+                subject.getPrivateCredentials().add(new String(passwordCallback.getPassword()));
+            } catch (Exception e) {
+                throw new SaslAuthenticationException("Login initialization failed", e);
+            }
+        }
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
index 17d31bd..3ec3031 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
@@ -22,13 +22,12 @@ import org.apache.kafka.common.network.InvalidReceiveException;
 import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.protocol.types.Struct;
 import org.apache.kafka.common.requests.RequestHeader;
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.easymock.Capture;
 import org.easymock.EasyMock;
 import org.easymock.IAnswer;
@@ -41,7 +40,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
-import static org.apache.kafka.common.security.scram.ScramMechanism.SCRAM_SHA_256;
+import static org.apache.kafka.common.security.scram.internal.ScramMechanism.SCRAM_SHA_256;
 import static org.junit.Assert.fail;
 
 public class SaslServerAuthenticatorTest {
@@ -112,8 +111,10 @@ public class SaslServerAuthenticatorTest {
         Map<String, JaasContext> jaasContexts = Collections.singletonMap(mechanism,
                 new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null));
         Map<String, Subject> subjects = Collections.singletonMap(mechanism, new Subject());
-        return new SaslServerAuthenticator(configs, "node", jaasContexts, subjects, null, new CredentialCache(),
-                new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, new DelegationTokenCache(ScramMechanism.mechanismNames()));
+        Map<String, AuthenticateCallbackHandler> callbackHandlers = Collections.<String, AuthenticateCallbackHandler>singletonMap(
+                mechanism, new SaslServerCallbackHandler());
+        return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null,
+                new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer);
     }
 
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java
index f1ef740..97b0b27 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java
@@ -17,62 +17,47 @@
 package org.apache.kafka.common.security.authenticator;
 
 import java.io.IOException;
-import java.security.Provider;
-import java.security.Security;
-import java.util.Arrays;
-import java.util.Enumeration;
-import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import javax.security.auth.callback.Callback;
-import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
-import javax.security.sasl.Sasl;
-import javax.security.sasl.SaslException;
-import javax.security.sasl.SaslServer;
-import javax.security.sasl.SaslServerFactory;
 
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
 
 /**
- * Digest-MD5 login module for multi-mechanism tests. Since callback handlers are not configurable in Kafka
- * yet, this replaces the standard Digest-MD5 SASL server provider with one that invokes the test callback handler.
+ * Digest-MD5 login module for multi-mechanism tests.
  * This login module uses the same format as PlainLoginModule and hence simply reuses the same methods.
  *
  */
 public class TestDigestLoginModule extends PlainLoginModule {
 
-    private static final SaslServerFactory STANDARD_DIGEST_SASL_SERVER_FACTORY;
-    static {
-        SaslServerFactory digestSaslServerFactory = null;
-        Enumeration<SaslServerFactory> factories = Sasl.getSaslServerFactories();
-        Map<String, Object> emptyProps = new HashMap<>();
-        while (factories.hasMoreElements()) {
-            SaslServerFactory factory = factories.nextElement();
-            if (Arrays.asList(factory.getMechanismNames(emptyProps)).contains("DIGEST-MD5")) {
-                digestSaslServerFactory = factory;
-                break;
-            }
-        }
-        STANDARD_DIGEST_SASL_SERVER_FACTORY = digestSaslServerFactory;
-        Security.insertProviderAt(new DigestSaslServerProvider(), 1);
-    }
+    public static class DigestServerCallbackHandler implements AuthenticateCallbackHandler {
 
-    public static class DigestServerCallbackHandler implements CallbackHandler {
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        }
 
         @Override
         public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+            String username = null;
             for (Callback callback : callbacks) {
                 if (callback instanceof NameCallback) {
                     NameCallback nameCallback = (NameCallback) callback;
-                    nameCallback.setName(nameCallback.getDefaultName());
+                    if (TestJaasConfig.USERNAME.equals(nameCallback.getDefaultName())) {
+                        nameCallback.setName(nameCallback.getDefaultName());
+                        username = TestJaasConfig.USERNAME;
+                    }
                 } else if (callback instanceof PasswordCallback) {
                     PasswordCallback passwordCallback = (PasswordCallback) callback;
-                    passwordCallback.setPassword(TestJaasConfig.PASSWORD.toCharArray());
+                    if (TestJaasConfig.USERNAME.equals(username))
+                        passwordCallback.setPassword(TestJaasConfig.PASSWORD.toCharArray());
                 } else if (callback instanceof RealmCallback) {
                     RealmCallback realmCallback = (RealmCallback) callback;
                     realmCallback.setText(realmCallback.getDefaultText());
@@ -85,30 +70,9 @@ public class TestDigestLoginModule extends PlainLoginModule {
                 }
             }
         }
-    }
-
-    public static class DigestSaslServerFactory implements SaslServerFactory {
-
-        @Override
-        public SaslServer createSaslServer(String mechanism, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh)
-                throws SaslException {
-            return STANDARD_DIGEST_SASL_SERVER_FACTORY.createSaslServer(mechanism, protocol, serverName, props, new DigestServerCallbackHandler());
-        }
 
         @Override
-        public String[] getMechanismNames(Map<String, ?> props) {
-            return new String[] {"DIGEST-MD5"};
-        }
-    }
-
-    public static class DigestSaslServerProvider extends Provider {
-
-        private static final long serialVersionUID = 1L;
-
-        @SuppressWarnings("deprecation")
-        protected DigestSaslServerProvider() {
-            super("Test SASL/Digest-MD5 Server Provider", 1.0, "Test SASL/Digest-MD5 Server Provider for Kafka");
-            put("SaslServerFactory.DIGEST-MD5", TestDigestLoginModule.DigestSaslServerFactory.class.getName());
+        public void close() {
         }
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java
index dafa79d..3ee7c2c 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java
@@ -28,7 +28,7 @@ import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag;
 import org.apache.kafka.common.config.types.Password;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 
 public class TestJaasConfig extends Configuration {
 
diff --git a/clients/src/test/java/org/apache/kafka/common/security/plain/PlainSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerTest.java
similarity index 89%
rename from clients/src/test/java/org/apache/kafka/common/security/plain/PlainSaslServerTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerTest.java
index 86baf3e..1410c8a 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/plain/PlainSaslServerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerTest.java
@@ -14,8 +14,9 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.plain;
+package org.apache.kafka.common.security.plain.internal;
 
+import org.apache.kafka.common.security.plain.PlainLoginModule;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -46,7 +47,9 @@ public class PlainSaslServerTest {
         options.put("user_" + USER_B, PASSWORD_B);
         jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), options);
         JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null);
-        saslServer = new PlainSaslServer(jaasContext);
+        PlainServerCallbackHandler callbackHandler = new PlainServerCallbackHandler();
+        callbackHandler.configure(null, "PLAIN", jaasContext.configurationEntries());
+        saslServer = new PlainSaslServer(callbackHandler);
     }
 
     @Test
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramCredentialUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtilsTest.java
similarity index 97%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramCredentialUtilsTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtilsTest.java
index e9dd285..a1a1d20 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramCredentialUtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtilsTest.java
@@ -14,23 +14,23 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import org.junit.Test;
+import java.security.NoSuchAlgorithmException;
+import java.util.Arrays;
 
+import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
+
+import org.junit.Before;
+import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
-
-import java.security.NoSuchAlgorithmException;
-import java.util.Arrays;
-
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertTrue;
 
-import org.apache.kafka.common.security.authenticator.CredentialCache;
-import org.junit.Before;
 
 public class ScramCredentialUtilsTest {
 
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramFormatterTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramFormatterTest.java
similarity index 91%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramFormatterTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramFormatterTest.java
index a86e0dd..b06b039 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramFormatterTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramFormatterTest.java
@@ -14,19 +14,18 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import org.apache.kafka.common.utils.Base64;
-import org.junit.Test;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 
+import org.junit.Test;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
-
 public class ScramFormatterTest {
 
     /**
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramMessagesTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramMessagesTest.java
similarity index 96%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramMessagesTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramMessagesTest.java
index 7b04ede..d856f37 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramMessagesTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramMessagesTest.java
@@ -14,29 +14,28 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
-
-import org.apache.kafka.common.utils.Base64;
-import org.junit.Before;
-import org.junit.Test;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.nio.charset.StandardCharsets;
 import java.util.Collections;
 
 import javax.security.sasl.SaslException;
 
+import org.apache.kafka.common.security.scram.internal.ScramMessages.AbstractScramMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.utils.Base64;
+
+import org.junit.Before;
+import org.junit.Test;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
-import org.apache.kafka.common.security.scram.ScramMessages.AbstractScramMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
-
 public class ScramMessagesTest {
 
     private static final String[] VALID_EXTENSIONS = {
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerTest.java
similarity index 96%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramSaslServerTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerTest.java
index 82ad914..3c4b82d 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramSaslServerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerTest.java
@@ -14,19 +14,20 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
-import org.junit.Before;
-import org.junit.Test;
 
 import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 
-import static org.junit.Assert.assertTrue;
-
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
+
+import org.junit.Before;
+import org.junit.Test;
+import static org.junit.Assert.assertTrue;
 
 public class ScramSaslServerTest {
 
diff --git a/core/src/main/scala/kafka/admin/ConfigCommand.scala b/core/src/main/scala/kafka/admin/ConfigCommand.scala
index 3563448..c19599d 100644
--- a/core/src/main/scala/kafka/admin/ConfigCommand.scala
+++ b/core/src/main/scala/kafka/admin/ConfigCommand.scala
@@ -32,7 +32,7 @@ import org.apache.kafka.clients.CommonClientConfigs
 import org.apache.kafka.clients.admin.{AlterConfigsOptions, ConfigEntry, DescribeConfigsOptions, AdminClient => JAdminClient, Config => JConfig}
 import org.apache.kafka.common.config.ConfigResource
 import org.apache.kafka.common.security.JaasUtils
-import org.apache.kafka.common.security.scram._
+import org.apache.kafka.common.security.scram.internal.{ScramCredentialUtils, ScramFormatter, ScramMechanism}
 import org.apache.kafka.common.utils.{Sanitizer, Time, Utils}
 
 import scala.collection._
diff --git a/core/src/main/scala/kafka/security/CredentialProvider.scala b/core/src/main/scala/kafka/security/CredentialProvider.scala
index 0e7ebb6..6f9c252 100644
--- a/core/src/main/scala/kafka/security/CredentialProvider.scala
+++ b/core/src/main/scala/kafka/security/CredentialProvider.scala
@@ -20,9 +20,10 @@ package kafka.security
 import java.util.{Collection, Properties}
 
 import org.apache.kafka.common.security.authenticator.CredentialCache
-import org.apache.kafka.common.security.scram.{ScramCredential, ScramCredentialUtils, ScramMechanism}
+import org.apache.kafka.common.security.scram.ScramCredential
 import org.apache.kafka.common.config.ConfigDef
 import org.apache.kafka.common.config.ConfigDef._
+import org.apache.kafka.common.security.scram.internal.{ScramCredentialUtils, ScramMechanism}
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCache
 
 class CredentialProvider(scramMechanisms: Collection[String], val tokenCache: DelegationTokenCache) {
diff --git a/core/src/main/scala/kafka/server/DelegationTokenManager.scala b/core/src/main/scala/kafka/server/DelegationTokenManager.scala
index 008dc32..4a947a1 100644
--- a/core/src/main/scala/kafka/server/DelegationTokenManager.scala
+++ b/core/src/main/scala/kafka/server/DelegationTokenManager.scala
@@ -29,7 +29,8 @@ import kafka.utils.{CoreUtils, Json, Logging}
 import kafka.zk.{DelegationTokenChangeNotificationSequenceZNode, DelegationTokenChangeNotificationZNode, DelegationTokensZNode, KafkaZkClient}
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.security.auth.KafkaPrincipal
-import org.apache.kafka.common.security.scram.{ScramCredential, ScramFormatter, ScramMechanism}
+import org.apache.kafka.common.security.scram.internal.{ScramFormatter, ScramMechanism}
+import org.apache.kafka.common.security.scram.ScramCredential
 import org.apache.kafka.common.security.token.delegation.{DelegationToken, DelegationTokenCache, TokenInformation}
 import org.apache.kafka.common.utils.{Base64, Sanitizer, SecurityUtils, Time}
 
diff --git a/core/src/main/scala/kafka/server/DynamicConfigManager.scala b/core/src/main/scala/kafka/server/DynamicConfigManager.scala
index 728f88c..d56b46f 100644
--- a/core/src/main/scala/kafka/server/DynamicConfigManager.scala
+++ b/core/src/main/scala/kafka/server/DynamicConfigManager.scala
@@ -22,9 +22,9 @@ import java.nio.charset.StandardCharsets
 import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener}
 import kafka.utils.{Json, Logging}
 import kafka.utils.json.JsonObject
-import kafka.zk.{KafkaZkClient, AdminZkClient, ConfigEntityChangeNotificationZNode, ConfigEntityChangeNotificationSequenceZNode}
+import kafka.zk.{AdminZkClient, ConfigEntityChangeNotificationSequenceZNode, ConfigEntityChangeNotificationZNode, KafkaZkClient}
 import org.apache.kafka.common.config.types.Password
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.utils.Time
 
 import scala.collection.JavaConverters._
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index 5a1dca3..78aac68 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -425,6 +425,10 @@ object KafkaConfig {
   val SaslMechanismInterBrokerProtocolProp = "sasl.mechanism.inter.broker.protocol"
   val SaslJaasConfigProp = SaslConfigs.SASL_JAAS_CONFIG
   val SaslEnabledMechanismsProp = BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG
+  val SaslServerCallbackHandlerClassProp = BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS
+  val SaslClientCallbackHandlerClassProp = SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS
+  val SaslLoginClassProp = SaslConfigs.SASL_LOGIN_CLASS
+  val SaslLoginCallbackHandlerClassProp = SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS
   val SaslKerberosServiceNameProp = SaslConfigs.SASL_KERBEROS_SERVICE_NAME
   val SaslKerberosKinitCmdProp = SaslConfigs.SASL_KERBEROS_KINIT_CMD
   val SaslKerberosTicketRenewWindowFactorProp = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR
@@ -713,7 +717,11 @@ object KafkaConfig {
   /** ********* Sasl Configuration ****************/
   val SaslMechanismInterBrokerProtocolDoc = "SASL mechanism used for inter-broker communication. Default is GSSAPI."
   val SaslJaasConfigDoc = SaslConfigs.SASL_JAAS_CONFIG_DOC
-  val SaslEnabledMechanismsDoc = SaslConfigs.SASL_ENABLED_MECHANISMS_DOC
+  val SaslEnabledMechanismsDoc = BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_DOC
+  val SaslServerCallbackHandlerClassDoc = BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC
+  val SaslClientCallbackHandlerClassDoc = SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC
+  val SaslLoginClassDoc = SaslConfigs.SASL_LOGIN_CLASS_DOC
+  val SaslLoginCallbackHandlerClassDoc = SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC
   val SaslKerberosServiceNameDoc = SaslConfigs.SASL_KERBEROS_SERVICE_NAME_DOC
   val SaslKerberosKinitCmdDoc = SaslConfigs.SASL_KERBEROS_KINIT_CMD_DOC
   val SaslKerberosTicketRenewWindowFactorDoc = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR_DOC
@@ -937,6 +945,10 @@ object KafkaConfig {
       .define(SaslMechanismInterBrokerProtocolProp, STRING, Defaults.SaslMechanismInterBrokerProtocol, MEDIUM, SaslMechanismInterBrokerProtocolDoc)
       .define(SaslJaasConfigProp, PASSWORD, null, MEDIUM, SaslJaasConfigDoc)
       .define(SaslEnabledMechanismsProp, LIST, Defaults.SaslEnabledMechanisms, MEDIUM, SaslEnabledMechanismsDoc)
+      .define(SaslServerCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslServerCallbackHandlerClassDoc)
+      .define(SaslClientCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslClientCallbackHandlerClassDoc)
+      .define(SaslLoginClassProp, CLASS, null, MEDIUM, SaslLoginClassDoc)
+      .define(SaslLoginCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslLoginCallbackHandlerClassDoc)
       .define(SaslKerberosServiceNameProp, STRING, null, MEDIUM, SaslKerberosServiceNameDoc)
       .define(SaslKerberosKinitCmdProp, STRING, Defaults.SaslKerberosKinitCmd, MEDIUM, SaslKerberosKinitCmdDoc)
       .define(SaslKerberosTicketRenewWindowFactorProp, DOUBLE, Defaults.SaslKerberosTicketRenewWindowFactor, MEDIUM, SaslKerberosTicketRenewWindowFactorDoc)
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala
index d7ca656..53632cd 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -44,7 +44,7 @@ import org.apache.kafka.common.network._
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{ControlledShutdownRequest, ControlledShutdownResponse}
 import org.apache.kafka.common.security.auth.SecurityProtocol
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCache
 import org.apache.kafka.common.security.{JaasContext, JaasUtils}
 import org.apache.kafka.common.utils.{AppInfoParser, LogContext, Time}
diff --git a/core/src/main/scala/kafka/utils/VerifiableProperties.scala b/core/src/main/scala/kafka/utils/VerifiableProperties.scala
index de4f654..5d70db5 100755
--- a/core/src/main/scala/kafka/utils/VerifiableProperties.scala
+++ b/core/src/main/scala/kafka/utils/VerifiableProperties.scala
@@ -122,7 +122,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
     require(v >= range._1 && v <= range._2, name + " has value " + v + " which is not in the range " + range + ".")
     v
   }
-  
+
   /**
    * Get a required argument as a double
    * @param name The property name
@@ -130,7 +130,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
    * @throws IllegalArgumentException If the given property is not present
    */
   def getDouble(name: String): Double = getString(name).toDouble
-  
+
   /**
    * Get an optional argument as a double
    * @param name The property name
@@ -141,7 +141,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
       getDouble(name)
     else
       default
-  } 
+  }
 
   /**
    * Read a boolean value from the properties instance
@@ -158,7 +158,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
       v.toBoolean
     }
   }
-  
+
   def getBoolean(name: String) = getString(name).toBoolean
 
   /**
@@ -178,7 +178,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
     require(containsKey(name), "Missing required property '" + name + "'")
     getProperty(name)
   }
-  
+
   /**
    * Get a Map[String, String] from a property list in the form k1:v2, k2:v2, ...
    */
diff --git a/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala
index 24f435e..27a6d11 100644
--- a/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala
@@ -24,7 +24,7 @@ import kafka.utils.{JaasTestUtils, TestUtils, ZkUtils}
 import org.apache.kafka.clients.admin.AdminClientConfig
 import org.apache.kafka.common.config.SaslConfigs
 import org.apache.kafka.common.security.auth.SecurityProtocol
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.security.token.delegation.DelegationToken
 import org.junit.Before
 
diff --git a/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala
index f5409b1..a5bf331 100644
--- a/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala
@@ -63,6 +63,7 @@ abstract class SaslEndToEndAuthorizationTest extends EndToEndAuthorizationTest {
     consumer2Config ++= consumerConfig
     // consumer2 retrieves its credentials from the static JAAS configuration, so we test also this path
     consumer2Config.remove(SaslConfigs.SASL_JAAS_CONFIG)
+    consumer2Config.remove(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS)
 
     val consumer2 = TestUtils.createNewConsumer(brokerList,
                                                 securityProtocol = securityProtocol,
diff --git a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
index 08351aa..efb8c48 100644
--- a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
@@ -16,22 +16,33 @@
   */
 package kafka.api
 
-import kafka.utils.{CoreUtils, JaasTestUtils, TestUtils, ZkUtils}
+import java.security.AccessController
+import javax.security.auth.callback._
+import javax.security.auth.Subject
+import javax.security.auth.login.AppConfigurationEntry
+
+import kafka.server.KafkaConfig
+import kafka.utils.{CoreUtils, TestUtils, ZkUtils}
+import kafka.utils.JaasTestUtils._
+import org.apache.kafka.common.config.SaslConfigs
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs
+import org.apache.kafka.common.network.ListenerName
 import org.apache.kafka.common.security.JaasUtils
-import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal, KafkaPrincipalBuilder, SaslAuthenticationContext}
+import org.apache.kafka.common.security.auth._
+import org.apache.kafka.common.security.plain.PlainAuthenticateCallback
 import org.junit.Test
 
 object SaslPlainSslEndToEndAuthorizationTest {
+
   class TestPrincipalBuilder extends KafkaPrincipalBuilder {
 
     override def build(context: AuthenticationContext): KafkaPrincipal = {
       context match {
         case ctx: SaslAuthenticationContext =>
           ctx.server.getAuthorizationID match {
-            case JaasTestUtils.KafkaPlainAdmin =>
+            case KafkaPlainAdmin =>
               new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "admin")
-            case JaasTestUtils.KafkaPlainUser =>
+            case KafkaPlainUser =>
               new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user")
             case _ =>
               KafkaPrincipal.ANONYMOUS
@@ -39,18 +50,84 @@ object SaslPlainSslEndToEndAuthorizationTest {
       }
     }
   }
+
+  object Credentials {
+    val allUsers = Map(KafkaPlainUser -> "user1-password",
+      KafkaPlainUser2 -> KafkaPlainPassword2,
+      KafkaPlainAdmin -> "broker-password")
+  }
+
+  class TestServerCallbackHandler extends AuthenticateCallbackHandler {
+    def configure(configs: java.util.Map[String, _], saslMechanism: String, jaasConfigEntries: java.util.List[AppConfigurationEntry]) {}
+    def handle(callbacks: Array[Callback]) {
+      var username: String = null
+      for (callback <- callbacks) {
+        if (callback.isInstanceOf[NameCallback])
+          username = callback.asInstanceOf[NameCallback].getDefaultName
+        else if (callback.isInstanceOf[PlainAuthenticateCallback]) {
+          val plainCallback = callback.asInstanceOf[PlainAuthenticateCallback]
+          plainCallback.authenticated(Credentials.allUsers(username) == new String(plainCallback.password))
+        } else
+          throw new UnsupportedCallbackException(callback)
+      }
+    }
+    def close() {}
+  }
+
+  class TestClientCallbackHandler extends AuthenticateCallbackHandler {
+    def configure(configs: java.util.Map[String, _], saslMechanism: String, jaasConfigEntries: java.util.List[AppConfigurationEntry]) {}
+    def handle(callbacks: Array[Callback]) {
+      val subject = Subject.getSubject(AccessController.getContext())
+      val username = subject.getPublicCredentials(classOf[String]).iterator().next()
+      for (callback <- callbacks) {
+        if (callback.isInstanceOf[NameCallback])
+          callback.asInstanceOf[NameCallback].setName(username)
+        else if (callback.isInstanceOf[PasswordCallback]) {
+          if (username == KafkaPlainUser || username == KafkaPlainAdmin)
+            callback.asInstanceOf[PasswordCallback].setPassword(Credentials.allUsers(username).toCharArray)
+        } else
+          throw new UnsupportedCallbackException(callback)
+      }
+    }
+    def close() {}
+  }
 }
 
+
+// This test uses SASL callback handler overrides for server connections of Kafka broker
+// and client connections of Kafka producers and consumers. Client connections from Kafka brokers
+// used for inter-broker communication also use custom callback handlers. The second client used in
+// the multi-user test SaslEndToEndAuthorizationTest#testTwoConsumersWithDifferentSaslCredentials uses
+// static JAAS configuration with default callback handlers to test those code paths as well.
 class SaslPlainSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest {
-  import SaslPlainSslEndToEndAuthorizationTest.TestPrincipalBuilder
+  import SaslPlainSslEndToEndAuthorizationTest._
 
   this.serverConfig.setProperty(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, classOf[TestPrincipalBuilder].getName)
+  this.serverConfig.put(KafkaConfig.SaslClientCallbackHandlerClassProp, classOf[TestClientCallbackHandler].getName)
+  val mechanismPrefix = ListenerName.forSecurityProtocol(SecurityProtocol.SASL_SSL).saslMechanismConfigPrefix("PLAIN")
+  this.serverConfig.put(s"$mechanismPrefix${KafkaConfig.SaslServerCallbackHandlerClassProp}", classOf[TestServerCallbackHandler].getName)
+  this.producerConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName)
+  this.consumerConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName)
+  private val plainLogin = s"org.apache.kafka.common.security.plain.PlainLoginModule username=$KafkaPlainUser required;"
+  this.producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin)
+  this.consumerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin)
 
   override protected def kafkaClientSaslMechanism = "PLAIN"
   override protected def kafkaServerSaslMechanisms = List("PLAIN")
+
   override val clientPrincipal = "user"
   override val kafkaPrincipal = "admin"
 
+  override def jaasSections(kafkaServerSaslMechanisms: Seq[String],
+                            kafkaClientSaslMechanism: Option[String],
+                            mode: SaslSetupMode,
+                            kafkaServerEntryName: String): Seq[JaasSection] = {
+    val brokerLogin = new PlainLoginModule(KafkaPlainAdmin, "") // Password provided by callback handler
+    val clientLogin = new PlainLoginModule(KafkaPlainUser2, KafkaPlainPassword2)
+    Seq(JaasSection(kafkaServerEntryName, Seq(brokerLogin)),
+      JaasSection(KafkaClientContextName, Seq(clientLogin))) ++ zkSections
+  }
+
   /**
    * Checks that secure paths created by broker and acl paths created by AclCommand
    * have expected ACLs.
diff --git a/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala
index 000fc21..d304ffc 100644
--- a/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala
@@ -16,9 +16,9 @@
   */
 package kafka.api
 
-import org.apache.kafka.common.security.scram.ScramMechanism
 import kafka.utils.JaasTestUtils
 import kafka.zk.ConfigEntityChangeNotificationZNode
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 
 import scala.collection.JavaConverters._
 import org.junit.Before
diff --git a/core/src/test/scala/integration/kafka/api/SaslSetup.scala b/core/src/test/scala/integration/kafka/api/SaslSetup.scala
index 273b247..ab2819e 100644
--- a/core/src/test/scala/integration/kafka/api/SaslSetup.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslSetup.scala
@@ -30,7 +30,7 @@ import org.apache.kafka.common.config.SaslConfigs
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs
 import org.apache.kafka.common.security.JaasUtils
 import org.apache.kafka.common.security.authenticator.LoginManager
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 
 /*
  * Implements an enumeration for the modes enabled here:
diff --git a/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala b/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala
index a17f060..66e98f5 100644
--- a/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala
+++ b/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala
@@ -28,7 +28,7 @@ import org.apache.kafka.clients.admin._
 import org.apache.kafka.common.config.ConfigResource
 import org.apache.kafka.common.internals.KafkaFutureImpl
 import org.apache.kafka.common.Node
-import org.apache.kafka.common.security.scram.ScramCredentialUtils
+import org.apache.kafka.common.security.scram.internal.ScramCredentialUtils
 import org.apache.kafka.common.utils.Sanitizer
 import org.easymock.EasyMock
 import org.junit.Assert._
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 0dad3c7..e6dadbb 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -38,7 +38,7 @@ import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader}
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.utils.{LogContext, MockTime, Time}
 import org.apache.log4j.Level
 import org.junit.Assert._
diff --git a/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala b/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala
index 8c03548..b8388b4 100644
--- a/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala
@@ -29,7 +29,7 @@ import kafka.utils.TestUtils
 import kafka.zk.ZooKeeperTestHarness
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.security.auth.KafkaPrincipal
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.security.token.delegation.{DelegationToken, DelegationTokenCache, TokenInformation}
 import org.apache.kafka.common.utils.{MockTime, SecurityUtils}
 import org.junit.Assert._
diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala
index 0213c12..81470c0 100755
--- a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala
@@ -692,6 +692,10 @@ class KafkaConfigTest {
         //Sasl Configs
         case KafkaConfig.SaslMechanismInterBrokerProtocolProp => // ignore
         case KafkaConfig.SaslEnabledMechanismsProp =>
+        case KafkaConfig.SaslClientCallbackHandlerClassProp =>
+        case KafkaConfig.SaslServerCallbackHandlerClassProp =>
+        case KafkaConfig.SaslLoginClassProp =>
+        case KafkaConfig.SaslLoginCallbackHandlerClassProp =>
         case KafkaConfig.SaslKerberosServiceNameProp => // ignore string
         case KafkaConfig.SaslKerberosKinitCmdProp =>
         case KafkaConfig.SaslKerberosTicketRenewWindowFactorProp =>
diff --git a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala
index 9ce3b01..10c7345 100644
--- a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala
@@ -114,7 +114,7 @@ object JaasTestUtils {
   val KafkaServerContextName = "KafkaServer"
   val KafkaServerPrincipalUnqualifiedName = "kafka"
   private val KafkaServerPrincipal = KafkaServerPrincipalUnqualifiedName + "/localhost@EXAMPLE.COM"
-  private val KafkaClientContextName = "KafkaClient"
+  val KafkaClientContextName = "KafkaClient"
   val KafkaClientPrincipalUnqualifiedName = "client"
   private val KafkaClientPrincipal = KafkaClientPrincipalUnqualifiedName + "@EXAMPLE.COM"
   val KafkaClientPrincipalUnqualifiedName2 = "client2"

-- 
To stop receiving notification emails like this one, please contact
rsivaram@apache.org.

Mime
View raw message