incubator-heraldry-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ket...@apache.org
Subject svn commit: r493368 - in /incubator/heraldry/libraries/python/openid/trunk/openid: consumer/consumer.py test/test_association_response.py
Date Sat, 06 Jan 2007 05:25:35 GMT
Author: keturn
Date: Fri Jan  5 21:25:34 2007
New Revision: 493368

URL: http://svn.apache.org/viewvc?view=rev&rev=493368
Log:
[python-to-heraldry @ Re-worked consumer association creation internal API]
XXX: this broke the (untested) association negotiation. The ticket is
still outstanding, though, so I feel OK about it.

Original author: Josh Hoyt <josh@janrain.com>
Date: 2006-12-22 21:40:38+00:00

Modified:
    incubator/heraldry/libraries/python/openid/trunk/openid/consumer/consumer.py
    incubator/heraldry/libraries/python/openid/trunk/openid/test/test_association_response.py

Modified: incubator/heraldry/libraries/python/openid/trunk/openid/consumer/consumer.py
URL: http://svn.apache.org/viewvc/incubator/heraldry/libraries/python/openid/trunk/openid/consumer/consumer.py?view=diff&rev=493368&r1=493367&r2=493368
==============================================================================
--- incubator/heraldry/libraries/python/openid/trunk/openid/consumer/consumer.py (original)
+++ incubator/heraldry/libraries/python/openid/trunk/openid/consumer/consumer.py Fri Jan 
5 21:25:34 2007
@@ -420,6 +420,10 @@
         self.supported_assoc_type = supported_assoc_type
         self.supported_session_type = supported_session_type
 
+class ProtocolError(ValueError):
+    """Exception that indicates that a message violated the
+    protocol. It is raised and caught internally to this file."""
+
 class GenericConsumer(object):
     """This is the implementation of the common logic for OpenID
     consumers. It is unaware of the application in which it is
@@ -765,57 +769,101 @@
             return False
 
     def _getAssociation(self, endpoint):
+        """Get an association for the endpoint's server_url.
+
+        First try seeing if we have a good association in the
+        store. If we do not, then attempt to negotiate an association
+        with the server.
+
+        If we negotiate a good association, it will get stored.
+
+        @returns: A valid association for the endpoint's server_url or None
+        @rtype: openid.association.Association or NoneType
+        """
         if self.store.isDumb():
             return None
 
         assoc = self.store.getAssociation(endpoint.server_url)
 
         if assoc is None or assoc.expiresIn <= 0:
-            (assoc_type, session_type) = self.negotiator.getAllowedType()
-            tried_types = []
-            while (assoc_type, session_type) not in tried_types:
-                assoc_session, args = self._createAssociateRequest(
-                    endpoint,
-                    assoc_type,
-                    session_type)
+            assoc = self._negotiateAssociation(endpoint)
+            if assoc is not None:
+                self.store.storeAssociation(endpoint.server_url, assoc)
+
+        return assoc
+
+    def _negotiateAssociation(self, endpoint):
+        """Make association requests to the server, attempting to
+        create a new association.
+        """
+        # Get our preferred session/association type from the negotiatior.
+        assoc_type, session_type = self.negotiator.getAllowedType()
+
+        try:
+            assoc = self._requestAssociation(
+                endpoint, assoc_type, session_type)
+        except UnsupportedAssocType, why:
+            # The server didn't like the association/session type that
+            # we sent, and it sent us back a message that might tell
+            # us how to handle it.
+            oidutil.log('Unsupported association type: %s' % (why.message,))
+
+            assoc_type = why.supported_assoc_type
+            session_type = why.supported_session_type
+
+            if assoc_type is None or session_type is None:
+                oidutil.log('Server responded with unsupported association '
+                            'session but did not supply a fallback.')
+                return None
+            elif not self.negotiator.isAllowed(assoc_type, session_type):
+                fmt = ('Server sent unsupported session/association type: '
+                       'session_type=%s, assoc_type=%s')
+                oidutil.log(fmt % (session_type, assoc_type))
+                return None
+            else:
                 try:
-                    response = self._makeKVPost(args, endpoint.server_url)
-                except fetchers.HTTPFetchingError, why:
-                    oidutil.log('openid.associate request failed: %s' %
-                                (why[0],))
+                    assoc = self._requestAssociation(
+                        endpoint, assoc_type, session_type)
+                except UnsupportedAssocType, why:
+                    # Do not keep trying, since it rejected the
+                    # association type that it told us to use.
+                    oidutil.log('Server %s refused its suggested association '
+                                'type: session_type=%s, assoc_type=%s'
+                                % (session_type, assoc_type))
                     assoc = None
-                    break
-                else:
-                    try:
-                        assoc = self._parseAssociation(
-                            response, assoc_session, endpoint.server_url)
-                    except UnsupportedAssocType, why:
-                        oidutil.log(
-                            'Unsupported assoc type: %s' % (why.message,))
-                        assoc_type = why.supported_assoc_type
-                        session_type = why.supported_session_type
-                        if assoc_type is None or session_type is None:
-                            oidutil.log('No allowed session type specified')
-                            assoc = None
-                            break
-
-                        if not self.negotiator.isAllowed(assoc_type,
-                                                         session_type):
-                            msg = (
-                                'Server sent unsupported session type %s '
-                                'for association type %s'
-                                ) % (session_type, assoc_type)
-                            oidutil.log(msg)
-                            assoc = None
-                            break
-                    else:
-                        break
-            else:
-                fmt = 'No association created. Tried types: %s'
-                oidutil.log(fmt % (tried_types,))
-                assoc = None
+        else:
+            return assoc
 
-        return assoc
+    def _requestAssociation(self, endpoint, assoc_type, session_type):
+        """Make and process one association request to this endpoint's
+        OP endpoint URL.
+
+        @returns: An association object or None if the association
+            processing failed.
+
+        @raises: UnsupportedAssocType XXX
+        """
+        assoc_session, args = self._createAssociateRequest(
+            endpoint, assoc_type, session_type)
+
+        try:
+            response = self._makeKVPost(args, endpoint.server_url)
+        except fetchers.HTTPFetchingError, why:
+            oidutil.log('openid.associate request failed: %s' % (why[0],))
+            return None
+
+        try:
+            assoc = self._extractAssociation(response, assoc_session)
+        except KeyError, why:
+            oidutil.log('Missing required parameter in response from %s: %s'
+                        % (endpoint.server_url, why[0]))
+            return None
+        except ProtocolError, why:
+            oidutil.log('Protocol error parsing response from %s: %s' % (
+                endpoint.server_url, why[0]))
+            return None
+        else:
+            return assoc
 
     def _createAssociateRequest(self, endpoint, assoc_type, session_type):
         session_type_class = self.session_types[session_type]
@@ -863,30 +911,31 @@
 
         return session_type
 
-    def _parseAssociation(self,
-                          association_response, assoc_session, server_url):
-        error_code = association_response.getArg(OPENID2_NS, 'error_code')
-        if error_code is not None:
-            return self._associateError(association_response)
-
+    def _extractAssociation(self, association_response, assoc_session):
+        # Extract the common fields from the response, raising an
+        # exception if they are not found
+        assoc_type = association_response.getArg(
+            OPENID_NS, 'assoc_type', no_default)
+        assoc_handle = association_response.getArg(
+            OPENID_NS, 'assoc_handle', no_default)
+
+        # expires_in is a base-10 string. The Python parsing will
+        # accept literals that have whitespace around them and will
+        # accept negative values. Neither of these are really in-spec,
+        # but we think it's OK to accept them.
+        expires_in_str = association_response.getArg(
+            OPENID_NS, 'expires_in', no_default)
         try:
-            assoc_type = association_response.getArg(
-                OPENID_NS, 'assoc_type', no_default)
-            assoc_handle = association_response.getArg(
-                OPENID_NS, 'assoc_handle', no_default)
-            expires_in_str = association_response.getArg(
-                OPENID_NS, 'expires_in', no_default)
-
-            if association_response.isOpenID1():
-                session_type = self._getOpenID1SessionType(
-                    association_response)
-            else:
-                session_type = association_response.getArg(
-                    OPENID2_NS, 'session_type', no_default)
-        except KeyError, e:
-            fmt = 'Getting association: missing key in response from %s: %s'
-            oidutil.log(fmt % (server_url, e[0]))
-            return None
+            expires_in = int(expires_in_str)
+        except ValueError, e:
+            raise ProtocolError('Invalid expires_in field: %s' % (e[0],))
+
+        # OpenID 1 has funny association session behaviour.
+        if association_response.isOpenID1():
+            session_type = self._getOpenID1SessionType(association_response)
+        else:
+            session_type = association_response.getArg(
+                OPENID2_NS, 'session_type', no_default)
 
         # Session type mismatch
         if assoc_session.session_type != session_type:
@@ -902,60 +951,26 @@
                 # Any other mismatch, regardless of protocol version
                 # results in the failure of the association session
                 # altogether.
-                message = 'Session type mismatch. Expected %r, got %r' % (
-                    assoc_session.session_type, session_type)
-                oidutil.log(message)
-                return None
+                fmt = 'Session type mismatch. Expected %r, got %r'
+                message = fmt % (assoc_session.session_type, session_type)
+                raise ProtocolError(message)
 
         # Make sure assoc_type is valid for session_type
         if assoc_type not in assoc_session.allowed_assoc_types:
-            msg = (
-                'Unsupported assoc_type for session %s returned '
-                'from server %s: %s'
-                ) % (server_url, assoc_session.session_type, assoc_type)
-            oidutil.log(msg)
-            return None
-
-        try:
-            expires_in = int(expires_in_str)
-        except ValueError, e:
-            fmt = 'Getting Association: invalid expires_in field: %s'
-            oidutil.log(fmt % (e[0],))
-            return None
+            fmt = 'Unsupported assoc_type for session %s returned: %s'
+            raise ProtocolError(fmt % (assoc_session.session_type, assoc_type))
 
+        # Delegate to the association session to extract the secret
+        # from the response, however is appropriate for that session
+        # type.
         try:
             secret = assoc_session.extractSecret(association_response)
         except ValueError, why:
-            oidutil.log('Malformed response for %s session: %s' % (
-                assoc_session.session_type, why[0]))
-            return None
-        except KeyError, why:
-            fmt = 'Getting association: missing key in response from %s: %s'
-            oidutil.log(fmt % (server_url, why[0]))
-            return None
+            fmt = 'Malformed response for %s session: %s'
+            raise ProtocolError(fmt % (assoc_session.session_type, why[0]))
 
-        assoc = Association.fromExpiresIn(
+        return Association.fromExpiresIn(
             expires_in, assoc_handle, secret, assoc_type)
-        self.store.storeAssociation(server_url, assoc)
-
-        return assoc
-
-    def _associateError(self, results):
-        error_code = results['error_code']
-        default_message = 'associate error from server: %s' % (error_code,)
-        message = results.get('error', default_message)
-        if error_code == 'unsupported-type':
-            supported_assoc_type = results.get('assoc_type')
-            supported_session_type = results.get('session_type')
-            raise UnsupportedAssocType(
-                message,
-                supported_assoc_type,
-                supported_session_type,
-                )
-        else:
-            fmt = 'Error response to associate request from server: %s'
-            oidutil.log(fmt % (message,))
-            return None
 
 class AuthRequest(object):
     def __init__(self, endpoint, assoc):

Modified: incubator/heraldry/libraries/python/openid/trunk/openid/test/test_association_response.py
URL: http://svn.apache.org/viewvc/incubator/heraldry/libraries/python/openid/trunk/openid/test/test_association_response.py?view=diff&rev=493368&r1=493367&r2=493368
==============================================================================
--- incubator/heraldry/libraries/python/openid/trunk/openid/test/test_association_response.py
(original)
+++ incubator/heraldry/libraries/python/openid/trunk/openid/test/test_association_response.py
Fri Jan  5 21:25:34 2007
@@ -5,17 +5,17 @@
 """
 from openid import oidutil
 from openid.test.test_consumer import CatchLogs
-from openid.message import Message, OPENID2_NS, OPENID_NS
+from openid.message import Message, OPENID2_NS, OPENID_NS, no_default
 from openid.server.server import DiffieHellmanSHA1ServerSession
 from openid.consumer.consumer import GenericConsumer, \
-     DiffieHellmanSHA1ConsumerSession
+     DiffieHellmanSHA1ConsumerSession, ProtocolError
 from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_1_1_TYPE, OPENID_2_0_TYPE
 import _memstore
 import unittest
 
 # Some values we can use for convenience (see mkAssocResponse)
 association_response_values = {
-    'expires_in': 'a time',
+    'expires_in': '1000',
     'assoc_handle':'a handle',
     'assoc_type':'a type',
     'session_type':'a session type',
@@ -39,7 +39,16 @@
         self.consumer = GenericConsumer(self.store)
         self.endpoint = OpenIDServiceEndpoint()
 
-def mkParseAssocMissingTest(keys):
+    def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs):
+        try:
+            result = func(*args, **kwargs)
+        except ProtocolError, e:
+            message = 'Expected prefix %r, got %r' % (str_prefix, e[0])
+            self.failUnless(e[0].startswith(str_prefix), message)
+        else:
+            self.fail('Expected ProtocolError, got %r' % (result,))
+
+def mkExtractAssocMissingTest(keys):
     """Factory function for creating test methods for generating
     missing field tests.
 
@@ -63,45 +72,42 @@
     def test(self):
         msg = mkAssocResponse(*keys)
 
-        result = self.consumer._parseAssociation(msg, None, 'dummy.url')
-        self.failUnless(result is None)
-        self.failUnlessEqual(len(self.messages), 1)
-        self.failUnless(self.messages[0].startswith(
-            'Getting association: missing key'))
+        self.failUnlessRaises(KeyError,
+                              self.consumer._extractAssociation, msg, None)
 
     return test
 
-class TestParseAssociationMissingFieldsOpenID2(BaseAssocTest):
+class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest):
     """Test for returning an error upon missing fields in association
     responses for OpenID 2"""
 
-    test_noFields_openid2 = mkParseAssocMissingTest(['ns'])
+    test_noFields_openid2 = mkExtractAssocMissingTest(['ns'])
 
-    test_missingExpires_openid2 = mkParseAssocMissingTest(
+    test_missingExpires_openid2 = mkExtractAssocMissingTest(
         ['assoc_handle', 'assoc_type', 'session_type', 'ns'])
 
-    test_missingHandle_openid2 = mkParseAssocMissingTest(
+    test_missingHandle_openid2 = mkExtractAssocMissingTest(
         ['expires_in', 'assoc_type', 'session_type', 'ns'])
 
-    test_missingAssocType_openid2 = mkParseAssocMissingTest(
+    test_missingAssocType_openid2 = mkExtractAssocMissingTest(
         ['expires_in', 'assoc_handle', 'session_type', 'ns'])
 
-    test_missingSessionType_openid2 = mkParseAssocMissingTest(
+    test_missingSessionType_openid2 = mkExtractAssocMissingTest(
         ['expires_in', 'assoc_handle', 'assoc_type', 'ns'])
 
-class TestParseAssociationMissingFieldsOpenID1(BaseAssocTest):
+class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest):
     """Test for returning an error upon missing fields in association
     responses for OpenID 2"""
 
-    test_noFields_openid1 = mkParseAssocMissingTest([])
+    test_noFields_openid1 = mkExtractAssocMissingTest([])
 
-    test_missingExpires_openid1 = mkParseAssocMissingTest(
+    test_missingExpires_openid1 = mkExtractAssocMissingTest(
         ['assoc_handle', 'assoc_type'])
 
-    test_missingHandle_openid1 = mkParseAssocMissingTest(
+    test_missingHandle_openid1 = mkExtractAssocMissingTest(
         ['expires_in', 'assoc_type'])
 
-    test_missingAssocType_openid1 = mkParseAssocMissingTest(
+    test_missingAssocType_openid1 = mkExtractAssocMissingTest(
         ['expires_in', 'assoc_handle'])
 
 class DummyAssocationSession(object):
@@ -109,7 +115,7 @@
         self.session_type = session_type
         self.allowed_assoc_types = allowed_assoc_types
 
-class ParseAssociationSessionTypeMismatch(BaseAssocTest):
+class ExtractAssociationSessionTypeMismatch(BaseAssocTest):
     def mkTest(requested_session_type, response_session_type, openid1=False):
         def test(self):
             assoc_session = DummyAssocationSession(requested_session_type)
@@ -118,9 +124,8 @@
                 keys.remove('ns')
             msg = mkAssocResponse(*keys)
             msg.setArg(OPENID_NS, 'session_type', response_session_type)
-            result = self.consumer._parseAssociation(
-                msg, assoc_session, server_url='dummy.url')
-            self.failUnless(result is None)
+            self.failUnlessProtocolError('Session type mismatch',
+                self.consumer._extractAssociation, msg, assoc_session)
 
         return test
 
@@ -146,13 +151,13 @@
 
     test_typeMismatchDHSHA1NoEnc_openid1 = mkTest(
         requested_session_type='DH-SHA1',
-        response_session_type='no-encryption',
+        response_session_type='DH-SHA256',
         openid1=True,
         )
 
     test_typeMismatchDHSHA256NoEnc_openid1 = mkTest(
         requested_session_type='DH-SHA256',
-        response_session_type='no-encryption',
+        response_session_type='DH-SHA1',
         openid1=True,
         )
 
@@ -228,35 +233,73 @@
         expected_session_type='DH-SHA256',
         )
 
-class TestAssocTypeInvalidForSession(BaseAssocTest):
-    def _setup(self, assoc_type):
-        no_encryption_session = DummyAssocationSession('matching-session-type',
-                                                       ['good-assoc-type'])
-        msg = mkAssocResponse(*association_response_values.keys())
-        msg.setArg(OPENID2_NS, 'session_type', 'matching-session-type')
-        msg.setArg(OPENID2_NS, 'assoc_type', assoc_type)
+class DummyAssociationSession(object):
+    secret = "shh! don't tell!"
+    extract_secret_called = False
+
+    session_type = None
+
+    allowed_assoc_types = None
+
+    def extractSecret(self, message):
+        self.extract_secret_called = True
+        return self.secret
+
+class TestInvalidFields(BaseAssocTest):
+    def setUp(self):
+        BaseAssocTest.setUp(self)
+        self.session_type = 'testing-session'
 
-        result = self.consumer._parseAssociation(
-            msg, no_encryption_session, 'dummy.url')
+        # This must something that works for Association.fromExpiresIn
+        self.assoc_type = 'HMAC-SHA1'
 
+        self.assoc_handle = 'testing-assoc-handle'
+
+        # These arguments should all be valid
+        self.assoc_response = Message.fromOpenIDArgs({
+            'expires_in': '1000',
+            'assoc_handle':self.assoc_handle,
+            'assoc_type':self.assoc_type,
+            'session_type':self.session_type,
+            'ns':OPENID2_NS,
+            })
+
+        self.assoc_session = DummyAssociationSession()
+
+        # Make the session for the response's session type
+        self.assoc_session.session_type = self.session_type
+        self.assoc_session.allowed_assoc_types = [self.assoc_type]
+
+    def test_worksWithGoodFields(self):
+        """Handle a full successful association response"""
+        assoc = self.consumer._extractAssociation(
+            self.assoc_response, self.assoc_session)
+        self.failUnless(self.assoc_session.extract_secret_called)
+        self.failUnlessEqual(self.assoc_session.secret, assoc.secret)
+        self.failUnlessEqual(1000, assoc.lifetime)
+        self.failUnlessEqual(self.assoc_handle, assoc.handle)
+        self.failUnlessEqual(self.assoc_type, assoc.assoc_type)
 
     def test_badAssocType(self):
-        self._setup('unsupported')
-        self.failUnlessEqual(1, len(self.messages))
-        self.failUnless(self.messages[0].startswith(
-            'Unsupported assoc_type for session'))
+        # Make sure that the assoc type in the response is not valid
+        # for the given session.
+        self.assoc_session.allowed_assoc_types = []
+        self.failUnlessProtocolError('Unsupported assoc_type for session',
+            self.consumer._extractAssociation,
+            self.assoc_response, self.assoc_session)
 
     def test_badExpiresIn(self):
-        self._setup('good-assoc-type')
-        self.failUnlessEqual(1, len(self.messages))
-        self.failUnless(self.messages[0].startswith(
-            'Getting Association: invalid expires_in'))
+        # Invalid value for expires_in should cause failure
+        self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever')
+        self.failUnlessProtocolError('Invalid expires_in',
+            self.consumer._extractAssociation,
+            self.assoc_response, self.assoc_session)
 
 
 # XXX: This is what causes most of the imports in this file. It is
 # sort of a unit test and sort of a functional test. I'm not terribly
 # fond of it.
-class TestParseAssociation(BaseAssocTest):
+class TestExtractAssociationDiffieHellman(BaseAssocTest):
     secret = 'x' * 20
 
     def _setUpDH(self):
@@ -279,7 +322,7 @@
 
     def test_success(self):
         sess, server_resp = self._setUpDH()
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
+        ret = self.consumer._extractAssociation(server_resp, sess)
         self.failIf(ret is None)
         self.failUnlessEqual(ret.assoc_type, 'HMAC-SHA1')
         self.failUnlessEqual(ret.secret, self.secret)
@@ -292,52 +335,8 @@
         self.endpoint.type_uris = [OPENID_2_0_TYPE, OPENID_1_1_TYPE]
         self.test_success()
 
-    def test_badAssocType(self):
-        sess, server_resp = self._setUpDH()
-        server_resp.setArg(OPENID_NS, 'assoc_type', 'Crazy Low Prices!!!')
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
-        self.failUnless(ret is None)
-
-    def test_badExpiresIn(self):
-        sess, server_resp = self._setUpDH()
-        server_resp.setArg(OPENID_NS, 'expires_in', 'Crazy Low Prices!!!')
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
-        self.failUnless(ret is None)
-
-    def test_badSessionType(self):
-        sess, server_resp = self._setUpDH()
-        server_resp.setArg(OPENID_NS, 'session_type', '|/iA6rA')
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
-        self.failUnless(ret is None)
-
-    def test_plainFallback(self):
-        sess = DiffieHellmanSHA1ConsumerSession()
-        server_resp = Message.fromOpenIDArgs({
-            'assoc_type': 'HMAC-SHA1',
-            'assoc_handle': 'handle',
-            'expires_in': '1000',
-            'mac_key': oidutil.toBase64(self.secret),
-            })
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
-        self.failIf(ret is None)
-        self.failUnlessEqual(ret.assoc_type, 'HMAC-SHA1')
-        self.failUnlessEqual(ret.secret, self.secret)
-        self.failUnlessEqual(ret.handle, 'handle')
-        self.failUnlessEqual(ret.lifetime, 1000)
-
-    def test_plainFallbackFailure(self):
-        sess = DiffieHellmanSHA1ConsumerSession()
-        # missing mac_key
-        server_resp = Message.fromOpenIDArgs({
-            'assoc_type': 'HMAC-SHA1',
-            'assoc_handle': 'handle',
-            'expires_in': '1000',
-            })
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
-        self.failUnless(ret is None)
-
     def test_badDHValues(self):
         sess, server_resp = self._setUpDH()
         server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00')
-        ret = self.consumer._parseAssociation(server_resp, sess, 'server_url')
-        self.failUnless(ret is None)
+        self.failUnlessProtocolError('Malformed response for',
+            self.consumer._extractAssociation, server_resp, sess)



Mime
View raw message