Package openid :: Package test :: Module test_sreg
[frames] | no frames]

Source Code for Module openid.test.test_sreg

  1  from openid.extensions import sreg 
  2  from openid.message import NamespaceMap, Message, registerNamespaceAlias 
  3  from openid.server.server import OpenIDRequest, OpenIDResponse 
  4   
  5  import unittest 
  6   
7 -class SRegURITest(unittest.TestCase):
8 - def test_is11(self):
9 self.failUnlessEqual(sreg.ns_uri_1_1, sreg.ns_uri)
10
11 -class CheckFieldNameTest(unittest.TestCase):
12 - def test_goodNamePasses(self):
13 for field_name in sreg.data_fields: 14 sreg.checkFieldName(field_name)
15
16 - def test_badNameFails(self):
17 self.failUnlessRaises(ValueError, sreg.checkFieldName, 'INVALID')
18
19 - def test_badTypeFails(self):
20 self.failUnlessRaises(ValueError, sreg.checkFieldName, None)
21 22 # For supportsSReg test
23 -class FakeEndpoint(object):
24 - def __init__(self, supported):
25 self.supported = supported 26 self.checked_uris = []
27
28 - def usesExtension(self, namespace_uri):
29 self.checked_uris.append(namespace_uri) 30 return namespace_uri in self.supported
31
32 -class SupportsSRegTest(unittest.TestCase):
33 - def test_unsupported(self):
34 endpoint = FakeEndpoint([]) 35 self.failIf(sreg.supportsSReg(endpoint)) 36 self.failUnlessEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], 37 endpoint.checked_uris)
38
39 - def test_supported_1_1(self):
40 endpoint = FakeEndpoint([sreg.ns_uri_1_1]) 41 self.failUnless(sreg.supportsSReg(endpoint)) 42 self.failUnlessEqual([sreg.ns_uri_1_1], endpoint.checked_uris)
43
44 - def test_supported_1_0(self):
45 endpoint = FakeEndpoint([sreg.ns_uri_1_0]) 46 self.failUnless(sreg.supportsSReg(endpoint)) 47 self.failUnlessEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], 48 endpoint.checked_uris)
49
50 -class FakeMessage(object):
51 - def __init__(self):
52 self.openid1 = False 53 self.namespaces = NamespaceMap()
54
55 - def isOpenID1(self):
56 return self.openid1
57
58 -class GetNSTest(unittest.TestCase):
59 - def setUp(self):
60 self.msg = FakeMessage()
61
62 - def test_openID2Empty(self):
63 ns_uri = sreg.getSRegNS(self.msg) 64 self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), 'sreg') 65 self.failUnlessEqual(sreg.ns_uri, ns_uri)
66
67 - def test_openID1Empty(self):
68 self.msg.openid1 = True 69 ns_uri = sreg.getSRegNS(self.msg) 70 self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), 'sreg') 71 self.failUnlessEqual(sreg.ns_uri, ns_uri)
72
73 - def test_openID1Defined_1_0(self):
74 self.msg.openid1 = True 75 self.msg.namespaces.add(sreg.ns_uri_1_0) 76 ns_uri = sreg.getSRegNS(self.msg) 77 self.failUnlessEqual(sreg.ns_uri_1_0, ns_uri)
78
80 for openid_version in [True, False]: 81 for sreg_version in [sreg.ns_uri_1_0, sreg.ns_uri_1_1]: 82 for alias in ['sreg', 'bogus']: 83 self.setUp() 84 85 self.msg.openid1 = openid_version 86 self.msg.namespaces.addAlias(sreg_version, alias) 87 ns_uri = sreg.getSRegNS(self.msg) 88 self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), alias) 89 self.failUnlessEqual(sreg_version, ns_uri)
90
91 - def test_openID1DefinedBadly(self):
92 self.msg.openid1 = True 93 self.msg.namespaces.addAlias('http://invalid/', 'sreg') 94 self.failUnlessRaises(sreg.SRegNamespaceError, 95 sreg.getSRegNS, self.msg)
96
97 - def test_openID2DefinedBadly(self):
98 self.msg.openid1 = False 99 self.msg.namespaces.addAlias('http://invalid/', 'sreg') 100 self.failUnlessRaises(sreg.SRegNamespaceError, 101 sreg.getSRegNS, self.msg)
102
103 - def test_openID2Defined_1_0(self):
104 self.msg.namespaces.add(sreg.ns_uri_1_0) 105 ns_uri = sreg.getSRegNS(self.msg) 106 self.failUnlessEqual(sreg.ns_uri_1_0, ns_uri)
107
109 args = { 110 'sreg.optional': 'nickname', 111 'sreg.required': 'dob', 112 } 113 114 m = Message.fromOpenIDArgs(args) 115 116 self.failUnless(m.getArg(sreg.ns_uri_1_1, 'optional') == 'nickname') 117 self.failUnless(m.getArg(sreg.ns_uri_1_1, 'required') == 'dob')
118
119 -class SRegRequestTest(unittest.TestCase):
120 - def test_constructEmpty(self):
121 req = sreg.SRegRequest() 122 self.failUnlessEqual([], req.optional) 123 self.failUnlessEqual([], req.required) 124 self.failUnlessEqual(None, req.policy_url) 125 self.failUnlessEqual(sreg.ns_uri, req.ns_uri)
126
127 - def test_constructFields(self):
128 req = sreg.SRegRequest( 129 ['nickname'], 130 ['gender'], 131 'http://policy', 132 'http://sreg.ns_uri') 133 self.failUnlessEqual(['gender'], req.optional) 134 self.failUnlessEqual(['nickname'], req.required) 135 self.failUnlessEqual('http://policy', req.policy_url) 136 self.failUnlessEqual('http://sreg.ns_uri', req.ns_uri)
137
138 - def test_constructBadFields(self):
139 self.failUnlessRaises( 140 ValueError, 141 sreg.SRegRequest, ['elvis'])
142
143 - def test_fromOpenIDRequest(self):
144 args = {} 145 ns_sentinel = object() 146 args_sentinel = object() 147 148 class FakeMessage(object): 149 copied = False 150 151 def __init__(self): 152 self.message = Message()
153 154 def getArgs(msg_self, ns_uri): 155 self.failUnlessEqual(ns_sentinel, ns_uri) 156 return args_sentinel
157 158 def copy(msg_self): 159 msg_self.copied = True 160 return msg_self 161 162 class TestingReq(sreg.SRegRequest): 163 def _getSRegNS(req_self, unused): 164 return ns_sentinel 165 166 def parseExtensionArgs(req_self, args): 167 self.failUnlessEqual(args_sentinel, args) 168 169 openid_req = OpenIDRequest() 170 171 msg = FakeMessage() 172 openid_req.message = msg 173 174 req = TestingReq.fromOpenIDRequest(openid_req) 175 self.failUnless(type(req) is TestingReq) 176 self.failUnless(msg.copied) 177
178 - def test_parseExtensionArgs_empty(self):
179 req = sreg.SRegRequest() 180 results = req.parseExtensionArgs({}) 181 self.failUnlessEqual(None, results)
182
183 - def test_parseExtensionArgs_extraIgnored(self):
184 req = sreg.SRegRequest() 185 req.parseExtensionArgs({'janrain':'inc'})
186
187 - def test_parseExtensionArgs_nonStrict(self):
188 req = sreg.SRegRequest() 189 req.parseExtensionArgs({'required':'beans'}) 190 self.failUnlessEqual([], req.required)
191
192 - def test_parseExtensionArgs_strict(self):
193 req = sreg.SRegRequest() 194 self.failUnlessRaises( 195 ValueError, 196 req.parseExtensionArgs, {'required':'beans'}, strict=True)
197
198 - def test_parseExtensionArgs_policy(self):
199 req = sreg.SRegRequest() 200 req.parseExtensionArgs({'policy_url':'http://policy'}, strict=True) 201 self.failUnlessEqual('http://policy', req.policy_url)
202
203 - def test_parseExtensionArgs_requiredEmpty(self):
204 req = sreg.SRegRequest() 205 req.parseExtensionArgs({'required':''}, strict=True) 206 self.failUnlessEqual([], req.required)
207
208 - def test_parseExtensionArgs_optionalEmpty(self):
209 req = sreg.SRegRequest() 210 req.parseExtensionArgs({'optional':''}, strict=True) 211 self.failUnlessEqual([], req.optional)
212
213 - def test_parseExtensionArgs_optionalSingle(self):
214 req = sreg.SRegRequest() 215 req.parseExtensionArgs({'optional':'nickname'}, strict=True) 216 self.failUnlessEqual(['nickname'], req.optional)
217
218 - def test_parseExtensionArgs_optionalList(self):
219 req = sreg.SRegRequest() 220 req.parseExtensionArgs({'optional':'nickname,email'}, strict=True) 221 self.failUnlessEqual(['nickname','email'], req.optional)
222
223 - def test_parseExtensionArgs_optionalListBadNonStrict(self):
224 req = sreg.SRegRequest() 225 req.parseExtensionArgs({'optional':'nickname,email,beer'}) 226 self.failUnlessEqual(['nickname','email'], req.optional)
227
228 - def test_parseExtensionArgs_optionalListBadStrict(self):
229 req = sreg.SRegRequest() 230 self.failUnlessRaises( 231 ValueError, 232 req.parseExtensionArgs, {'optional':'nickname,email,beer'}, 233 strict=True)
234
235 - def test_parseExtensionArgs_bothNonStrict(self):
236 req = sreg.SRegRequest() 237 req.parseExtensionArgs({'optional':'nickname', 238 'required':'nickname'}) 239 self.failUnlessEqual([], req.optional) 240 self.failUnlessEqual(['nickname'], req.required)
241
242 - def test_parseExtensionArgs_bothStrict(self):
243 req = sreg.SRegRequest() 244 self.failUnlessRaises( 245 ValueError, 246 req.parseExtensionArgs, 247 {'optional':'nickname', 248 'required':'nickname'}, 249 strict=True)
250
251 - def test_parseExtensionArgs_bothList(self):
252 req = sreg.SRegRequest() 253 req.parseExtensionArgs({'optional':'nickname,email', 254 'required':'country,postcode'}, strict=True) 255 self.failUnlessEqual(['nickname','email'], req.optional) 256 self.failUnlessEqual(['country','postcode'], req.required)
257
258 - def test_allRequestedFields(self):
259 req = sreg.SRegRequest() 260 self.failUnlessEqual([], req.allRequestedFields()) 261 req.requestField('nickname') 262 self.failUnlessEqual(['nickname'], req.allRequestedFields()) 263 req.requestField('gender', required=True) 264 requested = req.allRequestedFields() 265 requested.sort() 266 self.failUnlessEqual(['gender', 'nickname'], requested)
267
268 - def test_wereFieldsRequested(self):
269 req = sreg.SRegRequest() 270 self.failIf(req.wereFieldsRequested()) 271 req.requestField('gender') 272 self.failUnless(req.wereFieldsRequested())
273
274 - def test_contains(self):
275 req = sreg.SRegRequest() 276 for field_name in sreg.data_fields: 277 self.failIf(field_name in req) 278 279 self.failIf('something else' in req) 280 281 req.requestField('nickname') 282 for field_name in sreg.data_fields: 283 if field_name == 'nickname': 284 self.failUnless(field_name in req) 285 else: 286 self.failIf(field_name in req)
287
288 - def test_requestField_bogus(self):
289 req = sreg.SRegRequest() 290 self.failUnlessRaises( 291 ValueError, 292 req.requestField, 'something else') 293 294 self.failUnlessRaises( 295 ValueError, 296 req.requestField, 'something else', strict=True)
297
298 - def test_requestField(self):
299 # Add all of the fields, one at a time 300 req = sreg.SRegRequest() 301 fields = list(sreg.data_fields) 302 for field_name in fields: 303 req.requestField(field_name) 304 305 self.failUnlessEqual(fields, req.optional) 306 self.failUnlessEqual([], req.required) 307 308 # By default, adding the same fields over again has no effect 309 for field_name in fields: 310 req.requestField(field_name) 311 312 self.failUnlessEqual(fields, req.optional) 313 self.failUnlessEqual([], req.required) 314 315 # Requesting a field as required overrides requesting it as optional 316 expected = list(fields) 317 overridden = expected.pop(0) 318 req.requestField(overridden, required=True) 319 self.failUnlessEqual(expected, req.optional) 320 self.failUnlessEqual([overridden], req.required) 321 322 # Requesting a field as required overrides requesting it as optional 323 for field_name in fields: 324 req.requestField(field_name, required=True) 325 326 self.failUnlessEqual([], req.optional) 327 self.failUnlessEqual(fields, req.required) 328 329 # Requesting it as optional does not downgrade it to optional 330 for field_name in fields: 331 req.requestField(field_name) 332 333 self.failUnlessEqual([], req.optional) 334 self.failUnlessEqual(fields, req.required)
335
336 - def test_requestFields_type(self):
337 req = sreg.SRegRequest() 338 self.failUnlessRaises(TypeError, req.requestFields, 'nickname')
339
340 - def test_requestFields(self):
341 # Add all of the fields 342 req = sreg.SRegRequest() 343 344 fields = list(sreg.data_fields) 345 req.requestFields(fields) 346 347 self.failUnlessEqual(fields, req.optional) 348 self.failUnlessEqual([], req.required) 349 350 # By default, adding the same fields over again has no effect 351 req.requestFields(fields) 352 353 self.failUnlessEqual(fields, req.optional) 354 self.failUnlessEqual([], req.required) 355 356 # Requesting a field as required overrides requesting it as optional 357 expected = list(fields) 358 overridden = expected.pop(0) 359 req.requestFields([overridden], required=True) 360 self.failUnlessEqual(expected, req.optional) 361 self.failUnlessEqual([overridden], req.required) 362 363 # Requesting a field as required overrides requesting it as optional 364 req.requestFields(fields, required=True) 365 366 self.failUnlessEqual([], req.optional) 367 self.failUnlessEqual(fields, req.required) 368 369 # Requesting it as optional does not downgrade it to optional 370 req.requestFields(fields) 371 372 self.failUnlessEqual([], req.optional) 373 self.failUnlessEqual(fields, req.required)
374
375 - def test_getExtensionArgs(self):
376 req = sreg.SRegRequest() 377 self.failUnlessEqual({}, req.getExtensionArgs()) 378 379 req.requestField('nickname') 380 self.failUnlessEqual({'optional':'nickname'}, req.getExtensionArgs()) 381 382 req.requestField('email') 383 self.failUnlessEqual({'optional':'nickname,email'}, 384 req.getExtensionArgs()) 385 386 req.requestField('gender', required=True) 387 self.failUnlessEqual({'optional':'nickname,email', 388 'required':'gender'}, 389 req.getExtensionArgs()) 390 391 req.requestField('postcode', required=True) 392 self.failUnlessEqual({'optional':'nickname,email', 393 'required':'gender,postcode'}, 394 req.getExtensionArgs()) 395 396 req.policy_url = 'http://policy.invalid/' 397 self.failUnlessEqual({'optional':'nickname,email', 398 'required':'gender,postcode', 399 'policy_url':'http://policy.invalid/'}, 400 req.getExtensionArgs())
401 402 data = { 403 'nickname':'linusaur', 404 'postcode':'12345', 405 'country':'US', 406 'gender':'M', 407 'fullname':'Leonhard Euler', 408 'email':'president@whitehouse.gov', 409 'dob':'0000-00-00', 410 'language':'en-us', 411 } 412
413 -class DummySuccessResponse(object):
414 - def __init__(self, message, signed_stuff):
415 self.message = message 416 self.signed_stuff = signed_stuff
417
418 - def getSignedNS(self, ns_uri):
419 return self.signed_stuff
420
421 -class SRegResponseTest(unittest.TestCase):
422 - def test_construct(self):
423 resp = sreg.SRegResponse(data) 424 425 self.failUnless(resp) 426 427 empty_resp = sreg.SRegResponse({}) 428 self.failIf(empty_resp)
429 430 # XXX: finish this test 431
433 message = Message.fromOpenIDArgs({ 434 'sreg.nickname':'The Mad Stork', 435 }) 436 success_resp = DummySuccessResponse(message, {}) 437 sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp) 438 self.failIf(sreg_resp)
439
441 message = Message.fromOpenIDArgs({ 442 'sreg.nickname':'The Mad Stork', 443 }) 444 success_resp = DummySuccessResponse(message, {}) 445 sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp, 446 signed_only=False) 447 self.failUnlessEqual([('nickname', 'The Mad Stork')], 448 sreg_resp.items())
449
450 -class SendFieldsTest(unittest.TestCase):
451 - def test(self):
452 # Create a request message with simple registration fields 453 sreg_req = sreg.SRegRequest(required=['nickname', 'email'], 454 optional=['fullname']) 455 req_msg = Message() 456 req_msg.updateArgs(sreg.ns_uri, sreg_req.getExtensionArgs()) 457 458 req = OpenIDRequest() 459 req.message = req_msg 460 req.namespace = req_msg.getOpenIDNamespace() 461 462 # -> send checkid_* request 463 464 # Create an empty response message 465 resp_msg = Message() 466 resp = OpenIDResponse(req) 467 resp.fields = resp_msg 468 469 # Put the requested data fields in the response message 470 sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, data) 471 resp.addExtension(sreg_resp) 472 473 # <- send id_res response 474 475 # Extract the fields that were sent 476 sreg_data_resp = resp_msg.getArgs(sreg.ns_uri) 477 self.failUnlessEqual( 478 {'nickname':'linusaur', 479 'email':'president@whitehouse.gov', 480 'fullname':'Leonhard Euler', 481 }, sreg_data_resp)
482 483 if __name__ == '__main__': 484 unittest.main() 485