Skip to content

Commit

Permalink
fix: tests and linter
Browse files Browse the repository at this point in the history
  • Loading branch information
fengkx committed Oct 14, 2023
1 parent 14513ff commit 17a97df
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 63 deletions.
57 changes: 30 additions & 27 deletions src/fava/beans/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import defaultdict
from decimal import Decimal
from itertools import groupby
from typing import Callable
from typing import Iterable
from typing import Sequence
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -151,15 +152,18 @@ def get_nested_price(
base_quote: BaseQuote,
date: datetime.date | None = None,
) -> Decimal | None:
all_commodity = [tuple[0] for tuple in list(self._map.keys())]
all_commodity = [
commodity_pairs[0] for commodity_pairs in list(self._map.keys())
]
groups = groupby(self._map.keys(), key=lambda x: x[0])
conversions: dict[str, set[str]] = {}
# Iterate over the groups

for key, group in groups:
# Extract the second element of each tuple in the group and add it to the dictionary as the value for the key
if key not in conversions:
conversions[key] = set()
conversions[key].update((tuple[1] for tuple in group))
conversions[key].update(
commodity_pair[1] for commodity_pair in group
)

bellman_ford = BellmanFord(
all_commodity,
Expand All @@ -179,40 +183,49 @@ def get_nested_price(


class BellmanFord:
"""An Bellman-Ford algorithm implementation.
optimized and modified for currency conversion scenarios.
"""

def __init__(
self,
all_nodes: list[str],
edges: dict[str, set[str]],
get_widget,
start,
):
get_widget: Callable[[BaseQuote], Decimal | None],
start: str,
) -> None:
self.searched = False
self.get_widget = get_widget
self.start = start
self.all_nodes = all_nodes
self.edges = edges

table: dict[str, (Decimal, str | None)] = {}
table: dict[str, tuple[Decimal, str | None]] = {}
# init table
for n in all_nodes:
table[n] = (Decimal("Infinity"), None)
self.table = table

def _get_path(self, end_node):
def _get_path(self, end_node: str) -> Iterable[str]:
if self.searched is False:
self.search()
r = [end_node]
from_record = self.table.get(end_node)
if from_record is None:
return r
while from_record[1] is not None and from_record[1] is not self.start:
while (
from_record is not None
and from_record[1] is not None
and from_record[1] is not self.start
):
if from_record[1] in r:
return reversed(r)
r.append(from_record[1])
from_record = self.table.get(from_record[1])
return reversed(r)

def update_table(self):
def update_table(self) -> bool:
updated = False
for from_node in self.all_nodes:
if self.table[from_node][0] == Decimal("Infinity"):
Expand All @@ -223,14 +236,6 @@ def update_table(self):
if widget is None:
continue
target_value = self.table[from_node][0] * widget
# print(
# (
# # target_value,
# from_node,
# to,
# self.get_widget((from_node, to)),
# )
# )
if to == self.start:
continue
if target_value.compare(self.table[to][0]) < 0 and (
Expand All @@ -241,20 +246,18 @@ def update_table(self):
continue
self.table[to] = (target_value, from_node)

# print(
# f"from {from_node} to {to} m"
# f" {target_value} {self.table[to][0]} {(target_value - self.table[to][0]).compare(Decimal(0.001)) == 1}"
# )
updated = True

return updated

def print_table(self):
def print_table(self) -> None:
heads = sorted(self.table.keys())
print("\n".join([str([t, self.table[t]]) for t in heads]))
print("+++++++++++++++++++++++++")
print( # noqa: T201
"\n".join([str([t, self.table[t]]) for t in heads]),
)
print("+++++++++++++++++++++++++") # noqa: T201

def search(self):
def search(self) -> dict[str, tuple[Decimal, str | None]]:
if self.searched is True:
return self.table
self.searched = True
Expand Down
2 changes: 0 additions & 2 deletions src/fava/core/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def convert_position(
# try the direct conversion
base_quote = (units_.currency, target_currency)
price_number = prices.get_nested_price(base_quote, date)
print("=====PRICE====", base_quote, price_number)
if price_number is not None:
return create.amount((units_.number * price_number, target_currency))

Expand Down Expand Up @@ -142,5 +141,4 @@ def cost_or_value(
return inventory.reduce(get_market_value, prices, date)
if conversion == "units":
return inventory.reduce(get_units)
print("========conversion", conversion)
return inventory.reduce(convert_position, conversion, prices, date)
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[
{
"balance": {
"USD": 100
"USD": 100.00
},
"date": "2000-01-01"
},
{
"balance": {
"USD": 50,
"USD": 50.00,
"XYZ": 1
},
"date": "2000-01-02"
Expand Down
Loading

0 comments on commit 17a97df

Please sign in to comment.