diff --git a/src/annot.rs b/src/annot.rs index 128c289..30bf7d3 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -420,7 +420,7 @@ where Ok(AnnotPath { segments }) } - fn parse_datatype_ctor_args(&mut self) -> Result>> { + fn parse_arg_terms(&mut self) -> Result>> { if self.look_ahead_token(0).is_none() { return Ok(Vec::new()); } @@ -478,6 +478,30 @@ where FormulaOrTerm::Term(var, sort.clone()) } _ => { + // If the single-segment identifier is followed by parethesized arguments, + // parse them as user-defined predicate calls. + let next_tt = self.look_ahead_token_tree(0); + + if let Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, args)) = + next_tt + { + let args = args.clone(); + self.consume(); + + let pred_symbol = chc::UserDefinedPred::new(ident.name.to_string()); + let pred = chc::Pred::UserDefined(pred_symbol); + + let mut parser = Parser { + resolver: self.boxed_resolver(), + cursor: args.trees(), + formula_existentials: self.formula_existentials.clone(), + }; + let args = parser.parse_arg_terms()?; + + let atom = chc::Atom::new(pred, args); + let formula = chc::Formula::Atom(atom); + return Ok(FormulaOrTerm::Formula(formula)); + } let (v, sort) = self.resolve(*ident)?; FormulaOrTerm::Term(chc::Term::var(v), sort) } @@ -497,7 +521,7 @@ where cursor: s.trees(), formula_existentials: self.formula_existentials.clone(), }; - let args = parser.parse_datatype_ctor_args()?; + let args = parser.parse_arg_terms()?; parser.end_of_input()?; let (term, sort) = path.to_datatype_ctor(args); FormulaOrTerm::Term(term, sort) diff --git a/src/chc.rs b/src/chc.rs index 67ef92b..e9741bb 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -902,12 +902,39 @@ impl MatcherPred { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPred { + inner: String, +} + +impl std::fmt::Display for UserDefinedPred { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPred +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + allocator.text(self.inner.clone()) + } +} + +impl UserDefinedPred { + pub fn new(inner: String) -> Self { + Self { inner } + } +} + /// A predicate. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pred { Known(KnownPred), Var(PredVarId), Matcher(MatcherPred), + UserDefined(UserDefinedPred), } impl std::fmt::Display for Pred { @@ -916,6 +943,7 @@ impl std::fmt::Display for Pred { Pred::Known(p) => p.fmt(f), Pred::Var(p) => p.fmt(f), Pred::Matcher(p) => p.fmt(f), + Pred::UserDefined(p) => p.fmt(f), } } } @@ -930,6 +958,7 @@ where Pred::Known(p) => p.pretty(allocator), Pred::Var(p) => p.pretty(allocator), Pred::Matcher(p) => p.pretty(allocator), + Pred::UserDefined(p) => p.pretty(allocator), } } } @@ -952,12 +981,19 @@ impl From for Pred { } } +impl From for Pred { + fn from(p: UserDefinedPred) -> Pred { + Pred::UserDefined(p) + } +} + impl Pred { pub fn name(&self) -> std::borrow::Cow<'static, str> { match self { Pred::Known(p) => p.name().into(), Pred::Var(p) => p.to_string().into(), Pred::Matcher(p) => p.name().into(), + Pred::UserDefined(p) => p.to_string().into(), } } @@ -966,6 +1002,7 @@ impl Pred { Pred::Known(p) => p.is_negative(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -974,6 +1011,7 @@ impl Pred { Pred::Known(p) => p.is_infix(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -982,6 +1020,7 @@ impl Pred { Pred::Known(p) => p.is_top(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -990,6 +1029,7 @@ impl Pred { Pred::Known(p) => p.is_bottom(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } } diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 8ed320f..30ea31c 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -42,6 +42,7 @@ fn unbox_pred(pred: Pred) -> Pred { Pred::Known(pred) => Pred::Known(pred), Pred::Var(pred) => Pred::Var(pred), Pred::Matcher(pred) => unbox_matcher_pred(pred), + Pred::UserDefined(pred) => Pred::UserDefined(pred), } } diff --git a/tests/ui/pass/annot_preds_raw_command.rs b/tests/ui/pass/annot_preds_raw_command.rs new file mode 100644 index 0000000..a596c4f --- /dev/null +++ b/tests/ui/pass/annot_preds_raw_command.rs @@ -0,0 +1,21 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#![feature(custom_inner_attributes)] +#![thrust::raw_command("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); +} \ No newline at end of file diff --git a/tests/ui/pass/annot_preds_raw_command_multi.rs b/tests/ui/pass/annot_preds_raw_command_multi.rs new file mode 100644 index 0000000..0dbfb0d --- /dev/null +++ b/tests/ui/pass/annot_preds_raw_command_multi.rs @@ -0,0 +1,36 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#![feature(custom_inner_attributes)] +#![thrust::raw_command("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#![thrust::raw_command("(define-fun is_triple ((x Int) (tripled_x Int)) Bool + (= + (* x 3) + tripled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +#[thrust::require(is_double(x, doubled_x))] +#[thrust::ensures(is_triple(x, result))] +fn triple(x: i64, doubled_x: i64) -> i64 { + x + doubled_x +} + +fn main() { + let a = 3; + let double_a = double(a); + assert!(double_a == 6); + assert!(triple(a, double_a) == 9); +} \ No newline at end of file