Skip to content

Instantly share code, notes, and snippets.

@aaronj1335
Created May 25, 2024 13:07
Show Gist options
  • Save aaronj1335/8b33d43b266f898c7f4334b55403732e to your computer and use it in GitHub Desktop.
Save aaronj1335/8b33d43b266f898c7f4334b55403732e to your computer and use it in GitHub Desktop.
import datetime
import sys
from collections import defaultdict
from datetime import date
from decimal import Decimal
from functools import cache
from pathlib import Path
from typing import (Dict, Generator, Iterable, NamedTuple, Optional, Sequence,
Set, Tuple, Union)
from beancount.core.data import (Account, Commodity, Currency, Entries, Event,
Open, Options, Price, Transaction)
from beancount.core.inventory import Inventory
from beancount.core.position import Position
from beancount.loader import load_file
from beancount.parser.printer import print_errors
from .data import amount
IGNORE_COST_BASIS: Set[str] = set(['USDINVESTIOU'])
# This is a common type for calculations that need to iterate over the
# transactions in the ledger for various reporting needs.
TransactionAndPrices = Tuple[
Transaction,
Dict[Account, Inventory],
Dict[Account, float], # Account -> Value mapping
Dict[Currency, float]] # Currency -> Price mapping
class BalanceForAccounts(NamedTuple):
transaction: Transaction
units: Decimal
currency: Currency
transaction_cost: float
total_cost: float
total_price: float
class Ledger:
def __init__(
self,
source: Union[str, Path],
date: Optional[datetime.date] = None,
strict=True,
):
self.source = Path(source)
if date is None:
date = datetime.date.today()
self.date = date
self._is_strict = strict
@cache
def _load(self) -> Tuple[Entries, Sequence, Options]:
entries, errors, options = load_file(str(self.source))
if errors and self._is_strict:
raise Exception(errors)
elif errors:
self._errors = errors
print_errors(self._errors, file=sys.stderr)
else:
self._errors = tuple()
return (entries, errors, options)
@property
@cache
def contents(self) -> str:
"""Return the contents of the ledger as a string."""
with open(self.source) as source_file:
return source_file.read()
@property
def errors(self) -> Sequence:
self._load()
return self._errors[:]
@property
def entries(self) -> Entries:
return self._load()[0]
@property
@cache
def accounts(self) -> list[Account]:
return [o.account for o in self.opens]
@property
@cache
def commodities(self) -> list[Commodity]:
return [d for d in self.entries if isinstance(d, Commodity)]
@property
@cache
def prices(self) -> list[Price]:
return sorted((d for d in self.entries if isinstance(d, Price)),
key=lambda d: d.date)
@property
@cache
def opens(self) -> list[Open]:
return [d for d in self.entries if isinstance(d, Open)]
@property
@cache
def events(self) -> list[Event]:
return [e for e in self.entries if isinstance(e, Event)]
@property
@cache
def transactions(self) -> list[Transaction]:
return sorted((d for d in self.entries if isinstance(d, Transaction)),
key=lambda d: d.date)
@cache
def transactions_with_balances(
self
) -> list[Tuple[Transaction, Dict[Account, Inventory]]]:
return list(self._transactions_with_balances())
def _transactions_with_balances(
self,
) -> Generator[Tuple[Transaction, Dict[Account, Inventory]], None, None]:
"""
Returns transactions in chronological order with an inventory for each
account.
"""
inventories = defaultdict(lambda: Inventory())
for transaction in self.transactions:
for posting in transaction.postings:
copied = inventories[posting.account].__copy__()
inventories[posting.account] = copied
inventories[posting.account].add_position(posting)
yield (transaction,
defaultdict(lambda: Inventory(),
{k: v for k, v in inventories.items()}))
@cache
def transactions_with_prices(self) -> Sequence[TransactionAndPrices]:
"""
Returns transactions in chronological order with an inventory and
market value in USD of each account.
The market value is determined from price directives in the ledger.
This determines value by multiplying the quantity in the inventory by
the price directive closest to the date of the transaction.
"""
return list(self._transactions_with_prices())
def _transactions_with_prices(self) -> Generator[TransactionAndPrices, None, None]:
account_columns = [d.account for d in self.opens]
prices_by_currency = defaultdict(lambda: [])
for price in self.prices:
prices_by_currency[price.currency].append(price)
price_indices = defaultdict(lambda: 0)
def market_value_and_prices(inventory, date):
prices = {}
value = 0.0
for currency in inventory.currencies():
number = 1
if currency != 'USD':
if (currency not in prices_by_currency or
not prices_by_currency[currency]):
continue
current_index = price_indices[currency]
next_index = current_index + 1
current_price = prices_by_currency[currency][current_index]
next_price = None
if len(prices_by_currency[currency]) > next_index:
next_price = prices_by_currency[currency][next_index]
if (next_price is not None and
abs(transaction.date - current_price.date) >=
abs(transaction.date - next_price.date)):
current_price = next_price
price_indices[currency] += 1
number = current_price.amount.number
prices[currency] = current_price.amount.number
value += float(
inventory.get_currency_units(currency).number * number)
return value, prices
items = self.transactions_with_balances()
for transaction, balances in items:
values = {}
prices = {}
for account in account_columns:
value, account_prices = market_value_and_prices(
balances[account], transaction.date)
values[account] = value
prices.update(account_prices)
yield (
transaction,
balances,
{k: v for k, v in values.items()},
{k: v for k, v in prices.items()})
def transactions_for_accounts(
self, accounts: Iterable[Account]
) -> Generator[TransactionAndPrices, None, None]:
"""
Returns a list of transactions for the given accounts.
This basically filters the output of transactions_with_prices to those
which involve the given accounts.
Arguments:
accounts: Iterable of either a str or Account (it's the same)
"""
for transaction, inventories, values, prices \
in self.transactions_with_prices():
for posting in transaction.postings:
if posting.account in accounts:
yield transaction, inventories, values, prices
break
def balances(
self,
accounts: Sequence[Account],
start: Optional[date] = None,
end: Optional[date] = None,
) -> Generator[BalanceForAccounts, None, None]:
"""
Returns the transaction amount, total cost, and total price of the given
accounts.
If only one account is provided that just has USD, this would be similar to
a bank statement with running account balance.
With an investment account that has shares of some asset, it's a useful way
to see returns over time.
This sums the values for all accounts given.
The date range is "inclusive", i.e. if the transaction date is == start or
the date == end, the transaction is included.
"""
items = self.transactions_for_accounts(accounts)
previous_cost = 0.0
previous_price = 0.0
starting_cost_and_price: Optional[Tuple[float, float]] = None
for transaction, inventories, values, _ in items:
cost = 0.0
for account in accounts:
this_cost = cost_basis(inventories[account])
if this_cost[1]:
sys.stderr.write(
'Position without cost for {}: {}\n'.format(
account, this_cost[1]))
cost += this_cost[0]
postings = [p for p in transaction.postings if p.account in accounts]
# Transactions with multiple currencies in one account may occur with
# in-kind transfers.
#
# In this case we'll just show the first currency. This
# isn't all of the information, but we can't have a variable number of
# columns for units, and this is probably sufficient.
currency = [(p.units or amount(0)).currency for p in postings][0]
units = Decimal(0)
for posting in postings:
posting_units = posting.units
if posting_units is not None and posting_units.currency == currency:
number = posting_units.number
if number is not None:
units += number
price = sum(values[a] for a in accounts)
if transaction_is_in_window(transaction, start, end):
if starting_cost_and_price is None:
starting_cost_and_price = (previous_cost, previous_price)
yield BalanceForAccounts(transaction,
units,
currency,
cost - previous_cost,
cost - starting_cost_and_price[0],
price - starting_cost_and_price[1])
previous_cost = cost
previous_price = price
def transaction_is_in_window(
transaction: Transaction, start: Optional[date], end: Optional[date]
) -> bool:
"""
Returns true if transaction is in window [start, end].
This is "inclusive", i.e. if the transaction date is == start or the date ==
end, the transaction is included.
"""
is_after_start = start is None or transaction.date >= start
is_before_end = end is None or transaction.date <= end
return is_after_start and is_before_end
def cost_basis(
inventory: Inventory,
) -> Tuple[
float, # Cost
Optional[Sequence[Position]], # List of positions without cost, if any
]:
"""Calculates the cost basis and possibly leftover units without cost."""
cost = 0.0
positions_without_cost = None
for position in inventory:
if position.cost is not None:
cost += float(position.cost.number * position.units.number)
elif position.units.currency == 'USD':
cost += float(position.units.number)
elif position.units.currency in IGNORE_COST_BASIS:
continue
else:
if position.units.number != 0.:
if positions_without_cost is None:
positions_without_cost = []
positions_without_cost.append(position)
return (cost, positions_without_cost)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment