Last active
May 8, 2024 18:21
-
-
Save mymmrac/516a0a755bb24975ed953c6acdd1666b to your computer and use it in GitHub Desktop.
Math expression interpreter in Python (supports unary and binary oprations)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import math | |
from enum import Enum | |
from typing import List | |
class TokenType(Enum): | |
LPAREN = 1 | |
RPAREN = 2 | |
PLUS = 3 | |
MINUS = 4 | |
MULTIPLY = 5 | |
DIVIDE = 6 | |
DIVIDE_INT = 7 | |
MODULO = 8 | |
POWER = 9 | |
FUNC = 10 | |
UNARY_PLUS = 11 | |
UNARY_MINUS = 12 | |
NUMBER = 13 | |
class Token: | |
def __init__(self, typ: TokenType, text: str, pos: int): | |
self.type = typ | |
self.text = text | |
self.pos = pos | |
def is_operand(self) -> bool: | |
return self.type == TokenType.NUMBER | |
def is_unary_operator(self) -> bool: | |
return self.type in { | |
TokenType.UNARY_PLUS, TokenType.UNARY_MINUS, TokenType.FUNC | |
} | |
def is_binary_operator(self) -> bool: | |
return self.type in { | |
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY, TokenType.DIVIDE, TokenType.DIVIDE_INT, | |
TokenType.MODULO, TokenType.POWER, | |
} | |
def precedence(self): | |
match self.type: | |
case TokenType.PLUS | TokenType.MINUS: | |
return 1 | |
case TokenType.MULTIPLY | TokenType.DIVIDE | TokenType.DIVIDE_INT | TokenType.MODULO: | |
return 2 | |
case TokenType.UNARY_PLUS | TokenType.UNARY_MINUS: | |
return 3 | |
case TokenType.POWER: | |
return 4 | |
case TokenType.FUNC: | |
return 5 | |
case _: | |
raise ValueError(f"Unexpected token `{self.text}` at {self.pos}") | |
def number(self) -> float: | |
if self.type == TokenType.NUMBER: | |
return float(self.text) | |
else: | |
raise ValueError(f"Expected a number, but got `{self.text}` at {self.pos}") | |
def apply(self, values: List[Token]): | |
result: float | |
if self.is_unary_operator(): | |
v1 = self.get_values(values, 1)[0] | |
match self.type: | |
case TokenType.UNARY_PLUS: | |
result = v1.number() | |
case TokenType.UNARY_MINUS: | |
result = -v1.number() | |
case TokenType.FUNC: | |
match self.text: | |
case "sin": | |
result = math.sin(v1.number()) | |
case "cos": | |
result = math.cos(v1.number()) | |
case "round": | |
result = round(v1.number()) | |
case _: | |
raise ValueError(f"Unknown function `{self.text}` at {self.pos}") | |
case _: | |
raise ValueError(f"Unhandled unary operator `{self.text}` at {self.pos}") | |
elif self.is_binary_operator(): | |
v1, v2 = self.get_values(values, 2) | |
match self.type: | |
case TokenType.PLUS: | |
result = v1.number() + v2.number() | |
case TokenType.MINUS: | |
result = v1.number() - v2.number() | |
case TokenType.MULTIPLY: | |
result = v1.number() * v2.number() | |
case TokenType.DIVIDE: | |
result = v1.number() / v2.number() | |
case TokenType.DIVIDE_INT: | |
result = v1.number() // v2.number() | |
case TokenType.MODULO: | |
result = v1.number() % v2.number() | |
case TokenType.POWER: | |
result = v1.number() ** v2.number() | |
case _: | |
raise ValueError(f"Unhandled binary operator `{self.text}` at {self.pos}") | |
else: | |
raise ValueError(f"Unexpected token `{self.text}` at {self.pos}") | |
values.append(Token(TokenType.NUMBER, str(result), -1)) | |
def get_values(self, values: List[Token], n: int) -> List[Token]: | |
if len(values) < n: | |
raise ValueError(f"Not enough values for for operator `{self.text}` at {self.pos}") | |
vs = values[-n:] | |
values[-n:] = [] | |
return vs | |
def __str__(self): | |
return f"Token({self.type}, {self.text}, {self.pos})" | |
def print_tokens(tokens: List[Token]): | |
print(" ".join(list(map(str, tokens)))) | |
def tokenize(text: str) -> List[Token]: | |
i = 0 | |
tokens: List[Token] = [] | |
while i < len(text): | |
if text[i] == ' ': | |
i += 1 | |
elif text[i] == '(': | |
tokens.append(Token(TokenType.LPAREN, '(', i)) | |
i += 1 | |
elif text[i] == ')': | |
tokens.append(Token(TokenType.RPAREN, ')', i)) | |
i += 1 | |
elif text[i] == '+': | |
tokens.append(Token(TokenType.PLUS, '+', i)) | |
i += 1 | |
elif text[i] == '-': | |
tokens.append(Token(TokenType.MINUS, '-', i)) | |
i += 1 | |
elif text[i] == '%': | |
tokens.append(Token(TokenType.MODULO, '%', i)) | |
i += 1 | |
elif text[i:].startswith("//"): | |
tokens.append(Token(TokenType.DIVIDE_INT, "//", i)) | |
i += 2 | |
elif text[i:].startswith("**"): | |
tokens.append(Token(TokenType.POWER, "**", i)) | |
i += 2 | |
elif text[i] == '*': | |
tokens.append(Token(TokenType.MULTIPLY, '*', i)) | |
i += 1 | |
elif text[i] == '/': | |
tokens.append(Token(TokenType.DIVIDE, '/', i)) | |
i += 1 | |
elif text[i].isdigit(): | |
j = i | |
has_dot = False | |
has_exp = False | |
has_exp_sign = False | |
while j < len(text) and ( | |
text[j].isdigit() or | |
(text[j] == '.' and not has_dot) or | |
(text[j] in "eE" and not has_exp) or | |
(text[j] in "+-" and not has_exp_sign) | |
): | |
if text[j] == '.': | |
has_dot = True | |
if text[j] == 'e' or text[j] == 'E': | |
has_exp = True | |
if text[j] == '+' or text[j] == '-': | |
has_exp_sign = True | |
j += 1 | |
tokens.append(Token(TokenType.NUMBER, text[i:j], i)) | |
i = j | |
else: | |
j = i | |
while j < len(text) and text[j].isalnum(): | |
j += 1 | |
if i == j: | |
raise ValueError(f"Invalid character `{text[i]}` at {i}") | |
tokens.append(Token(TokenType.FUNC, text[i:j], i)) | |
i = j | |
return tokens | |
def type_check(tokens: List[Token]) -> List[Token]: | |
l_values = 0 | |
open_parents = 0 | |
last_open_paren_pos = -1 | |
for i, token in enumerate(tokens): | |
match token.type: | |
case TokenType.LPAREN: | |
open_parents += 1 | |
last_open_paren_pos = i | |
case TokenType.RPAREN: | |
if open_parents == 0: | |
raise ValueError(f"Unexpected closing parenthesis at {token.pos}") | |
open_parents -= 1 | |
case TokenType.NUMBER: | |
l_values += 1 | |
case TokenType.PLUS | TokenType.MINUS: | |
if l_values == 0: | |
if i == len(tokens) - 1: | |
raise ValueError(f"Expected an expression after `{token.text}` at {token.pos}") | |
tokens[i].type = TokenType.UNARY_PLUS if token.type == TokenType.PLUS else TokenType.UNARY_MINUS | |
else: | |
l_values -= 1 | |
case TokenType.MULTIPLY | TokenType.DIVIDE | TokenType.DIVIDE_INT | TokenType.MODULO | TokenType.POWER: | |
if l_values == 0: | |
raise ValueError(f"Expected an expression after `{token.text}` at {token.pos}") | |
l_values -= 1 | |
case TokenType.FUNC: | |
pass | |
case _: | |
raise ValueError(f"Unexpected token `{token.text}` at {token.pos}") | |
if open_parents != 0: | |
raise ValueError(f"Unexpected opening parenthesis at {last_open_paren_pos}") | |
if l_values != 1: | |
raise ValueError(f"Expected an expression after `{tokens[-1].text}` at {tokens[-1].pos}") | |
return tokens | |
def parse(tokens: List[Token]) -> List[Token]: | |
stack: List[Token] = [] | |
output: List[Token] = [] | |
for i, token in enumerate(tokens): | |
if token.is_operand(): | |
output.append(token) | |
elif token.type == TokenType.LPAREN: | |
stack.append(token) | |
elif token.type == TokenType.RPAREN: | |
while stack and stack[-1].type != TokenType.LPAREN: | |
output.append(stack.pop()) | |
stack.pop() | |
elif token.is_unary_operator(): | |
if (i == 0 or tokens[i - 1].type == TokenType.LPAREN) and tokens[i + 1].is_operand(): | |
output.append(token) | |
else: | |
while stack and (stack[-1].type != TokenType.LPAREN) and (token.precedence() <= stack[-1].precedence()): | |
output.append(stack.pop()) | |
stack.append(token) | |
elif token.is_binary_operator(): | |
while stack and (stack[-1].type != TokenType.LPAREN) and (token.precedence() <= stack[-1].precedence()): | |
output.append(stack.pop()) | |
stack.append(token) | |
else: | |
raise ValueError(f"Unexpected token `{token.text}` at {token.pos}") | |
while stack: | |
output.append(stack.pop()) | |
return output | |
def evaluate(tokens: List[Token]) -> float: | |
values: List[Token] = [] | |
for token in tokens: | |
if token.is_operand(): | |
values.append(token) | |
else: | |
token.apply(values) | |
match len(values): | |
case 0: | |
raise ValueError("Empty expression") | |
case 1: | |
return values[0].number() | |
case _: | |
raise ValueError(f"Unexpected token `{values[-1].text}` at {values[-1].pos}") | |
def main(): | |
text = "-(round(1.4e-5 + 2 ** 3 * (4 - sin(-2.45) * cos(4)) - 2.3 // 2) % 16)" | |
# text = "1 + 2 * 3 ** 2 + (2 + 3) * 4" | |
print("Tokenize...") | |
tokens = tokenize(text) | |
print_tokens(tokens) | |
print("Type check...") | |
tokens = type_check(tokens) | |
print_tokens(tokens) | |
print("Parse...") | |
tokens = parse(tokens) | |
print_tokens(tokens) | |
print("Evaluate...") | |
result = evaluate(tokens) | |
print(f"{text} = {result}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment