diff --git a/openbb_terminal/helper_funcs.py b/openbb_terminal/helper_funcs.py index d1dc8360c4ef..f09cb1f14074 100644 --- a/openbb_terminal/helper_funcs.py +++ b/openbb_terminal/helper_funcs.py @@ -9,10 +9,8 @@ import random import re import sys -import types import urllib.parse import webbrowser -from collections.abc import Iterable from datetime import ( date as d, datetime, @@ -577,6 +575,15 @@ def valid_date(s: str) -> datetime: raise argparse.ArgumentTypeError(f"Not a valid date: {s}") from value_error +def is_valid_date(s: str) -> bool: + """Check if date is in valid format.""" + try: + datetime.strptime(s, "%Y-%m-%d") + return True + except ValueError: + return False + + def valid_repo(repo: str) -> str: """Argparse type to check github repo is in valid format.""" result = re.search(r"^[a-zA-Z0-9-_.]+\/[a-zA-Z0-9-_.]+$", repo) # noqa: W605 @@ -600,85 +607,6 @@ def valid_hour(hr: str) -> int: return new_hr -def plot_view_stock(df: pd.DataFrame, symbol: str, interval: str): - """Plot the loaded stock dataframe. - - Parameters - ---------- - df: Dataframe - Dataframe of prices and volumes - symbol: str - Symbol of ticker - interval: str - Stock data resolution for plotting purposes - """ - df.sort_index(ascending=True, inplace=True) - bar_colors = ["r" if x[1].Open < x[1].Close else "g" for x in df.iterrows()] - - try: - fig, ax = plt.subplots( - 2, - 1, - gridspec_kw={"height_ratios": [3, 1]}, - figsize=plot_autoscale(), - dpi=cfgPlot.PLOT_DPI, - ) - except Exception as e: - console.print(e) - console.print( - "Encountered an error trying to open a chart window. Check your X server configuration." - ) - logging.exception("%s", type(e).__name__) - return - - # In order to make nice Volume plot, make the bar width = interval - if interval == "1440min": - bar_width = timedelta(days=1) - title_string = "Daily" - else: - bar_width = timedelta(minutes=int(interval.split("m")[0])) - title_string = f"{int(interval.split('m')[0])} min" - - ax[0].yaxis.tick_right() - if "Adj Close" in df.columns: - ax[0].plot(df.index, df["Adj Close"], c=cfgPlot.VIEW_COLOR) - else: - ax[0].plot(df.index, df["Close"], c=cfgPlot.VIEW_COLOR) - ax[0].set_xlim(df.index[0], df.index[-1]) - ax[0].set_xticks([]) - ax[0].yaxis.set_label_position("right") - ax[0].set_ylabel("Share Price ($)") - ax[0].grid(axis="y", color="gainsboro", linestyle="-", linewidth=0.5) - - ax[0].spines["top"].set_visible(False) - ax[0].spines["left"].set_visible(False) - ax[1].bar( - df.index, df.Volume / 1_000_000, color=bar_colors, alpha=0.8, width=bar_width - ) - ax[1].set_xlim(df.index[0], df.index[-1]) - ax[1].yaxis.tick_right() - ax[1].yaxis.set_label_position("right") - ax[1].set_ylabel("Volume [1M]") - ax[1].grid(axis="y", color="gainsboro", linestyle="-", linewidth=0.5) - ax[1].spines["top"].set_visible(False) - ax[1].spines["left"].set_visible(False) - ax[1].set_xlabel("Time") - fig.suptitle( - symbol + " " + title_string, - size=20, - x=0.15, - y=0.95, - fontfamily="serif", - fontstyle="italic", - ) - if obbff.USE_ION: - plt.ion() - fig.tight_layout(pad=2) - plt.setp(ax[1].get_xticklabels(), rotation=20, horizontalalignment="right") - - plt.show() - - def us_market_holidays(years) -> list: """Get US market holidays.""" if isinstance(years, int): @@ -770,6 +698,58 @@ def lambda_long_number_format(num, round_decimal=3) -> str: return num +def revert_lambda_long_number_format(num_str: str) -> Union[float, str]: + """ + Revert the formatting of a long number if the input is a formatted number, otherwise return the input as is. + + Parameters + ---------- + num_str : str + The number to remove the formatting. + + Returns + ------- + Union[float, str] + The number as float (with no formatting) or the input as is. + + """ + magnitude_dict = { + "K": 1000, + "M": 1000000, + "B": 1000000000, + "T": 1000000000000, + "P": 1000000000000000, + } + + # Ensure the input is a string and not empty + if not num_str or not isinstance(num_str, str): + return num_str + + num_as_list = num_str.strip().split() + + # If the input string is a number parse it as float + if ( + len(num_as_list) == 1 + and num_as_list[0].replace(".", "").replace("-", "").isdigit() + and not is_valid_date(num_str) + ): + return float(num_str) + + # If the input string is a formatted number with magnitude + if ( + len(num_as_list) == 2 + and num_as_list[1] in magnitude_dict + and num_as_list[0].replace(".", "").replace("-", "").isdigit() + ): + num, unit = num_as_list + magnitude = magnitude_dict.get(unit) + if magnitude: + return float(num) * magnitude + + # Return the input string as is if it's not a formatted number + return num_str + + def lambda_long_number_format_y_axis(df, y_column, ax): """Format long number that goes onto Y axis.""" max_values = df[y_column].values.max() @@ -1358,6 +1338,8 @@ def export_data( regex=True, ) + df = df.applymap(revert_lambda_long_number_format) + if exp_type.endswith("csv"): exists, overwrite = ask_file_overwrite(saved_path) if exists and not overwrite: @@ -1551,69 +1533,6 @@ def camel_case_split(string: str) -> str: return " ".join(results).title() -def choice_check_after_action(action=None, choices=None): - """Return an action class that checks choice after action call. - - Does that for argument of argparse.ArgumentParser.add_argument function. - - Parameters - ---------- - action : Union[class, function] - Action for set args before check choices. - If action is class, it must implement argparse.Action methods - If action is function, it takes 4 args(parser, namespace, values, option_string) - and needs to return value to set dest - - choices : Union[Iterable, function] - A container of values that should be allowed. - If choices is function, it takes 1 args(value) to check and - return bool that value is allowed or not - - Returns - ------- - Class - Class extended argparse.Action - """ - if isinstance(choices, Iterable): - - def choice_checker(value): - return value in choices - - elif isinstance(choices, types.FunctionType): - choice_checker = choices - else: - raise NotImplementedError("choices argument must be iterable or function") - - if isinstance(action, type): - - class ActionClass(action): - def __call__(self, parser, namespace, values, option_string=None): - super().__call__(parser, namespace, values, option_string) - if not choice_checker(getattr(namespace, self.dest)): - raise ValueError( - f"{getattr(namespace, self.dest)} is not in {choices}" - ) - - elif isinstance(action, types.FunctionType): - - class ActionClass(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - setattr( - namespace, - self.dest, - action(parser, namespace, values, option_string), - ) - if not choice_checker(getattr(namespace, self.dest)): - raise ValueError( - f"{getattr(namespace, self.dest)} is not in {choices}" - ) - - else: - raise NotImplementedError("action argument must be class or function") - - return ActionClass - - def is_valid_axes_count( axes: List[plt.Axes], n: int, diff --git a/tests/openbb_terminal/stocks/fundamental_analysis/cassettes/test_business_insider_view/test_price_target_from_analysts_TSLA.yaml b/tests/openbb_terminal/stocks/fundamental_analysis/cassettes/test_business_insider_view/test_price_target_from_analysts_TSLA.yaml index 68683663c52b..2315a131c123 100644 --- a/tests/openbb_terminal/stocks/fundamental_analysis/cassettes/test_business_insider_view/test_price_target_from_analysts_TSLA.yaml +++ b/tests/openbb_terminal/stocks/fundamental_analysis/cassettes/test_business_insider_view/test_price_target_from_analysts_TSLA.yaml @@ -252,9 +252,9 @@ interactions: vary: - Origin,Accept-Encoding x-envoy-decorator-operation: - - finance-chart-api--mtls-production-ir2.finance-k8s.svc.yahoo.local:4080/* + - finance-chart-api--mtls-canary-production-ir2.finance-k8s.svc.yahoo.local:4080/* x-envoy-upstream-service-time: - - '7' + - '5' x-request-id: - 78ec0f7f-22e6-4cb9-9a7f-90991e11590e x-yahoo-request-id: diff --git a/tests/openbb_terminal/test_helper_funcs.py b/tests/openbb_terminal/test_helper_funcs.py index 5910f3c27290..c08515b2fafb 100644 --- a/tests/openbb_terminal/test_helper_funcs.py +++ b/tests/openbb_terminal/test_helper_funcs.py @@ -11,6 +11,7 @@ check_start_less_than_end, export_data, remove_timezone_from_dataframe, + revert_lambda_long_number_format, ) # pylint: disable=E1101 @@ -156,3 +157,38 @@ def test_remove_timezone_from_dataframe(df, df_expected): df_result = remove_timezone_from_dataframe(df) assert df_result.equals(df_expected) + + +@pytest.mark.parametrize( + "value, expected", + [ + (123, 123), + ("xpto", "xpto"), + ( + "this isssssssssss a veryyyyy long stringgg", + "this isssssssssss a veryyyyy long stringgg", + ), + (None, None), + (True, True), + (0, 0), + ("2022-01-01", "2022-01-01"), + ("3/9/2022", "3/9/2022"), + ("2022-03-09 10:30:00", "2022-03-09 10:30:00"), + ("a 123", "a 123"), + ([1, 2, 3], [1, 2, 3]), + ("", ""), + ("-3 K", -3000.0), + ("-99 M", -99000000.0), + ("-125 B", -125000000000.0), + ("-15 T", -15000000000000.0), + ("-15 P", -15000000000000000.0), + ("-15 P xpto", "-15 P xpto"), + ("-15 P 3 K", "-15 P 3 K"), + ("15 P -3 K", "15 P -3 K"), + ("2.130", 2.130), + ("2,130.000", "2,130.000"), # this is not a valid number + ("674,234.99", "674,234.99"), # this is not a valid number + ], +) +def test_revert_lambda_long_number_format(value, expected): + assert revert_lambda_long_number_format(value) == expected