@@ -17,8 +17,6 @@ use crate::{extract_spans, Diagnostic};
17
17
use once_cell:: sync:: Lazy ;
18
18
use proc_macro2:: { Span , TokenStream } ;
19
19
use quote:: quote;
20
- use rustpython_codegen as codegen;
21
- use rustpython_compiler:: compile;
22
20
use rustpython_compiler_core:: { CodeObject , FrozenModule , Mode } ;
23
21
use std:: {
24
22
collections:: HashMap ,
@@ -51,15 +49,25 @@ struct CompilationSource {
51
49
span : ( Span , Span ) ,
52
50
}
53
51
52
+ pub trait Compiler {
53
+ fn compile (
54
+ & self ,
55
+ source : & str ,
56
+ mode : Mode ,
57
+ module_name : String ,
58
+ ) -> Result < CodeObject , Box < dyn std:: error:: Error > > ;
59
+ }
60
+
54
61
impl CompilationSource {
55
62
fn compile_string < D : std:: fmt:: Display , F : FnOnce ( ) -> D > (
56
63
& self ,
57
64
source : & str ,
58
65
mode : Mode ,
59
66
module_name : String ,
67
+ compiler : & dyn Compiler ,
60
68
origin : F ,
61
69
) -> Result < CodeObject , Diagnostic > {
62
- compile ( source, mode, module_name, codegen :: CompileOpts :: default ( ) ) . map_err ( |err| {
70
+ compiler . compile ( source, mode, module_name) . map_err ( |err| {
63
71
Diagnostic :: spans_error (
64
72
self . span ,
65
73
format ! ( "Python compile error from {}: {}" , origin( ) , err) ,
@@ -71,21 +79,30 @@ impl CompilationSource {
71
79
& self ,
72
80
mode : Mode ,
73
81
module_name : String ,
82
+ compiler : & dyn Compiler ,
74
83
) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
75
84
match & self . kind {
76
- CompilationSourceKind :: Dir ( rel_path) => {
77
- self . compile_dir ( & CARGO_MANIFEST_DIR . join ( rel_path) , String :: new ( ) , mode)
78
- }
85
+ CompilationSourceKind :: Dir ( rel_path) => self . compile_dir (
86
+ & CARGO_MANIFEST_DIR . join ( rel_path) ,
87
+ String :: new ( ) ,
88
+ mode,
89
+ compiler,
90
+ ) ,
79
91
_ => Ok ( hashmap ! {
80
92
module_name. clone( ) => FrozenModule {
81
- code: self . compile_single( mode, module_name) ?,
93
+ code: self . compile_single( mode, module_name, compiler ) ?,
82
94
package: false ,
83
95
} ,
84
96
} ) ,
85
97
}
86
98
}
87
99
88
- fn compile_single ( & self , mode : Mode , module_name : String ) -> Result < CodeObject , Diagnostic > {
100
+ fn compile_single (
101
+ & self ,
102
+ mode : Mode ,
103
+ module_name : String ,
104
+ compiler : & dyn Compiler ,
105
+ ) -> Result < CodeObject , Diagnostic > {
89
106
match & self . kind {
90
107
CompilationSourceKind :: File ( rel_path) => {
91
108
let path = CARGO_MANIFEST_DIR . join ( rel_path) ;
@@ -95,10 +112,10 @@ impl CompilationSource {
95
112
format ! ( "Error reading file {:?}: {}" , path, err) ,
96
113
)
97
114
} ) ?;
98
- self . compile_string ( & source, mode, module_name, || rel_path. display ( ) )
115
+ self . compile_string ( & source, mode, module_name, compiler , || rel_path. display ( ) )
99
116
}
100
117
CompilationSourceKind :: SourceCode ( code) => {
101
- self . compile_string ( & textwrap:: dedent ( code) , mode, module_name, || {
118
+ self . compile_string ( & textwrap:: dedent ( code) , mode, module_name, compiler , || {
102
119
"string literal"
103
120
} )
104
121
}
@@ -113,6 +130,7 @@ impl CompilationSource {
113
130
path : & Path ,
114
131
parent : String ,
115
132
mode : Mode ,
133
+ compiler : & dyn Compiler ,
116
134
) -> Result < HashMap < String , FrozenModule > , Diagnostic > {
117
135
let mut code_map = HashMap :: new ( ) ;
118
136
let paths = fs:: read_dir ( path)
@@ -144,6 +162,7 @@ impl CompilationSource {
144
162
format ! ( "{}.{}" , parent, file_name)
145
163
} ,
146
164
mode,
165
+ compiler,
147
166
) ?) ;
148
167
} else if file_name. ends_with ( ".py" ) {
149
168
let stem = path. file_stem ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
@@ -163,7 +182,7 @@ impl CompilationSource {
163
182
format ! ( "Error reading file {:?}: {}" , path, err) ,
164
183
)
165
184
} ) ?;
166
- self . compile_string ( & source, mode, module_name. clone ( ) , || {
185
+ self . compile_string ( & source, mode, module_name. clone ( ) , compiler , || {
167
186
path. strip_prefix ( & * CARGO_MANIFEST_DIR )
168
187
. ok ( )
169
188
. unwrap_or ( & path)
@@ -239,35 +258,28 @@ impl PyCompileInput {
239
258
Some ( ident) => ident,
240
259
None => continue ,
241
260
} ;
261
+ let check_str = || match & name_value. lit {
262
+ Lit :: Str ( s) => Ok ( s) ,
263
+ _ => Err ( err_span ! ( name_value. lit, "{ident} must be a string" ) ) ,
264
+ } ;
242
265
if ident == "mode" {
243
- match & name_value. lit {
244
- Lit :: Str ( s) => match s. value ( ) . parse ( ) {
245
- Ok ( mode_val) => mode = Some ( mode_val) ,
246
- Err ( e) => bail_span ! ( s, "{}" , e) ,
247
- } ,
248
- _ => bail_span ! ( name_value. lit, "mode must be a string" ) ,
266
+ let s = check_str ( ) ?;
267
+ match s. value ( ) . parse ( ) {
268
+ Ok ( mode_val) => mode = Some ( mode_val) ,
269
+ Err ( e) => bail_span ! ( s, "{}" , e) ,
249
270
}
250
271
} else if ident == "module_name" {
251
- module_name = Some ( match & name_value. lit {
252
- Lit :: Str ( s) => s. value ( ) ,
253
- _ => bail_span ! ( name_value. lit, "module_name must be string" ) ,
254
- } )
272
+ module_name = Some ( check_str ( ) ?. value ( ) )
255
273
} else if ident == "source" {
256
274
assert_source_empty ( & source) ?;
257
- let code = match & name_value. lit {
258
- Lit :: Str ( s) => s. value ( ) ,
259
- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
260
- } ;
275
+ let code = check_str ( ) ?. value ( ) ;
261
276
source = Some ( CompilationSource {
262
277
kind : CompilationSourceKind :: SourceCode ( code) ,
263
278
span : extract_spans ( & name_value) . unwrap ( ) ,
264
279
} ) ;
265
280
} else if ident == "file" {
266
281
assert_source_empty ( & source) ?;
267
- let path = match & name_value. lit {
268
- Lit :: Str ( s) => PathBuf :: from ( s. value ( ) ) ,
269
- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
270
- } ;
282
+ let path = check_str ( ) ?. value ( ) . into ( ) ;
271
283
source = Some ( CompilationSource {
272
284
kind : CompilationSourceKind :: File ( path) ,
273
285
span : extract_spans ( & name_value) . unwrap ( ) ,
@@ -278,19 +290,13 @@ impl PyCompileInput {
278
290
}
279
291
280
292
assert_source_empty ( & source) ?;
281
- let path = match & name_value. lit {
282
- Lit :: Str ( s) => PathBuf :: from ( s. value ( ) ) ,
283
- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
284
- } ;
293
+ let path = check_str ( ) ?. value ( ) . into ( ) ;
285
294
source = Some ( CompilationSource {
286
295
kind : CompilationSourceKind :: Dir ( path) ,
287
296
span : extract_spans ( & name_value) . unwrap ( ) ,
288
297
} ) ;
289
298
} else if ident == "crate_name" {
290
- let name = match & name_value. lit {
291
- Lit :: Str ( s) => s. parse ( ) ?,
292
- _ => bail_span ! ( name_value. lit, "source must be a string" ) ,
293
- } ;
299
+ let name = check_str ( ) ?. parse ( ) ?;
294
300
crate_name = Some ( name) ;
295
301
}
296
302
}
@@ -351,12 +357,17 @@ struct PyCompileArgs {
351
357
crate_name : syn:: Path ,
352
358
}
353
359
354
- pub fn impl_py_compile ( input : TokenStream ) -> Result < TokenStream , Diagnostic > {
360
+ pub fn impl_py_compile (
361
+ input : TokenStream ,
362
+ compiler : & dyn Compiler ,
363
+ ) -> Result < TokenStream , Diagnostic > {
355
364
let input: PyCompileInput = parse2 ( input) ?;
356
365
let args = input. parse ( false ) ?;
357
366
358
367
let crate_name = args. crate_name ;
359
- let code = args. source . compile_single ( args. mode , args. module_name ) ?;
368
+ let code = args
369
+ . source
370
+ . compile_single ( args. mode , args. module_name , compiler) ?;
360
371
361
372
let bytes = code. to_bytes ( ) ;
362
373
let bytes = LitByteStr :: new ( & bytes, Span :: call_site ( ) ) ;
@@ -369,12 +380,15 @@ pub fn impl_py_compile(input: TokenStream) -> Result<TokenStream, Diagnostic> {
369
380
Ok ( output)
370
381
}
371
382
372
- pub fn impl_py_freeze ( input : TokenStream ) -> Result < TokenStream , Diagnostic > {
383
+ pub fn impl_py_freeze (
384
+ input : TokenStream ,
385
+ compiler : & dyn Compiler ,
386
+ ) -> Result < TokenStream , Diagnostic > {
373
387
let input: PyCompileInput = parse2 ( input) ?;
374
388
let args = input. parse ( true ) ?;
375
389
376
390
let crate_name = args. crate_name ;
377
- let code_map = args. source . compile ( args. mode , args. module_name ) ?;
391
+ let code_map = args. source . compile ( args. mode , args. module_name , compiler ) ?;
378
392
379
393
let data =
380
394
rustpython_compiler_core:: frozen_lib:: encode_lib ( code_map. iter ( ) . map ( |( k, v) | ( & * * k, v) ) ) ;
0 commit comments