glsl_lang_pp/processor/
expr.rs

1use std::iter::Peekable;
2
3use crate::{
4    parser::SyntaxKind::{self, *},
5    types::Token,
6    util::Unescaped,
7};
8
9use super::{event::OutputToken, ProcessorState};
10
11#[derive(Debug, Clone)]
12pub struct ExprEvaluator<'i, I: Iterator<Item = &'i OutputToken>> {
13    input: Peekable<I>,
14    state: &'i ProcessorState,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum EvalResult {
19    Constant(Result<i32, ()>),
20    Token(OutputToken),
21}
22
23impl<'i, I: Iterator<Item = &'i OutputToken>> ExprEvaluator<'i, I> {
24    pub fn new(input: I, state: &'i ProcessorState) -> Self {
25        Self {
26            input: input.peekable(),
27            state,
28        }
29    }
30
31    fn bump(&mut self) -> Option<&'i OutputToken> {
32        loop {
33            let token = self.input.next();
34            if let Some(token) = token {
35                if !token.kind().is_trivia() {
36                    return Some(token);
37                }
38            } else {
39                return None;
40            }
41        }
42    }
43
44    fn peek(&mut self) -> Option<&'i OutputToken> {
45        loop {
46            let token = self.input.peek().copied();
47            if let Some(token) = token {
48                if !token.kind().is_trivia() {
49                    return Some(token);
50                } else {
51                    self.input.next();
52                }
53            } else {
54                return None;
55            }
56        }
57    }
58
59    fn peek_kind(&mut self) -> Option<SyntaxKind> {
60        self.peek().map(|token| token.kind())
61    }
62
63    fn primary(&mut self) -> Option<Result<i32, ()>> {
64        let token = self.peek()?;
65
66        match token.kind() {
67            DIGITS => {
68                // Try to parse the value before bumping. If parsing fails, we'll return the DIGITS
69                // token unparsed
70                let value = match Token::parse_digits(&Unescaped::new(token.text()).to_string()) {
71                    Token::UINT_CONST(value) if value <= i32::MAX as u32 => Some(value as i32),
72                    Token::INT_CONST(value) => Some(value),
73                    _ => None,
74                };
75                if let Some(value) = value {
76                    self.bump();
77                    return Some(Ok(value));
78                }
79            }
80            IDENT_KW => {
81                self.bump();
82                return Some(Ok(0));
83            }
84            LPAREN => {
85                self.bump();
86
87                let inner = self.expr();
88
89                // Consume RPAREN, or return None
90                if self.peek()?.kind() == RPAREN {
91                    self.bump();
92                    return inner;
93                }
94            }
95            _ => {}
96        }
97
98        None
99    }
100
101    fn unary(&mut self) -> Option<Result<i32, ()>> {
102        match self.peek_kind()? {
103            PLUS => {
104                self.bump();
105                self.unary()
106            }
107            DASH => {
108                self.bump();
109                self.unary().map(|result| result.map(|val| -val))
110            }
111            TILDE => {
112                self.bump();
113                self.unary().map(|result| result.map(|val| !val))
114            }
115            BANG => {
116                self.bump();
117                self.unary()
118                    .map(|result| result.map(|val| if val == 0 { 1 } else { 0 }))
119            }
120            DEFINED => {
121                self.bump();
122
123                match self.peek()?.kind() {
124                    IDENT_KW => {
125                        // Free-standing form, get the ident name
126                        let ident = Unescaped::new(self.bump()?.text()).to_string();
127                        Some(Ok(if self.state.get_definition(&ident).is_some() {
128                            1
129                        } else {
130                            0
131                        }))
132                    }
133                    LPAREN => {
134                        // Parenthesis form
135                        self.bump();
136
137                        // Try to find an ident
138                        if let Some(ident) = self.peek().and_then(|token| {
139                            if token.kind() == IDENT_KW {
140                                Some(Unescaped::new(token.text()).to_string())
141                            } else {
142                                None
143                            }
144                        }) {
145                            // Found an ident, bump it
146                            self.bump();
147
148                            // Look for the RPAREN
149                            if self.peek()?.kind() == RPAREN {
150                                self.bump();
151                            } else {
152                                // Missing RPAREN or extra tokens
153                                return None;
154                            }
155
156                            return Some(Ok(if self.state.get_definition(&ident).is_some() {
157                                1
158                            } else {
159                                0
160                            }));
161                        }
162
163                        // Invalid
164                        None
165                    }
166                    _ => {
167                        // Invalid
168                        None
169                    }
170                }
171            }
172            _ => self.primary(),
173        }
174    }
175
176    fn binary_op(
177        lhs: Option<Result<i32, ()>>,
178        rhs: Option<Result<i32, ()>>,
179        f: impl FnOnce(i32, i32) -> Result<(i32, bool), ()>,
180    ) -> Option<Result<i32, ()>> {
181        lhs.zip(rhs).map(|(lhs, rhs)| {
182            lhs.and_then(|a| {
183                rhs.and_then(|b| f(a, b).and_then(|(val, ovf)| if ovf { Err(()) } else { Ok(val) }))
184            })
185        })
186    }
187
188    fn multiplicative(&mut self) -> Option<Result<i32, ()>> {
189        let mut lhs = self.unary();
190
191        while let Some(kind) = self.peek_kind() {
192            match kind {
193                ASTERISK => {
194                    self.bump();
195                    lhs = Self::binary_op(lhs, self.unary(), |a, b| Ok(a.overflowing_mul(b)));
196                }
197                SLASH => {
198                    self.bump();
199                    lhs = Self::binary_op(lhs, self.unary(), |a, b| {
200                        if b == 0 {
201                            Err(())
202                        } else {
203                            Ok(a.overflowing_div(b))
204                        }
205                    });
206                }
207                PERCENT => {
208                    self.bump();
209                    lhs = Self::binary_op(lhs, self.unary(), |a, b| {
210                        if b == 0 {
211                            Err(())
212                        } else {
213                            Ok(a.overflowing_rem(b))
214                        }
215                    });
216                }
217                _ => {
218                    break;
219                }
220            }
221        }
222
223        lhs
224    }
225
226    fn additive(&mut self) -> Option<Result<i32, ()>> {
227        let mut lhs = self.multiplicative();
228
229        while let Some(kind) = self.peek_kind() {
230            match kind {
231                PLUS => {
232                    self.bump();
233                    lhs =
234                        Self::binary_op(
235                            lhs,
236                            self.multiplicative(),
237                            |a, b| Ok(a.overflowing_add(b)),
238                        );
239                }
240                DASH => {
241                    self.bump();
242                    lhs =
243                        Self::binary_op(
244                            lhs,
245                            self.multiplicative(),
246                            |a, b| Ok(a.overflowing_sub(b)),
247                        );
248                }
249                _ => {
250                    break;
251                }
252            }
253        }
254
255        lhs
256    }
257
258    fn shift(&mut self) -> Option<Result<i32, ()>> {
259        let mut lhs = self.additive();
260
261        while let Some(kind) = self.peek_kind() {
262            match kind {
263                LEFT_OP => {
264                    self.bump();
265                    lhs = Self::binary_op(lhs, self.additive(), |a, b| {
266                        if b < 0 {
267                            Err(())
268                        } else {
269                            Ok(a.overflowing_shl(b as u32))
270                        }
271                    });
272                }
273                RIGHT_OP => {
274                    self.bump();
275                    lhs = Self::binary_op(lhs, self.additive(), |a, b| {
276                        if b < 0 {
277                            Err(())
278                        } else {
279                            Ok(a.overflowing_shr(b as u32))
280                        }
281                    });
282                }
283                _ => {
284                    break;
285                }
286            }
287        }
288
289        lhs
290    }
291
292    fn relational(&mut self) -> Option<Result<i32, ()>> {
293        let mut lhs = self.shift();
294
295        while let Some(kind) = self.peek_kind() {
296            match kind {
297                LANGLE => {
298                    self.bump();
299                    lhs = Self::binary_op(lhs, self.shift(), |a, b| {
300                        Ok((if a < b { 1 } else { 0 }, false))
301                    });
302                }
303                RANGLE => {
304                    self.bump();
305                    lhs = Self::binary_op(lhs, self.shift(), |a, b| {
306                        Ok((if a > b { 1 } else { 0 }, false))
307                    });
308                }
309                LE_OP => {
310                    self.bump();
311                    lhs = Self::binary_op(lhs, self.shift(), |a, b| {
312                        Ok((if a <= b { 1 } else { 0 }, false))
313                    });
314                }
315                GE_OP => {
316                    self.bump();
317                    lhs = Self::binary_op(lhs, self.shift(), |a, b| {
318                        Ok((if a >= b { 1 } else { 0 }, false))
319                    });
320                }
321                _ => {
322                    break;
323                }
324            }
325        }
326
327        lhs
328    }
329
330    fn equality(&mut self) -> Option<Result<i32, ()>> {
331        let mut lhs = self.relational();
332
333        while let Some(kind) = self.peek_kind() {
334            match kind {
335                EQ_OP => {
336                    self.bump();
337                    lhs = Self::binary_op(lhs, self.relational(), |a, b| {
338                        Ok((if a == b { 1 } else { 0 }, false))
339                    });
340                }
341                NE_OP => {
342                    self.bump();
343                    lhs = Self::binary_op(lhs, self.relational(), |a, b| {
344                        Ok((if a != b { 1 } else { 0 }, false))
345                    });
346                }
347                _ => {
348                    break;
349                }
350            }
351        }
352
353        lhs
354    }
355
356    fn and(&mut self) -> Option<Result<i32, ()>> {
357        let mut lhs = self.equality();
358
359        while self
360            .peek_kind()
361            .map(|kind| kind == AMPERSAND)
362            .unwrap_or(false)
363        {
364            self.bump();
365            lhs = Self::binary_op(lhs, self.equality(), |a, b| Ok((a & b, false)));
366        }
367
368        lhs
369    }
370
371    fn xor(&mut self) -> Option<Result<i32, ()>> {
372        let mut lhs = self.and();
373
374        while self.peek_kind().map(|kind| kind == CARET).unwrap_or(false) {
375            self.bump();
376            lhs = Self::binary_op(lhs, self.and(), |a, b| Ok((a ^ b, false)));
377        }
378
379        lhs
380    }
381
382    fn or(&mut self) -> Option<Result<i32, ()>> {
383        let mut lhs = self.xor();
384
385        while self.peek_kind().map(|kind| kind == BAR).unwrap_or(false) {
386            self.bump();
387            lhs = Self::binary_op(lhs, self.xor(), |a, b| Ok((a | b, false)));
388        }
389
390        lhs
391    }
392
393    fn logical_and(&mut self) -> Option<Result<i32, ()>> {
394        let mut lhs = self.or();
395
396        while self.peek_kind().map(|kind| kind == AND_OP).unwrap_or(false) {
397            self.bump();
398            lhs = Self::binary_op(lhs, self.or(), |a, b| {
399                Ok((if a != 0 && b != 0 { 1 } else { 0 }, false))
400            });
401        }
402
403        lhs
404    }
405
406    fn logical_or(&mut self) -> Option<Result<i32, ()>> {
407        let mut lhs = self.logical_and();
408
409        while self.peek_kind().map(|kind| kind == OR_OP).unwrap_or(false) {
410            self.bump();
411            lhs = Self::binary_op(lhs, self.logical_and(), |a, b| {
412                Ok((if a != 0 || b != 0 { 1 } else { 0 }, false))
413            });
414        }
415
416        lhs
417    }
418
419    fn expr(&mut self) -> Option<Result<i32, ()>> {
420        self.logical_or()
421    }
422
423    fn next_result(&mut self) -> Option<EvalResult> {
424        match self.expr() {
425            Some(value) => Some(EvalResult::Constant(value)),
426            None => Some(EvalResult::Token(self.bump().cloned()?)),
427        }
428    }
429}
430
431impl<'i, I: Iterator<Item = &'i OutputToken>> Iterator for ExprEvaluator<'i, I> {
432    type Item = EvalResult;
433
434    fn next(&mut self) -> Option<Self::Item> {
435        self.next_result()
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use std::str::FromStr;
442
443    use lang_util::FileId;
444
445    use crate::{
446        parser::SyntaxKind,
447        processor::{
448            event::Event,
449            nodes::{Define, DefineObject},
450            ProcessorState,
451        },
452    };
453
454    use super::ExprEvaluator;
455
456    use self::EvalResult::*;
457
458    /// Wrapper structure to compare token kinds with PartialEq instead of tokens
459    #[derive(Debug, Clone, PartialEq)]
460    enum EvalResult {
461        Constant(Result<i32, ()>),
462        Token(SyntaxKind),
463    }
464
465    fn eval(input: &str) -> Vec<EvalResult> {
466        // Parse the input token sequence
467        let tokens: Vec<_> = crate::processor::str::process(input, ProcessorState::default())
468            .filter_map(|evt| evt.ok().and_then(Event::into_token))
469            .collect();
470
471        // Processor state for evaluation: outside of unit testing, this is provided by the
472        // processor
473        let mut eval_state = ProcessorState::default();
474        eval_state.definition(
475            Define::object(
476                "IS_DEFINED".into(),
477                DefineObject::from_str("1").unwrap(),
478                false,
479            ),
480            FileId::default(),
481        );
482
483        // Evaluate
484        ExprEvaluator::new(tokens.iter(), &eval_state)
485            .map(|result| match result {
486                super::EvalResult::Constant(value) => Constant(value),
487                super::EvalResult::Token(token) => Token(token.kind()),
488            })
489            .collect()
490    }
491
492    #[test]
493    fn test_parenthesis() {
494        assert_eq!(&eval("2 + 3 * 4"), &[Constant(Ok(14))]);
495        assert_eq!(&eval("(2 + 3) * 4"), &[Constant(Ok(20))]);
496        assert_eq!(&eval("(((2) + (3)) * (4))"), &[Constant(Ok(20))]);
497    }
498
499    #[test]
500    fn test_primary() {
501        assert_eq!(&eval("0"), &[Constant(Ok(0))]);
502        assert_eq!(&eval("1"), &[Constant(Ok(1))]);
503        assert_eq!(&eval("FOO"), &[Constant(Ok(0))]);
504    }
505
506    #[test]
507    fn test_unary() {
508        assert_eq!(&eval("+0"), &[Constant(Ok(0))]);
509        assert_eq!(&eval("+1"), &[Constant(Ok(1))]);
510        assert_eq!(&eval("+FOO"), &[Constant(Ok(0))]);
511
512        assert_eq!(&eval("-0"), &[Constant(Ok(0))]);
513        assert_eq!(&eval("-1"), &[Constant(Ok(-1))]);
514        assert_eq!(&eval("-FOO"), &[Constant(Ok(0))]);
515
516        assert_eq!(&eval("~0"), &[Constant(Ok(!0))]);
517        assert_eq!(&eval("~1"), &[Constant(Ok(!1))]);
518        assert_eq!(&eval("~FOO"), &[Constant(Ok(!0))]);
519
520        assert_eq!(&eval("!0"), &[Constant(Ok(1))]);
521        assert_eq!(&eval("!1"), &[Constant(Ok(0))]);
522        assert_eq!(&eval("!FOO"), &[Constant(Ok(1))]);
523
524        assert_eq!(&eval("defined IS_DEFINED"), &[Constant(Ok(1))]);
525        assert_eq!(&eval("defined NOT_DEFINED"), &[Constant(Ok(0))]);
526        assert_eq!(&eval("defined ( IS_DEFINED )"), &[Constant(Ok(1))]);
527        assert_eq!(&eval("defined ( NOT_DEFINED )"), &[Constant(Ok(0))]);
528
529        assert_eq!(&eval("!defined IS_DEFINED"), &[Constant(Ok(0))]);
530        assert_eq!(&eval("!defined NOT_DEFINED"), &[Constant(Ok(1))]);
531        assert_eq!(&eval("!defined ( IS_DEFINED )"), &[Constant(Ok(0))]);
532        assert_eq!(&eval("!defined ( NOT_DEFINED )"), &[Constant(Ok(1))]);
533
534        // Invalid expressions
535        assert_eq!(&eval("defined +"), &[]);
536        assert_eq!(&eval("defined ( IS_DEFINED "), &[]);
537        assert_eq!(&eval("defined ( NOT_DEFINED "), &[]);
538    }
539
540    #[test]
541    fn test_multiplicative() {
542        assert_eq!(&eval("1 * 2"), &[Constant(Ok(2))]);
543        assert_eq!(&eval("2 * 3"), &[Constant(Ok(6))]);
544
545        assert_eq!(&eval("1 / 2"), &[Constant(Ok(0))]);
546        assert_eq!(&eval("2 / 3"), &[Constant(Ok(0))]);
547        assert_eq!(&eval("6 / 2"), &[Constant(Ok(3))]);
548        assert_eq!(&eval("1 / 0"), &[Constant(Err(()))]);
549
550        assert_eq!(&eval("1 % 2"), &[Constant(Ok(1))]);
551        assert_eq!(&eval("2 % 3"), &[Constant(Ok(2))]);
552        assert_eq!(&eval("6 % 2"), &[Constant(Ok(0))]);
553        assert_eq!(&eval("1 % 0"), &[Constant(Err(()))]);
554    }
555
556    #[test]
557    fn test_additive() {
558        assert_eq!(&eval("1 + 2"), &[Constant(Ok(3))]);
559        assert_eq!(&eval("2 + 3"), &[Constant(Ok(5))]);
560
561        assert_eq!(&eval("1 - 2"), &[Constant(Ok(-1))]);
562        assert_eq!(&eval("2 - 3"), &[Constant(Ok(-1))]);
563        assert_eq!(&eval("6 - 2"), &[Constant(Ok(4))]);
564        assert_eq!(&eval("1 - 0"), &[Constant(Ok(1))]);
565    }
566
567    #[test]
568    fn test_shift() {
569        assert_eq!(&eval("1 << 2"), &[Constant(Ok(4))]);
570        assert_eq!(&eval("2 << 3"), &[Constant(Ok(16))]);
571
572        assert_eq!(&eval("1 >> 2"), &[Constant(Ok(0))]);
573        assert_eq!(&eval("2 >> 3"), &[Constant(Ok(0))]);
574        assert_eq!(&eval("6 >> 2"), &[Constant(Ok(1))]);
575        assert_eq!(&eval("1 >> 0"), &[Constant(Ok(1))]);
576    }
577
578    #[test]
579    fn test_relational() {
580        assert_eq!(&eval("1 < 2"), &[Constant(Ok(1))]);
581        assert_eq!(&eval("2 < 1"), &[Constant(Ok(0))]);
582        assert_eq!(&eval("2 < 2"), &[Constant(Ok(0))]);
583
584        assert_eq!(&eval("1 > 2"), &[Constant(Ok(0))]);
585        assert_eq!(&eval("2 > 1"), &[Constant(Ok(1))]);
586        assert_eq!(&eval("2 > 2"), &[Constant(Ok(0))]);
587
588        assert_eq!(&eval("1 <= 2"), &[Constant(Ok(1))]);
589        assert_eq!(&eval("2 <= 1"), &[Constant(Ok(0))]);
590        assert_eq!(&eval("2 <= 2"), &[Constant(Ok(1))]);
591
592        assert_eq!(&eval("1 >= 2"), &[Constant(Ok(0))]);
593        assert_eq!(&eval("2 >= 1"), &[Constant(Ok(1))]);
594        assert_eq!(&eval("2 >= 2"), &[Constant(Ok(1))]);
595    }
596
597    #[test]
598    fn test_equality() {
599        assert_eq!(&eval("2 == 1"), &[Constant(Ok(0))]);
600        assert_eq!(&eval("2 == 2"), &[Constant(Ok(1))]);
601
602        assert_eq!(&eval("1 != 2"), &[Constant(Ok(1))]);
603        assert_eq!(&eval("2 != 2"), &[Constant(Ok(0))]);
604    }
605
606    #[test]
607    fn test_and() {
608        assert_eq!(&eval("2 & 1"), &[Constant(Ok(0))]);
609        assert_eq!(&eval("3 & 2"), &[Constant(Ok(2))]);
610    }
611
612    #[test]
613    fn test_xor() {
614        assert_eq!(&eval("2 ^ 1"), &[Constant(Ok(3))]);
615        assert_eq!(&eval("3 ^ 2"), &[Constant(Ok(1))]);
616    }
617
618    #[test]
619    fn test_or() {
620        assert_eq!(&eval("2 | 1"), &[Constant(Ok(3))]);
621        assert_eq!(&eval("3 | 2"), &[Constant(Ok(3))]);
622    }
623
624    #[test]
625    fn test_logical_and() {
626        assert_eq!(&eval("2 && 0"), &[Constant(Ok(0))]);
627        assert_eq!(&eval("3 && 2"), &[Constant(Ok(1))]);
628        assert_eq!(&eval("0 && 2"), &[Constant(Ok(0))]);
629        assert_eq!(&eval("0 && 0"), &[Constant(Ok(0))]);
630    }
631
632    #[test]
633    fn test_logical_or() {
634        assert_eq!(&eval("2 || 0"), &[Constant(Ok(1))]);
635        assert_eq!(&eval("3 || 2"), &[Constant(Ok(1))]);
636        assert_eq!(&eval("0 || 2"), &[Constant(Ok(1))]);
637        assert_eq!(&eval("0 || 0"), &[Constant(Ok(0))]);
638    }
639
640    #[test]
641    fn test_overflow() {
642        assert_eq!(&eval("1 << 60"), &[Constant(Err(()))]);
643    }
644}