Skip to content

Commit 396a0ca

Browse files
authored
Basic Match statements (RustPython#5485)
Signed-off-by: Ashwin Naren <[email protected]>
1 parent a500178 commit 396a0ca

File tree

3 files changed

+166
-9
lines changed

3 files changed

+166
-9
lines changed

compiler/codegen/src/compile.rs

Lines changed: 151 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ use num_complex::Complex64;
1818
use num_traits::ToPrimitive;
1919
use rustpython_ast::located::{self as located_ast, Located};
2020
use rustpython_compiler_core::{
21-
bytecode::{self, Arg as OpArgMarker, CodeObject, ConstantData, Instruction, OpArg, OpArgType},
21+
bytecode::{
22+
self, Arg as OpArgMarker, CodeObject, ComparisonOperator, ConstantData, Instruction, OpArg,
23+
OpArgType,
24+
},
2225
Mode,
2326
};
2427
use rustpython_parser_core::source_code::{LineNumber, SourceLocation};
@@ -211,6 +214,12 @@ macro_rules! emit {
211214
};
212215
}
213216

217+
struct PatternContext {
218+
current_block: usize,
219+
blocks: Vec<ir::BlockIdx>,
220+
allow_irrefutable: bool,
221+
}
222+
214223
impl Compiler {
215224
fn new(opts: CompileOpts, source_path: String, code_name: String) -> Self {
216225
let module_code = ir::CodeInfo {
@@ -1755,14 +1764,152 @@ impl Compiler {
17551764
Ok(())
17561765
}
17571766

1767+
fn compile_pattern_value(
1768+
&mut self,
1769+
value: &located_ast::PatternMatchValue,
1770+
_pattern_context: &mut PatternContext,
1771+
) -> CompileResult<()> {
1772+
self.compile_expression(&value.value)?;
1773+
emit!(
1774+
self,
1775+
Instruction::CompareOperation {
1776+
op: ComparisonOperator::Equal
1777+
}
1778+
);
1779+
Ok(())
1780+
}
1781+
1782+
fn compile_pattern_as(
1783+
&mut self,
1784+
as_pattern: &located_ast::PatternMatchAs,
1785+
pattern_context: &mut PatternContext,
1786+
) -> CompileResult<()> {
1787+
if as_pattern.pattern.is_none() && !pattern_context.allow_irrefutable {
1788+
// TODO: better error message
1789+
if let Some(_name) = as_pattern.name.as_ref() {
1790+
return Err(
1791+
self.error_loc(CodegenErrorType::InvalidMatchCase, as_pattern.location())
1792+
);
1793+
}
1794+
return Err(self.error_loc(CodegenErrorType::InvalidMatchCase, as_pattern.location()));
1795+
}
1796+
// Need to make a copy for (possibly) storing later:
1797+
emit!(self, Instruction::Duplicate);
1798+
if let Some(pattern) = &as_pattern.pattern {
1799+
self.compile_pattern_inner(pattern, pattern_context)?;
1800+
}
1801+
if let Some(name) = as_pattern.name.as_ref() {
1802+
self.store_name(name.as_str())?;
1803+
} else {
1804+
emit!(self, Instruction::Pop);
1805+
}
1806+
Ok(())
1807+
}
1808+
1809+
fn compile_pattern_inner(
1810+
&mut self,
1811+
pattern_type: &located_ast::Pattern,
1812+
pattern_context: &mut PatternContext,
1813+
) -> CompileResult<()> {
1814+
match &pattern_type {
1815+
located_ast::Pattern::MatchValue(value) => {
1816+
self.compile_pattern_value(value, pattern_context)
1817+
}
1818+
located_ast::Pattern::MatchAs(as_pattern) => {
1819+
self.compile_pattern_as(as_pattern, pattern_context)
1820+
}
1821+
_ => {
1822+
eprintln!("not implemented pattern type: {pattern_type:?}");
1823+
Err(self.error(CodegenErrorType::NotImplementedYet))
1824+
}
1825+
}
1826+
}
1827+
1828+
fn compile_pattern(
1829+
&mut self,
1830+
pattern_type: &located_ast::Pattern,
1831+
pattern_context: &mut PatternContext,
1832+
) -> CompileResult<()> {
1833+
self.compile_pattern_inner(pattern_type, pattern_context)?;
1834+
emit!(
1835+
self,
1836+
Instruction::JumpIfFalse {
1837+
target: pattern_context.blocks[pattern_context.current_block + 1]
1838+
}
1839+
);
1840+
Ok(())
1841+
}
1842+
1843+
fn compile_match_inner(
1844+
&mut self,
1845+
subject: &located_ast::Expr,
1846+
cases: &[located_ast::MatchCase],
1847+
pattern_context: &mut PatternContext,
1848+
) -> CompileResult<()> {
1849+
self.compile_expression(subject)?;
1850+
pattern_context.blocks = std::iter::repeat_with(|| self.new_block())
1851+
.take(cases.len() + 1)
1852+
.collect::<Vec<_>>();
1853+
let end_block = *pattern_context.blocks.last().unwrap();
1854+
1855+
let _match_case_type = cases.last().expect("cases is not empty");
1856+
// TODO: get proper check for default case
1857+
// let has_default = match_case_type.pattern.is_match_as() && 1 < cases.len();
1858+
let has_default = false;
1859+
for i in 0..cases.len() - (has_default as usize) {
1860+
self.switch_to_block(pattern_context.blocks[i]);
1861+
pattern_context.current_block = i;
1862+
pattern_context.allow_irrefutable = cases[i].guard.is_some() || i == cases.len() - 1;
1863+
let m = &cases[i];
1864+
// Only copy the subject if we're *not* on the last case:
1865+
if i != cases.len() - has_default as usize - 1 {
1866+
emit!(self, Instruction::Duplicate);
1867+
}
1868+
self.compile_pattern(&m.pattern, pattern_context)?;
1869+
self.compile_statements(&m.body)?;
1870+
emit!(self, Instruction::Jump { target: end_block });
1871+
}
1872+
// TODO: below code is not called and does not work
1873+
if has_default {
1874+
// A trailing "case _" is common, and lets us save a bit of redundant
1875+
// pushing and popping in the loop above:
1876+
let m = &cases.last().unwrap();
1877+
self.switch_to_block(*pattern_context.blocks.last().unwrap());
1878+
if cases.len() == 1 {
1879+
// No matches. Done with the subject:
1880+
emit!(self, Instruction::Pop);
1881+
} else {
1882+
// Show line coverage for default case (it doesn't create bytecode)
1883+
// emit!(self, Instruction::Nop);
1884+
}
1885+
self.compile_statements(&m.body)?;
1886+
}
1887+
1888+
self.switch_to_block(end_block);
1889+
1890+
let code = self.current_code_info();
1891+
pattern_context
1892+
.blocks
1893+
.iter()
1894+
.zip(pattern_context.blocks.iter().skip(1))
1895+
.for_each(|(a, b)| {
1896+
code.blocks[a.0 as usize].next = *b;
1897+
});
1898+
Ok(())
1899+
}
1900+
17581901
fn compile_match(
17591902
&mut self,
17601903
subject: &located_ast::Expr,
17611904
cases: &[located_ast::MatchCase],
17621905
) -> CompileResult<()> {
1763-
eprintln!("match subject: {subject:?}");
1764-
eprintln!("match cases: {cases:?}");
1765-
Err(self.error(CodegenErrorType::NotImplementedYet))
1906+
let mut pattern_context = PatternContext {
1907+
current_block: usize::MAX,
1908+
blocks: Vec::new(),
1909+
allow_irrefutable: false,
1910+
};
1911+
self.compile_match_inner(subject, cases, &mut pattern_context)?;
1912+
Ok(())
17661913
}
17671914

17681915
fn compile_chained_comparison(

compiler/codegen/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ pub enum CodegenErrorType {
3030
TooManyStarUnpack,
3131
EmptyWithItems,
3232
EmptyWithBody,
33+
DuplicateStore(String),
34+
InvalidMatchCase,
3335
NotImplementedYet, // RustPython marker for unimplemented features
3436
}
3537

@@ -75,6 +77,12 @@ impl fmt::Display for CodegenErrorType {
7577
EmptyWithBody => {
7678
write!(f, "empty body on With")
7779
}
80+
DuplicateStore(s) => {
81+
write!(f, "duplicate store {s}")
82+
}
83+
InvalidMatchCase => {
84+
write!(f, "invalid match case")
85+
}
7886
NotImplementedYet => {
7987
write!(f, "RustPython does not implement this feature yet")
8088
}

compiler/codegen/src/symboltable.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -886,11 +886,13 @@ impl SymbolTableBuilder {
886886
self.scan_statements(orelse)?;
887887
self.scan_statements(finalbody)?;
888888
}
889-
Stmt::Match(StmtMatch { subject, .. }) => {
890-
return Err(SymbolTableError {
891-
error: "match expression is not implemented yet".to_owned(),
892-
location: Some(subject.location()),
893-
});
889+
Stmt::Match(StmtMatch { subject, cases, .. }) => {
890+
self.scan_expression(subject, ExpressionContext::Load)?;
891+
for case in cases {
892+
// TODO: below
893+
// self.scan_pattern(&case.pattern, ExpressionContext::Load)?;
894+
self.scan_statements(&case.body)?;
895+
}
894896
}
895897
Stmt::Raise(StmtRaise { exc, cause, .. }) => {
896898
if let Some(expression) = exc {

0 commit comments

Comments
 (0)