1  """Tests for consumer handling of association responses 
  2   
  3  This duplicates some things that are covered by test_consumer, but 
  4  this works for now. 
  5  """ 
  6  from openid import oidutil 
  7  from openid.test.test_consumer import CatchLogs 
  8  from openid.message import Message, OPENID2_NS, OPENID_NS, no_default 
  9  from openid.server.server import DiffieHellmanSHA1ServerSession 
 10  from openid.consumer.consumer import GenericConsumer, \ 
 11       DiffieHellmanSHA1ConsumerSession, ProtocolError 
 12  from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_1_1_TYPE, OPENID_2_0_TYPE 
 13  from openid.store import memstore 
 14  import unittest 
 15   
 16   
 17  association_response_values = { 
 18      'expires_in': '1000', 
 19      'assoc_handle':'a handle', 
 20      'assoc_type':'a type', 
 21      'session_type':'a session type', 
 22      'ns':OPENID2_NS, 
 23      } 
 24   
 26      """Build an association response message that contains the 
 27      specified subset of keys. The values come from 
 28      `association_response_values`. 
 29   
 30      This is useful for testing for missing keys and other times that 
 31      we don't care what the values are.""" 
 32      args = dict([(key, association_response_values[key]) for key in keys]) 
 33      return Message.fromOpenIDArgs(args) 
  34   
 41   
 43          try: 
 44              result = func(*args, **kwargs) 
 45          except ProtocolError, e: 
 46              message = 'Expected prefix %r, got %r' % (str_prefix, e[0]) 
 47              self.failUnless(e[0].startswith(str_prefix), message) 
 48          else: 
 49              self.fail('Expected ProtocolError, got %r' % (result,)) 
   50   
 52      """Factory function for creating test methods for generating 
 53      missing field tests. 
 54   
 55      Make a test that ensures that an association response that 
 56      is missing required fields will short-circuit return None. 
 57   
 58      According to 'Association Session Response' subsection 'Common 
 59      Response Parameters', the following fields are required for OpenID 
 60      2.0: 
 61   
 62       * ns 
 63       * session_type 
 64       * assoc_handle 
 65       * assoc_type 
 66       * expires_in 
 67   
 68      If 'ns' is missing, it will fall back to OpenID 1 checking. In 
 69      OpenID 1, everything except 'session_type' and 'ns' are required. 
 70      """ 
 71   
 72      def test(self): 
 73          msg = mkAssocResponse(*keys) 
 74   
 75          self.failUnlessRaises(KeyError, 
 76                                self.consumer._extractAssociation, msg, None) 
  77   
 78      return test 
 79   
 81      """Test for returning an error upon missing fields in association 
 82      responses for OpenID 2""" 
 83   
 84      test_noFields_openid2 = mkExtractAssocMissingTest(['ns']) 
 85   
 86      test_missingExpires_openid2 = mkExtractAssocMissingTest( 
 87          ['assoc_handle', 'assoc_type', 'session_type', 'ns']) 
 88   
 89      test_missingHandle_openid2 = mkExtractAssocMissingTest( 
 90          ['expires_in', 'assoc_type', 'session_type', 'ns']) 
 91   
 92      test_missingAssocType_openid2 = mkExtractAssocMissingTest( 
 93          ['expires_in', 'assoc_handle', 'session_type', 'ns']) 
 94   
 95      test_missingSessionType_openid2 = mkExtractAssocMissingTest( 
 96          ['expires_in', 'assoc_handle', 'assoc_type', 'ns']) 
  97   
 99      """Test for returning an error upon missing fields in association 
100      responses for OpenID 2""" 
101   
102      test_noFields_openid1 = mkExtractAssocMissingTest([]) 
103   
104      test_missingExpires_openid1 = mkExtractAssocMissingTest( 
105          ['assoc_handle', 'assoc_type']) 
106   
107      test_missingHandle_openid1 = mkExtractAssocMissingTest( 
108          ['expires_in', 'assoc_type']) 
109   
110      test_missingAssocType_openid1 = mkExtractAssocMissingTest( 
111          ['expires_in', 'assoc_handle']) 
 112   
114 -    def __init__(self, session_type, allowed_assoc_types=()): 
  117   
131   
132      test_typeMismatchNoEncBlank_openid2 = mkTest( 
133          requested_session_type='no-encryption', 
134          response_session_type='', 
135          ) 
136   
137      test_typeMismatchDHSHA1NoEnc_openid2 = mkTest( 
138          requested_session_type='DH-SHA1', 
139          response_session_type='no-encryption', 
140          ) 
141   
142      test_typeMismatchDHSHA256NoEnc_openid2 = mkTest( 
143          requested_session_type='DH-SHA256', 
144          response_session_type='no-encryption', 
145          ) 
146   
147      test_typeMismatchNoEncDHSHA1_openid2 = mkTest( 
148          requested_session_type='no-encryption', 
149          response_session_type='DH-SHA1', 
150          ) 
151   
152      test_typeMismatchDHSHA1NoEnc_openid1 = mkTest( 
153          requested_session_type='DH-SHA1', 
154          response_session_type='DH-SHA256', 
155          openid1=True, 
156          ) 
157   
158      test_typeMismatchDHSHA256NoEnc_openid1 = mkTest( 
159          requested_session_type='DH-SHA256', 
160          response_session_type='DH-SHA1', 
161          openid1=True, 
162          ) 
163   
164      test_typeMismatchNoEncDHSHA1_openid1 = mkTest( 
165          requested_session_type='no-encryption', 
166          response_session_type='DH-SHA1', 
167          openid1=True, 
168          ) 
169   
170   
172 -    def mkTest(expected_session_type, session_type_value): 
 173          """Return a test method that will check what session type will 
174          be used if the OpenID 1 response to an associate call sets the 
175          'session_type' field to `session_type_value` 
176          """ 
177          def test(self): 
178              self._doTest(expected_session_type, session_type_value) 
179              self.failUnlessEqual(0, len(self.messages)) 
 180   
181          return test 
 182   
183 -    def _doTest(self, expected_session_type, session_type_value): 
 184           
185           
186           
187          args = {} 
188          if session_type_value is not None: 
189              args['session_type'] = session_type_value 
190          message = Message.fromOpenIDArgs(args) 
191          self.failUnless(message.isOpenID1()) 
192   
193          actual_session_type = self.consumer._getOpenID1SessionType(message) 
194          error_message = ('Returned sesion type parameter %r was expected ' 
195                           'to yield session type %r, but yielded %r' % 
196                           (session_type_value, expected_session_type, 
197                            actual_session_type)) 
198          self.failUnlessEqual( 
199              expected_session_type, actual_session_type, error_message) 
 200   
201      test_none = mkTest( 
202          session_type_value=None, 
203          expected_session_type='no-encryption', 
204          ) 
205   
206      test_empty = mkTest( 
207          session_type_value='', 
208          expected_session_type='no-encryption', 
209          ) 
210   
211       
213          self._doTest( 
214              session_type_value='no-encryption', 
215              expected_session_type='no-encryption', 
216              ) 
217          self.failUnlessEqual(1, len(self.messages)) 
218          self.failUnless(self.messages[0].startswith( 
219              'WARNING: OpenID server sent "no-encryption"')) 
 220   
221      test_dhSHA1 = mkTest( 
222          session_type_value='DH-SHA1', 
223          expected_session_type='DH-SHA1', 
224          ) 
225   
226       
227       
228       
229       
230       
231      test_dhSHA256 = mkTest( 
232          session_type_value='DH-SHA256', 
233          expected_session_type='DH-SHA256', 
234          ) 
235   
247   
272   
274          """Handle a full successful association response""" 
275          assoc = self.consumer._extractAssociation( 
276              self.assoc_response, self.assoc_session) 
277          self.failUnless(self.assoc_session.extract_secret_called) 
278          self.failUnlessEqual(self.assoc_session.secret, assoc.secret) 
279          self.failUnlessEqual(1000, assoc.lifetime) 
280          self.failUnlessEqual(self.assoc_handle, assoc.handle) 
281          self.failUnlessEqual(self.assoc_type, assoc.assoc_type) 
 282   
290   
 297   
298   
299   
300   
301   
303      secret = 'x' * 20 
304   
306          sess, message = self.consumer._createAssociateRequest( 
307              self.endpoint, 'HMAC-SHA1', 'DH-SHA1') 
308   
309           
310          self.failUnlessEqual(self.endpoint.compatibilityMode(), 
311                               message.isOpenID1()) 
312   
313          server_sess = DiffieHellmanSHA1ServerSession.fromMessage(message) 
314          server_resp = server_sess.answer(self.secret) 
315          server_resp['assoc_type'] = 'HMAC-SHA1' 
316          server_resp['assoc_handle'] = 'handle' 
317          server_resp['expires_in'] = '1000' 
318          server_resp['session_type'] = 'DH-SHA1' 
319          return sess, Message.fromOpenIDArgs(server_resp) 
 320   
322          sess, server_resp = self._setUpDH() 
323          ret = self.consumer._extractAssociation(server_resp, sess) 
324          self.failIf(ret is None) 
325          self.failUnlessEqual(ret.assoc_type, 'HMAC-SHA1') 
326          self.failUnlessEqual(ret.secret, self.secret) 
327          self.failUnlessEqual(ret.handle, 'handle') 
328          self.failUnlessEqual(ret.lifetime, 1000) 
 329   
335   
337          sess, server_resp = self._setUpDH() 
338          server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00') 
339          self.failUnlessProtocolError('Malformed response for', 
340              self.consumer._extractAssociation, server_resp, sess) 
  341