@@ -285,6 +285,103 @@ fn set_compare_inner(
285
285
Ok ( vm. new_bool ( true ) )
286
286
}
287
287
288
+ fn set_union ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
289
+ arg_check ! (
290
+ vm,
291
+ args,
292
+ required = [
293
+ ( zelf, Some ( vm. ctx. set_type( ) ) ) ,
294
+ ( other, Some ( vm. ctx. set_type( ) ) )
295
+ ]
296
+ ) ;
297
+
298
+ let mut elements = get_elements ( zelf) . clone ( ) ;
299
+ elements. extend ( get_elements ( other) . clone ( ) ) ;
300
+
301
+ Ok ( PyObject :: new (
302
+ PyObjectPayload :: Set { elements } ,
303
+ vm. ctx . set_type ( ) ,
304
+ ) )
305
+ }
306
+
307
+ fn set_intersection ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
308
+ set_combine_inner ( vm, args, SetCombineOperation :: Intersection )
309
+ }
310
+
311
+ fn set_difference ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
312
+ set_combine_inner ( vm, args, SetCombineOperation :: Difference )
313
+ }
314
+
315
+ fn set_symmetric_difference ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
316
+ arg_check ! (
317
+ vm,
318
+ args,
319
+ required = [
320
+ ( zelf, Some ( vm. ctx. set_type( ) ) ) ,
321
+ ( other, Some ( vm. ctx. set_type( ) ) )
322
+ ]
323
+ ) ;
324
+
325
+ let mut elements = HashMap :: new ( ) ;
326
+
327
+ for element in get_elements ( zelf) . iter ( ) {
328
+ let value = vm. call_method ( other, "__contains__" , vec ! [ element. 1 . clone( ) ] ) ?;
329
+ if !objbool:: get_value ( & value) {
330
+ elements. insert ( element. 0 . clone ( ) , element. 1 . clone ( ) ) ;
331
+ }
332
+ }
333
+
334
+ for element in get_elements ( other) . iter ( ) {
335
+ let value = vm. call_method ( zelf, "__contains__" , vec ! [ element. 1 . clone( ) ] ) ?;
336
+ if !objbool:: get_value ( & value) {
337
+ elements. insert ( element. 0 . clone ( ) , element. 1 . clone ( ) ) ;
338
+ }
339
+ }
340
+
341
+ Ok ( PyObject :: new (
342
+ PyObjectPayload :: Set { elements } ,
343
+ vm. ctx . set_type ( ) ,
344
+ ) )
345
+ }
346
+
347
+ enum SetCombineOperation {
348
+ Intersection ,
349
+ Difference ,
350
+ }
351
+
352
+ fn set_combine_inner (
353
+ vm : & mut VirtualMachine ,
354
+ args : PyFuncArgs ,
355
+ op : SetCombineOperation ,
356
+ ) -> PyResult {
357
+ arg_check ! (
358
+ vm,
359
+ args,
360
+ required = [
361
+ ( zelf, Some ( vm. ctx. set_type( ) ) ) ,
362
+ ( other, Some ( vm. ctx. set_type( ) ) )
363
+ ]
364
+ ) ;
365
+
366
+ let mut elements = HashMap :: new ( ) ;
367
+
368
+ for element in get_elements ( zelf) . iter ( ) {
369
+ let value = vm. call_method ( other, "__contains__" , vec ! [ element. 1 . clone( ) ] ) ?;
370
+ let should_add = match op {
371
+ SetCombineOperation :: Intersection => objbool:: get_value ( & value) ,
372
+ SetCombineOperation :: Difference => !objbool:: get_value ( & value) ,
373
+ } ;
374
+ if should_add {
375
+ elements. insert ( element. 0 . clone ( ) , element. 1 . clone ( ) ) ;
376
+ }
377
+ }
378
+
379
+ Ok ( PyObject :: new (
380
+ PyObjectPayload :: Set { elements } ,
381
+ vm. ctx . set_type ( ) ,
382
+ ) )
383
+ }
384
+
288
385
fn frozenset_repr ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
289
386
arg_check ! ( vm, args, required = [ ( o, Some ( vm. ctx. frozenset_type( ) ) ) ] ) ;
290
387
@@ -325,6 +422,30 @@ pub fn init(context: &PyContext) {
325
422
context. set_attr ( & set_type, "__lt__" , context. new_rustfunc ( set_lt) ) ;
326
423
context. set_attr ( & set_type, "issubset" , context. new_rustfunc ( set_le) ) ;
327
424
context. set_attr ( & set_type, "issuperset" , context. new_rustfunc ( set_ge) ) ;
425
+ context. set_attr ( & set_type, "union" , context. new_rustfunc ( set_union) ) ;
426
+ context. set_attr ( & set_type, "__or__" , context. new_rustfunc ( set_union) ) ;
427
+ context. set_attr (
428
+ & set_type,
429
+ "intersection" ,
430
+ context. new_rustfunc ( set_intersection) ,
431
+ ) ;
432
+ context. set_attr ( & set_type, "__and__" , context. new_rustfunc ( set_intersection) ) ;
433
+ context. set_attr (
434
+ & set_type,
435
+ "difference" ,
436
+ context. new_rustfunc ( set_difference) ,
437
+ ) ;
438
+ context. set_attr ( & set_type, "__sub__" , context. new_rustfunc ( set_difference) ) ;
439
+ context. set_attr (
440
+ & set_type,
441
+ "symmetric_difference" ,
442
+ context. new_rustfunc ( set_symmetric_difference) ,
443
+ ) ;
444
+ context. set_attr (
445
+ & set_type,
446
+ "__xor__" ,
447
+ context. new_rustfunc ( set_symmetric_difference) ,
448
+ ) ;
328
449
context. set_attr ( & set_type, "__doc__" , context. new_str ( set_doc. to_string ( ) ) ) ;
329
450
context. set_attr ( & set_type, "add" , context. new_rustfunc ( set_add) ) ;
330
451
context. set_attr ( & set_type, "remove" , context. new_rustfunc ( set_remove) ) ;
0 commit comments