Skip to content

Commit

Permalink
Restores numeric values (OpenBB-finance#4183)
Browse files Browse the repository at this point in the history
* fix tests

* Revert "fix tests"

This reverts commit f013668.

* Revert "Revert "fix tests""

This reverts commit b49e633.

* Revert "fix tests"

This reverts commit f013668.

* revert_lambda_long_number_format

* enhanced the function to remove the formatting plus some cleaning

* tests

* minor adjustment in commnet

* fix introduced bug by parsing dates

* resolve conflicting tests

---------

Co-authored-by: james <[email protected]>
  • Loading branch information
hjoaquim and jmaslek authored Feb 9, 2023
1 parent cf84504 commit f143f83
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 146 deletions.
207 changes: 63 additions & 144 deletions openbb_terminal/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import random
import re
import sys
import types
import urllib.parse
import webbrowser
from collections.abc import Iterable
from datetime import (
date as d,
datetime,
Expand Down Expand Up @@ -577,6 +575,15 @@ def valid_date(s: str) -> datetime:
raise argparse.ArgumentTypeError(f"Not a valid date: {s}") from value_error


def is_valid_date(s: str) -> bool:
"""Check if date is in valid format."""
try:
datetime.strptime(s, "%Y-%m-%d")
return True
except ValueError:
return False


def valid_repo(repo: str) -> str:
"""Argparse type to check github repo is in valid format."""
result = re.search(r"^[a-zA-Z0-9-_.]+\/[a-zA-Z0-9-_.]+$", repo) # noqa: W605
Expand All @@ -600,85 +607,6 @@ def valid_hour(hr: str) -> int:
return new_hr


def plot_view_stock(df: pd.DataFrame, symbol: str, interval: str):
"""Plot the loaded stock dataframe.
Parameters
----------
df: Dataframe
Dataframe of prices and volumes
symbol: str
Symbol of ticker
interval: str
Stock data resolution for plotting purposes
"""
df.sort_index(ascending=True, inplace=True)
bar_colors = ["r" if x[1].Open < x[1].Close else "g" for x in df.iterrows()]

try:
fig, ax = plt.subplots(
2,
1,
gridspec_kw={"height_ratios": [3, 1]},
figsize=plot_autoscale(),
dpi=cfgPlot.PLOT_DPI,
)
except Exception as e:
console.print(e)
console.print(
"Encountered an error trying to open a chart window. Check your X server configuration."
)
logging.exception("%s", type(e).__name__)
return

# In order to make nice Volume plot, make the bar width = interval
if interval == "1440min":
bar_width = timedelta(days=1)
title_string = "Daily"
else:
bar_width = timedelta(minutes=int(interval.split("m")[0]))
title_string = f"{int(interval.split('m')[0])} min"

ax[0].yaxis.tick_right()
if "Adj Close" in df.columns:
ax[0].plot(df.index, df["Adj Close"], c=cfgPlot.VIEW_COLOR)
else:
ax[0].plot(df.index, df["Close"], c=cfgPlot.VIEW_COLOR)
ax[0].set_xlim(df.index[0], df.index[-1])
ax[0].set_xticks([])
ax[0].yaxis.set_label_position("right")
ax[0].set_ylabel("Share Price ($)")
ax[0].grid(axis="y", color="gainsboro", linestyle="-", linewidth=0.5)

ax[0].spines["top"].set_visible(False)
ax[0].spines["left"].set_visible(False)
ax[1].bar(
df.index, df.Volume / 1_000_000, color=bar_colors, alpha=0.8, width=bar_width
)
ax[1].set_xlim(df.index[0], df.index[-1])
ax[1].yaxis.tick_right()
ax[1].yaxis.set_label_position("right")
ax[1].set_ylabel("Volume [1M]")
ax[1].grid(axis="y", color="gainsboro", linestyle="-", linewidth=0.5)
ax[1].spines["top"].set_visible(False)
ax[1].spines["left"].set_visible(False)
ax[1].set_xlabel("Time")
fig.suptitle(
symbol + " " + title_string,
size=20,
x=0.15,
y=0.95,
fontfamily="serif",
fontstyle="italic",
)
if obbff.USE_ION:
plt.ion()
fig.tight_layout(pad=2)
plt.setp(ax[1].get_xticklabels(), rotation=20, horizontalalignment="right")

plt.show()


def us_market_holidays(years) -> list:
"""Get US market holidays."""
if isinstance(years, int):
Expand Down Expand Up @@ -770,6 +698,58 @@ def lambda_long_number_format(num, round_decimal=3) -> str:
return num


def revert_lambda_long_number_format(num_str: str) -> Union[float, str]:
"""
Revert the formatting of a long number if the input is a formatted number, otherwise return the input as is.
Parameters
----------
num_str : str
The number to remove the formatting.
Returns
-------
Union[float, str]
The number as float (with no formatting) or the input as is.
"""
magnitude_dict = {
"K": 1000,
"M": 1000000,
"B": 1000000000,
"T": 1000000000000,
"P": 1000000000000000,
}

# Ensure the input is a string and not empty
if not num_str or not isinstance(num_str, str):
return num_str

num_as_list = num_str.strip().split()

# If the input string is a number parse it as float
if (
len(num_as_list) == 1
and num_as_list[0].replace(".", "").replace("-", "").isdigit()
and not is_valid_date(num_str)
):
return float(num_str)

# If the input string is a formatted number with magnitude
if (
len(num_as_list) == 2
and num_as_list[1] in magnitude_dict
and num_as_list[0].replace(".", "").replace("-", "").isdigit()
):
num, unit = num_as_list
magnitude = magnitude_dict.get(unit)
if magnitude:
return float(num) * magnitude

# Return the input string as is if it's not a formatted number
return num_str


def lambda_long_number_format_y_axis(df, y_column, ax):
"""Format long number that goes onto Y axis."""
max_values = df[y_column].values.max()
Expand Down Expand Up @@ -1358,6 +1338,8 @@ def export_data(
regex=True,
)

df = df.applymap(revert_lambda_long_number_format)

if exp_type.endswith("csv"):
exists, overwrite = ask_file_overwrite(saved_path)
if exists and not overwrite:
Expand Down Expand Up @@ -1551,69 +1533,6 @@ def camel_case_split(string: str) -> str:
return " ".join(results).title()


def choice_check_after_action(action=None, choices=None):
"""Return an action class that checks choice after action call.
Does that for argument of argparse.ArgumentParser.add_argument function.
Parameters
----------
action : Union[class, function]
Action for set args before check choices.
If action is class, it must implement argparse.Action methods
If action is function, it takes 4 args(parser, namespace, values, option_string)
and needs to return value to set dest
choices : Union[Iterable, function]
A container of values that should be allowed.
If choices is function, it takes 1 args(value) to check and
return bool that value is allowed or not
Returns
-------
Class
Class extended argparse.Action
"""
if isinstance(choices, Iterable):

def choice_checker(value):
return value in choices

elif isinstance(choices, types.FunctionType):
choice_checker = choices
else:
raise NotImplementedError("choices argument must be iterable or function")

if isinstance(action, type):

class ActionClass(action):
def __call__(self, parser, namespace, values, option_string=None):
super().__call__(parser, namespace, values, option_string)
if not choice_checker(getattr(namespace, self.dest)):
raise ValueError(
f"{getattr(namespace, self.dest)} is not in {choices}"
)

elif isinstance(action, types.FunctionType):

class ActionClass(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(
namespace,
self.dest,
action(parser, namespace, values, option_string),
)
if not choice_checker(getattr(namespace, self.dest)):
raise ValueError(
f"{getattr(namespace, self.dest)} is not in {choices}"
)

else:
raise NotImplementedError("action argument must be class or function")

return ActionClass


def is_valid_axes_count(
axes: List[plt.Axes],
n: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ interactions:
vary:
- Origin,Accept-Encoding
x-envoy-decorator-operation:
- finance-chart-api--mtls-production-ir2.finance-k8s.svc.yahoo.local:4080/*
- finance-chart-api--mtls-canary-production-ir2.finance-k8s.svc.yahoo.local:4080/*
x-envoy-upstream-service-time:
- '7'
- '5'
x-request-id:
- 78ec0f7f-22e6-4cb9-9a7f-90991e11590e
x-yahoo-request-id:
Expand Down
36 changes: 36 additions & 0 deletions tests/openbb_terminal/test_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
check_start_less_than_end,
export_data,
remove_timezone_from_dataframe,
revert_lambda_long_number_format,
)

# pylint: disable=E1101
Expand Down Expand Up @@ -156,3 +157,38 @@ def test_remove_timezone_from_dataframe(df, df_expected):

df_result = remove_timezone_from_dataframe(df)
assert df_result.equals(df_expected)


@pytest.mark.parametrize(
"value, expected",
[
(123, 123),
("xpto", "xpto"),
(
"this isssssssssss a veryyyyy long stringgg",
"this isssssssssss a veryyyyy long stringgg",
),
(None, None),
(True, True),
(0, 0),
("2022-01-01", "2022-01-01"),
("3/9/2022", "3/9/2022"),
("2022-03-09 10:30:00", "2022-03-09 10:30:00"),
("a 123", "a 123"),
([1, 2, 3], [1, 2, 3]),
("", ""),
("-3 K", -3000.0),
("-99 M", -99000000.0),
("-125 B", -125000000000.0),
("-15 T", -15000000000000.0),
("-15 P", -15000000000000000.0),
("-15 P xpto", "-15 P xpto"),
("-15 P 3 K", "-15 P 3 K"),
("15 P -3 K", "15 P -3 K"),
("2.130", 2.130),
("2,130.000", "2,130.000"), # this is not a valid number
("674,234.99", "674,234.99"), # this is not a valid number
],
)
def test_revert_lambda_long_number_format(value, expected):
assert revert_lambda_long_number_format(value) == expected

0 comments on commit f143f83

Please sign in to comment.