Skip to content

Commit 0a2d493

Browse files
committed
feat: improve search and symbols_is_namespace
- search accepts namespace as an argument - search and symbols_is_namespace return DataFrame (default) or json
1 parent babbc12 commit 0a2d493

File tree

5 files changed

+84
-16
lines changed

5 files changed

+84
-16
lines changed

main.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
import numpy as np
21
import okama as ok
32

4-
asset_list = ok.AssetList(assets=['MSFT.US'], ccy='USD')
5-
6-
print(asset_list.dividends_annual)
7-
8-
9-
x = asset_list.get_dividend_mean_growth_rate(period=20)
10-
# x.replace([np.inf, -np.inf], 0, inplace=True)
11-
12-
print(f'Growth rate:', x)
3+
print(ok.search('aeroflot', namespace=None, response_format='frame'))

okama/api/namespaces.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@ def get_namespaces():
1414

1515

1616
@lru_cache()
17-
def symbols_in_namespace(namespace: str = default_namespace):
17+
def symbols_in_namespace(namespace: str = default_namespace, response_format: str = 'frame') -> pd.DataFrame:
1818
string_response = API.get_symbols_in_namespace(namespace.upper())
1919
list_of_symbols = json.loads(string_response)
20-
df = pd.DataFrame(list_of_symbols[1:], columns=list_of_symbols[0])
21-
return df.astype("string", copy=False)
20+
if response_format.lower() == 'frame':
21+
df = pd.DataFrame(list_of_symbols[1:], columns=list_of_symbols[0])
22+
return df.astype("string", copy=False)
23+
elif response_format.lower() == 'json':
24+
return list_of_symbols
25+
else:
26+
raise ValueError('response_format must be "json" or "frame"')
27+
2228

2329

2430
@lru_cache()

okama/api/search.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,32 @@
11
import json
2+
from typing import Optional
3+
4+
import pandas as pd
25

36
from .api_methods import API
7+
from .namespaces import symbols_in_namespace
48

59

6-
def search(search_string: str) -> json:
10+
def search(search_string: str, namespace: Optional[str] = None, response_format: str = 'frame') -> json:
11+
# search for string in a single namespace
12+
if namespace:
13+
df = symbols_in_namespace(namespace.upper())
14+
condition1 = df['name'].str.contains(search_string, case=False)
15+
condition2 = df['ticker'].str.contains(search_string, case=False)
16+
frame_response = df[condition1 | condition2]
17+
if response_format.lower() == 'frame':
18+
return frame_response
19+
elif response_format.lower() == 'json':
20+
return frame_response.to_json(orient='records')
21+
else:
22+
raise ValueError('response_format must be "json" or "frame"')
23+
# search for string in all namespaces
724
string_response = API.search(search_string)
8-
return json.loads(string_response)
25+
json_response = json.loads(string_response)
26+
if response_format.lower() == 'frame':
27+
df = pd.DataFrame(json_response[1:], columns=json_response[0])
28+
return df
29+
elif response_format.lower() == 'json':
30+
return json_response
31+
else:
32+
raise ValueError('response_format must be "json" or "frame"')

okama/common/make_asset_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
self.assets_first_dates.update({self.inflation: Inflation(self.inflation).first_date})
8080
self.assets_last_dates.update({self.inflation: Inflation(self.inflation).last_date})
8181
self.assets_ror: pd.DataFrame = self.assets_ror[
82-
self.first_date : self.last_date
82+
self.first_date: self.last_date
8383
]
8484
self.period_length: float = round(
8585
(self.last_date - self.first_date) / np.timedelta64(365, "D"), ndigits=1

tests/test_search.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Tests the search
3+
"""
4+
import json
5+
6+
import pytest
7+
8+
from okama.api.search import search
9+
10+
11+
def test_search_namespace_json():
12+
x = search(
13+
"aeroflot",
14+
namespace="MOEX",
15+
response_format="json",
16+
)
17+
assert json.loads(x)[0]['symbol'] == 'AFLT.MOEX'
18+
19+
20+
def test_search_namespace_frame():
21+
x = search(
22+
"aeroflot",
23+
namespace="MOEX",
24+
response_format="frame",
25+
)
26+
assert x['symbol'].values[0] == 'AFLT.MOEX'
27+
28+
29+
def test_search_all_json():
30+
x = search(
31+
"lkoh",
32+
response_format="json",
33+
)
34+
assert x[1][0] == 'LKOH.MOEX'
35+
36+
37+
def test_search_all_frame():
38+
x = search(
39+
"lkoh",
40+
response_format="frame",
41+
)
42+
assert x['symbol'].iloc[0] == 'LKOH.MOEX'
43+
44+
45+
def test_search_error():
46+
with pytest.raises(ValueError):
47+
search("arg", response_format='txt')

0 commit comments

Comments
 (0)