Skip to content

Commit

Permalink
optimized data writer & check
Browse files Browse the repository at this point in the history
  • Loading branch information
rmfish committed Dec 14, 2023
1 parent 1e1a808 commit d1e4306
Show file tree
Hide file tree
Showing 99 changed files with 324 additions and 390 deletions.
23 changes: 17 additions & 6 deletions tutake/api/base_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,13 @@ def append(self, records):
self.fields = records.fields
self.items = self.items + records.items

def __str__(self):
return self.data_frame().__str__()


class BatchWriter:

def __init__(self, engine, table: str, schema, database_dir):
def __init__(self, engine, table: str, schema, database_dir=None):
self.writer = None
self.engine = engine
self.table = table
Expand All @@ -402,11 +405,12 @@ def _init_writer(self):
return self.writer

def start(self):
self._init_writer()
conn = self.engine.connect()
result = conn.execute(f"""Select max(id) from {self.table}""").fetchone()
if result is not None and result[0] is not None:
self.max_id = int(result[0])
if self.writer is None:
self._init_writer()
conn = self.engine.connect()
result = conn.execute(f"""Select max(id) from {self.table}""").fetchone()
if result is not None and result[0] is not None:
self.max_id = int(result[0])

def rollback(self):
if self.writer:
Expand Down Expand Up @@ -438,12 +442,19 @@ def add_records(self, records):
data = records
else:
return

self.start()
data['id'] = range(self.max_id, self.max_id + len(data))
data = data.reindex(columns=['id'] + list(data.columns[:-1]))
table = pa.Table.from_pandas(df=data, schema=self.schema)
self.writer.write_table(table)
self.max_id = self.max_id + len(data)

def flush(self):
conn = self.engine.connect()
conn.execute("FORCE CHECKPOINT;")
conn.close()

def close(self):
self.max_id = 0
if self.writer:
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tutake.api.base_dao import checker_logger, BatchWriter


def _auto_data_repair(self, trade_date, ts_codes):
def _auto_data_repair(self, trade_date, ts_codes, tushare):
writer = BatchWriter(self.engine, self.table_name)
if len(ts_codes) < 100:
for ts_code in ts_codes:
Expand Down Expand Up @@ -56,11 +56,10 @@ def check_by_date(self, method, default_start, force_start=None, date_apply=lamb
tushare_pd = pd.DataFrame(tushare.items, columns=tushare.fields)
diff = list(set(tushare_pd[diff_column].unique().tolist()) - set(db[diff_column].unique().tolist()))
if diff_repair is not None:
diff_repair(self, trade_date, diff)
continue
diff_repair(self, trade_date, diff, tushare)
else:
checker_logger.warning(
f"Not equals data. The date is {trade_date}. tushare size is {tushare.size()}, db size is {db.shape[0]}, diff is {diff}")
f"Not equals data {self.name}. The date is {trade_date}. tushare size is {tushare.size()}, db size is {db.shape[0]}, diff is {diff}")
if force_start is None:
self.checker.error_point(trade_date=trade_date)
return
Expand Down
4 changes: 4 additions & 0 deletions tutake/api/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,12 @@ def _process_by_func(self, prepare_write, query_parameters, fetch_and_append, wr
logging.exception(err)
finally:
writer.close()
last_task = kwargs.get("last_task")
if last_task is None or last_task is True:
writer.flush()
return status


def _inner_process(self, process_params, fetch_and_append, status: ProcessStatus, writer: BatchWriter = None,
retry_cnt=0, entrypoint=None):
if retry_cnt > self.max_repeat:
Expand Down
11 changes: 7 additions & 4 deletions tutake/api/process_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,23 @@ def _get_all_task(self) -> [Task]:
return apis

def _do_process(self, tasks, entrypoint="scheduler"):
def __process(_job_id, _task) -> ProcessStatus:
def __process(_job_id, _task, is_last_task=None) -> ProcessStatus:
if _task is not None:
return _task.process(entrypoint=entrypoint)
return _task.process(entrypoint=entrypoint, last_task=is_last_task)
else:
return None

start = time.time()
status_list = []
if isinstance(tasks, Task):
status_list.append(__process(f"tutake_{tasks.name}", tasks))
status_list.append(__process(f"tutake_{tasks.name}", tasks, None))
elif isinstance(tasks, Sequence):
for task in tasks:
try:
status_list.append(__process(f"tutake_{task.name}", task))
if task == tasks[-1]:
status_list.append(__process(f"tutake_{task.name}", task, True))
else:
status_list.append(__process(f"tutake_{task.name}", task, False))
except Exception as err:
# self.logger.error(f"Exception with {api} process,err is {err}")
continue
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/adj_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_adj_factor"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -50,7 +49,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareAdjFactor.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareAdjFactor)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareAdjFactor),
config.get_tutake_data_dir())

query_fields = ['ts_code', 'trade_date', 'start_date', 'end_date', 'limit', 'offset']
self.tushare_fields = ["ts_code", "trade_date", "adj_factor"]
Expand Down Expand Up @@ -101,8 +101,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/anns.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_anns"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -55,7 +54,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareAnns.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareAnns)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareAnns),
config.get_tutake_data_dir())

query_fields = ['ts_code', 'ann_date', 'start_date', 'end_date', 'limit', 'offset']
self.tushare_fields = ["ts_code", "ann_date", "ann_type", "title", "content", "pub_time", "src_url", "filepath"]
Expand Down Expand Up @@ -131,8 +131,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/bak_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_bak_basic"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -71,7 +70,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareBakBasic.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareBakBasic)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareBakBasic),
config.get_tutake_data_dir())

query_fields = ['trade_date', 'ts_code', 'limit', 'offset']
self.tushare_fields = [
Expand Down Expand Up @@ -233,8 +233,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/bak_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_bak_daily"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -78,7 +77,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareBakDaily.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareBakDaily)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareBakDaily),
config.get_tutake_data_dir())

query_fields = ['ts_code', 'trade_date', 'start_date', 'end_date', 'offset', 'limit']
self.tushare_fields = [
Expand Down Expand Up @@ -279,8 +279,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/balancesheet_vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_balancesheet_vip"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -205,7 +204,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareBalancesheetVip.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareBalancesheetVip)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareBalancesheetVip),
config.get_tutake_data_dir())

query_fields = [
'ts_code', 'ann_date', 'f_ann_date', 'start_date', 'end_date', 'period', 'report_type', 'comp_type',
Expand Down Expand Up @@ -1099,8 +1099,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/cashflow_vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_cashflow_vip"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -144,7 +143,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareCashflowVip.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareCashflowVip)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareCashflowVip),
config.get_tutake_data_dir())

query_fields = [
'ts_code', 'ann_date', 'f_ann_date', 'start_date', 'end_date', 'period', 'report_type', 'comp_type',
Expand Down Expand Up @@ -718,8 +718,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/ccass_hold.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_ccass_hold"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -53,7 +52,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareCcassHold.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareCcassHold)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareCcassHold),
config.get_tutake_data_dir())

query_fields = ['ts_code', 'trade_date', 'start_date', 'end_date', 'type', 'hk_hold', 'limit', 'offset']
self.tushare_fields = ["trade_date", "ts_code", "name", "shareholding", "hold_nums", "hold_ratio"]
Expand Down Expand Up @@ -121,8 +121,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/ccass_hold_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_ccass_hold_detail"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -54,7 +53,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareCcassHoldDetail.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareCcassHoldDetail)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareCcassHoldDetail),
config.get_tutake_data_dir())

query_fields = ['ts_code', 'trade_date', 'start_date', 'end_date', 'hk_code', 'limit', 'offset']
self.tushare_fields = [
Expand Down Expand Up @@ -132,8 +132,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/ci_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_ci_daily"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -58,7 +57,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareCiDaily.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareCiDaily)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareCiDaily),
config.get_tutake_data_dir())

query_fields = ['ts_code', 'trade_date', 'start_date', 'end_date', 'limit', 'offset']
self.tushare_fields = [
Expand Down Expand Up @@ -155,8 +155,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
7 changes: 3 additions & 4 deletions tutake/api/ts/cn_cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __new__(cls, *args, **kwargs):
def __init__(self, config):
self.table_name = "tushare_cn_cpi"
self.database = 'tutake.duckdb'
self.database_dir = config.get_tutake_data_dir()
self.database_url = config.get_data_driver_url(self.database)
self.engine = create_shared_engine(self.database_url,
connect_args={
Expand All @@ -60,7 +59,8 @@ def __init__(self, config):
session_factory = sessionmaker()
session_factory.configure(bind=self.engine)
TushareCnCpi.__table__.create(bind=self.engine, checkfirst=True)
self.schema = BaseDao.parquet_schema(TushareCnCpi)
self.writer = BatchWriter(self.engine, self.table_name, BaseDao.parquet_schema(TushareCnCpi),
config.get_tutake_data_dir())

query_fields = ['m', 'start_m', 'end_m', 'limit', 'offset']
self.tushare_fields = [
Expand Down Expand Up @@ -166,8 +166,7 @@ def process(self, **kwargs):
同步历史数据
:return:
"""
return super()._process(self.fetch_and_append,
BatchWriter(self.engine, self.table_name, self.schema, self.database_dir), **kwargs)
return super()._process(self.fetch_and_append, self.writer, **kwargs)

def fetch_and_append(self, **kwargs):
"""
Expand Down
Loading

0 comments on commit d1e4306

Please sign in to comment.