1 from openid.extensions import sreg
2 from openid.message import NamespaceMap, Message, registerNamespaceAlias
3 from openid.server.server import OpenIDRequest, OpenIDResponse
4
5 import unittest
6
10
15
17 self.failUnlessRaises(ValueError, sreg.checkFieldName, 'INVALID')
18
20 self.failUnlessRaises(ValueError, sreg.checkFieldName, None)
21
22
25 self.supported = supported
26 self.checked_uris = []
27
29 self.checked_uris.append(namespace_uri)
30 return namespace_uri in self.supported
31
49
57
61
66
72
78
80 for openid_version in [True, False]:
81 for sreg_version in [sreg.ns_uri_1_0, sreg.ns_uri_1_1]:
82 for alias in ['sreg', 'bogus']:
83 self.setUp()
84
85 self.msg.openid1 = openid_version
86 self.msg.namespaces.addAlias(sreg_version, alias)
87 ns_uri = sreg.getSRegNS(self.msg)
88 self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), alias)
89 self.failUnlessEqual(sreg_version, ns_uri)
90
92 self.msg.openid1 = True
93 self.msg.namespaces.addAlias('http://invalid/', 'sreg')
94 self.failUnlessRaises(sreg.SRegNamespaceError,
95 sreg.getSRegNS, self.msg)
96
98 self.msg.openid1 = False
99 self.msg.namespaces.addAlias('http://invalid/', 'sreg')
100 self.failUnlessRaises(sreg.SRegNamespaceError,
101 sreg.getSRegNS, self.msg)
102
107
118
121 req = sreg.SRegRequest()
122 self.failUnlessEqual([], req.optional)
123 self.failUnlessEqual([], req.required)
124 self.failUnlessEqual(None, req.policy_url)
125 self.failUnlessEqual(sreg.ns_uri, req.ns_uri)
126
128 req = sreg.SRegRequest(
129 ['nickname'],
130 ['gender'],
131 'http://policy',
132 'http://sreg.ns_uri')
133 self.failUnlessEqual(['gender'], req.optional)
134 self.failUnlessEqual(['nickname'], req.required)
135 self.failUnlessEqual('http://policy', req.policy_url)
136 self.failUnlessEqual('http://sreg.ns_uri', req.ns_uri)
137
139 self.failUnlessRaises(
140 ValueError,
141 sreg.SRegRequest, ['elvis'])
142
144 args = {}
145 ns_sentinel = object()
146 args_sentinel = object()
147
148 class FakeMessage(object):
149 copied = False
150
151 def __init__(self):
152 self.message = Message()
153
154 def getArgs(msg_self, ns_uri):
155 self.failUnlessEqual(ns_sentinel, ns_uri)
156 return args_sentinel
157
158 def copy(msg_self):
159 msg_self.copied = True
160 return msg_self
161
162 class TestingReq(sreg.SRegRequest):
163 def _getSRegNS(req_self, unused):
164 return ns_sentinel
165
166 def parseExtensionArgs(req_self, args):
167 self.failUnlessEqual(args_sentinel, args)
168
169 openid_req = OpenIDRequest()
170
171 msg = FakeMessage()
172 openid_req.message = msg
173
174 req = TestingReq.fromOpenIDRequest(openid_req)
175 self.failUnless(type(req) is TestingReq)
176 self.failUnless(msg.copied)
177
182
186
191
197
202
207
212
217
222
227
234
236 req = sreg.SRegRequest()
237 req.parseExtensionArgs({'optional':'nickname',
238 'required':'nickname'})
239 self.failUnlessEqual([], req.optional)
240 self.failUnlessEqual(['nickname'], req.required)
241
243 req = sreg.SRegRequest()
244 self.failUnlessRaises(
245 ValueError,
246 req.parseExtensionArgs,
247 {'optional':'nickname',
248 'required':'nickname'},
249 strict=True)
250
252 req = sreg.SRegRequest()
253 req.parseExtensionArgs({'optional':'nickname,email',
254 'required':'country,postcode'}, strict=True)
255 self.failUnlessEqual(['nickname','email'], req.optional)
256 self.failUnlessEqual(['country','postcode'], req.required)
257
267
273
275 req = sreg.SRegRequest()
276 for field_name in sreg.data_fields:
277 self.failIf(field_name in req)
278
279 self.failIf('something else' in req)
280
281 req.requestField('nickname')
282 for field_name in sreg.data_fields:
283 if field_name == 'nickname':
284 self.failUnless(field_name in req)
285 else:
286 self.failIf(field_name in req)
287
289 req = sreg.SRegRequest()
290 self.failUnlessRaises(
291 ValueError,
292 req.requestField, 'something else')
293
294 self.failUnlessRaises(
295 ValueError,
296 req.requestField, 'something else', strict=True)
297
299
300 req = sreg.SRegRequest()
301 fields = list(sreg.data_fields)
302 for field_name in fields:
303 req.requestField(field_name)
304
305 self.failUnlessEqual(fields, req.optional)
306 self.failUnlessEqual([], req.required)
307
308
309 for field_name in fields:
310 req.requestField(field_name)
311
312 self.failUnlessEqual(fields, req.optional)
313 self.failUnlessEqual([], req.required)
314
315
316 expected = list(fields)
317 overridden = expected.pop(0)
318 req.requestField(overridden, required=True)
319 self.failUnlessEqual(expected, req.optional)
320 self.failUnlessEqual([overridden], req.required)
321
322
323 for field_name in fields:
324 req.requestField(field_name, required=True)
325
326 self.failUnlessEqual([], req.optional)
327 self.failUnlessEqual(fields, req.required)
328
329
330 for field_name in fields:
331 req.requestField(field_name)
332
333 self.failUnlessEqual([], req.optional)
334 self.failUnlessEqual(fields, req.required)
335
339
341
342 req = sreg.SRegRequest()
343
344 fields = list(sreg.data_fields)
345 req.requestFields(fields)
346
347 self.failUnlessEqual(fields, req.optional)
348 self.failUnlessEqual([], req.required)
349
350
351 req.requestFields(fields)
352
353 self.failUnlessEqual(fields, req.optional)
354 self.failUnlessEqual([], req.required)
355
356
357 expected = list(fields)
358 overridden = expected.pop(0)
359 req.requestFields([overridden], required=True)
360 self.failUnlessEqual(expected, req.optional)
361 self.failUnlessEqual([overridden], req.required)
362
363
364 req.requestFields(fields, required=True)
365
366 self.failUnlessEqual([], req.optional)
367 self.failUnlessEqual(fields, req.required)
368
369
370 req.requestFields(fields)
371
372 self.failUnlessEqual([], req.optional)
373 self.failUnlessEqual(fields, req.required)
374
376 req = sreg.SRegRequest()
377 self.failUnlessEqual({}, req.getExtensionArgs())
378
379 req.requestField('nickname')
380 self.failUnlessEqual({'optional':'nickname'}, req.getExtensionArgs())
381
382 req.requestField('email')
383 self.failUnlessEqual({'optional':'nickname,email'},
384 req.getExtensionArgs())
385
386 req.requestField('gender', required=True)
387 self.failUnlessEqual({'optional':'nickname,email',
388 'required':'gender'},
389 req.getExtensionArgs())
390
391 req.requestField('postcode', required=True)
392 self.failUnlessEqual({'optional':'nickname,email',
393 'required':'gender,postcode'},
394 req.getExtensionArgs())
395
396 req.policy_url = 'http://policy.invalid/'
397 self.failUnlessEqual({'optional':'nickname,email',
398 'required':'gender,postcode',
399 'policy_url':'http://policy.invalid/'},
400 req.getExtensionArgs())
401
402 data = {
403 'nickname':'linusaur',
404 'postcode':'12345',
405 'country':'US',
406 'gender':'M',
407 'fullname':'Leonhard Euler',
408 'email':'president@whitehouse.gov',
409 'dob':'0000-00-00',
410 'language':'en-us',
411 }
412
414 - def __init__(self, message, signed_stuff):
417
419 return self.signed_stuff
420
449
482
483 if __name__ == '__main__':
484 unittest.main()
485