1  from openid.extensions import sreg 
  2  from openid.message import NamespaceMap, Message, registerNamespaceAlias 
  3  from openid.server.server import OpenIDRequest, OpenIDResponse 
  4   
  5  import unittest 
  6   
 10   
 15   
 17          self.failUnlessRaises(ValueError, sreg.checkFieldName, 'INVALID') 
  18   
 20          self.failUnlessRaises(ValueError, sreg.checkFieldName, None) 
   21   
 22   
 25          self.supported = supported 
 26          self.checked_uris = [] 
  27   
 29          self.checked_uris.append(namespace_uri) 
 30          return namespace_uri in self.supported 
   31   
 49   
 57   
 61   
 66   
 72   
 78   
 80          for openid_version in [True, False]: 
 81              for sreg_version in [sreg.ns_uri_1_0, sreg.ns_uri_1_1]: 
 82                  for alias in ['sreg', 'bogus']: 
 83                      self.setUp() 
 84   
 85                      self.msg.openid1 = openid_version 
 86                      self.msg.namespaces.addAlias(sreg_version, alias) 
 87                      ns_uri = sreg.getSRegNS(self.msg) 
 88                      self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), alias) 
 89                      self.failUnlessEqual(sreg_version, ns_uri) 
  90   
 92          self.msg.openid1 = True 
 93          self.msg.namespaces.addAlias('http://invalid/', 'sreg') 
 94          self.failUnlessRaises(sreg.SRegNamespaceError, 
 95                                sreg.getSRegNS, self.msg) 
  96   
 98          self.msg.openid1 = False 
 99          self.msg.namespaces.addAlias('http://invalid/', 'sreg') 
100          self.failUnlessRaises(sreg.SRegNamespaceError, 
101                                sreg.getSRegNS, self.msg) 
 102   
107   
 118   
121          req = sreg.SRegRequest() 
122          self.failUnlessEqual([], req.optional) 
123          self.failUnlessEqual([], req.required) 
124          self.failUnlessEqual(None, req.policy_url) 
125          self.failUnlessEqual(sreg.ns_uri, req.ns_uri) 
 126   
128          req = sreg.SRegRequest( 
129              ['nickname'], 
130              ['gender'], 
131              'http://policy', 
132              'http://sreg.ns_uri') 
133          self.failUnlessEqual(['gender'], req.optional) 
134          self.failUnlessEqual(['nickname'], req.required) 
135          self.failUnlessEqual('http://policy', req.policy_url) 
136          self.failUnlessEqual('http://sreg.ns_uri', req.ns_uri) 
 137   
139          self.failUnlessRaises( 
140              ValueError, 
141              sreg.SRegRequest, ['elvis']) 
 142   
144          args = {} 
145          ns_sentinel = object() 
146          args_sentinel = object() 
147   
148          class FakeMessage(object): 
149              copied = False 
150   
151              def __init__(self): 
152                  self.message = Message() 
 153   
154              def getArgs(msg_self, ns_uri): 
155                  self.failUnlessEqual(ns_sentinel, ns_uri) 
156                  return args_sentinel 
 157   
158              def copy(msg_self): 
159                  msg_self.copied = True 
160                  return msg_self 
161   
162          class TestingReq(sreg.SRegRequest): 
163              def _getSRegNS(req_self, unused): 
164                  return ns_sentinel 
165   
166              def parseExtensionArgs(req_self, args): 
167                  self.failUnlessEqual(args_sentinel, args) 
168   
169          openid_req = OpenIDRequest() 
170   
171          msg = FakeMessage() 
172          openid_req.message = msg 
173   
174          req = TestingReq.fromOpenIDRequest(openid_req) 
175          self.failUnless(type(req) is TestingReq) 
176          self.failUnless(msg.copied) 
177   
182   
186   
191   
197   
202   
207   
212   
217   
222   
227   
234   
236          req = sreg.SRegRequest() 
237          req.parseExtensionArgs({'optional':'nickname', 
238                                  'required':'nickname'}) 
239          self.failUnlessEqual([], req.optional) 
240          self.failUnlessEqual(['nickname'], req.required) 
 241   
243          req = sreg.SRegRequest() 
244          self.failUnlessRaises( 
245              ValueError, 
246              req.parseExtensionArgs, 
247              {'optional':'nickname', 
248               'required':'nickname'}, 
249              strict=True) 
 250   
252          req = sreg.SRegRequest() 
253          req.parseExtensionArgs({'optional':'nickname,email', 
254                                  'required':'country,postcode'}, strict=True) 
255          self.failUnlessEqual(['nickname','email'], req.optional) 
256          self.failUnlessEqual(['country','postcode'], req.required) 
 257   
267   
273   
275          req = sreg.SRegRequest() 
276          for field_name in sreg.data_fields: 
277              self.failIf(field_name in req) 
278   
279          self.failIf('something else' in req) 
280   
281          req.requestField('nickname') 
282          for field_name in sreg.data_fields: 
283              if field_name == 'nickname': 
284                  self.failUnless(field_name in req) 
285              else: 
286                  self.failIf(field_name in req) 
 287   
289          req = sreg.SRegRequest() 
290          self.failUnlessRaises( 
291              ValueError, 
292              req.requestField, 'something else') 
293   
294          self.failUnlessRaises( 
295              ValueError, 
296              req.requestField, 'something else', strict=True) 
 297   
299           
300          req = sreg.SRegRequest() 
301          fields = list(sreg.data_fields) 
302          for field_name in fields: 
303              req.requestField(field_name) 
304   
305          self.failUnlessEqual(fields, req.optional) 
306          self.failUnlessEqual([], req.required) 
307   
308           
309          for field_name in fields: 
310              req.requestField(field_name) 
311   
312          self.failUnlessEqual(fields, req.optional) 
313          self.failUnlessEqual([], req.required) 
314   
315           
316          expected = list(fields) 
317          overridden = expected.pop(0) 
318          req.requestField(overridden, required=True) 
319          self.failUnlessEqual(expected, req.optional) 
320          self.failUnlessEqual([overridden], req.required) 
321   
322           
323          for field_name in fields: 
324              req.requestField(field_name, required=True) 
325   
326          self.failUnlessEqual([], req.optional) 
327          self.failUnlessEqual(fields, req.required) 
328   
329           
330          for field_name in fields: 
331              req.requestField(field_name) 
332   
333          self.failUnlessEqual([], req.optional) 
334          self.failUnlessEqual(fields, req.required) 
 335   
339   
341           
342          req = sreg.SRegRequest() 
343   
344          fields = list(sreg.data_fields) 
345          req.requestFields(fields) 
346   
347          self.failUnlessEqual(fields, req.optional) 
348          self.failUnlessEqual([], req.required) 
349   
350           
351          req.requestFields(fields) 
352   
353          self.failUnlessEqual(fields, req.optional) 
354          self.failUnlessEqual([], req.required) 
355   
356           
357          expected = list(fields) 
358          overridden = expected.pop(0) 
359          req.requestFields([overridden], required=True) 
360          self.failUnlessEqual(expected, req.optional) 
361          self.failUnlessEqual([overridden], req.required) 
362   
363           
364          req.requestFields(fields, required=True) 
365   
366          self.failUnlessEqual([], req.optional) 
367          self.failUnlessEqual(fields, req.required) 
368   
369           
370          req.requestFields(fields) 
371   
372          self.failUnlessEqual([], req.optional) 
373          self.failUnlessEqual(fields, req.required) 
 374   
376          req = sreg.SRegRequest() 
377          self.failUnlessEqual({}, req.getExtensionArgs()) 
378   
379          req.requestField('nickname') 
380          self.failUnlessEqual({'optional':'nickname'}, req.getExtensionArgs()) 
381   
382          req.requestField('email') 
383          self.failUnlessEqual({'optional':'nickname,email'}, 
384                               req.getExtensionArgs()) 
385   
386          req.requestField('gender', required=True) 
387          self.failUnlessEqual({'optional':'nickname,email', 
388                                'required':'gender'}, 
389                               req.getExtensionArgs()) 
390   
391          req.requestField('postcode', required=True) 
392          self.failUnlessEqual({'optional':'nickname,email', 
393                                'required':'gender,postcode'}, 
394                               req.getExtensionArgs()) 
395   
396          req.policy_url = 'http://policy.invalid/' 
397          self.failUnlessEqual({'optional':'nickname,email', 
398                                'required':'gender,postcode', 
399                                'policy_url':'http://policy.invalid/'}, 
400                               req.getExtensionArgs()) 
 401   
402  data = { 
403      'nickname':'linusaur', 
404      'postcode':'12345', 
405      'country':'US', 
406      'gender':'M', 
407      'fullname':'Leonhard Euler', 
408      'email':'president@whitehouse.gov', 
409      'dob':'0000-00-00', 
410      'language':'en-us', 
411      } 
412   
414 -    def __init__(self, message, signed_stuff): 
 417   
419          return self.signed_stuff 
  420   
449   
482   
483  if __name__ == '__main__': 
484      unittest.main() 
485