Files
stockdb/write_to_db.py
Leonard Excoffier 162106c8e0 e
2024-08-31 19:50:15 -04:00

89 lines
3.5 KiB
Python

import os
import pandas as pd
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy.dialects.mysql import insert
from dotenv import load_dotenv
import numpy as np
# Load environment variables from .env file
load_dotenv()
# Get DB connection parameters from environment
DB_USER = os.getenv('DB_USER')
DB_PASSWORD = os.getenv('DB_PASSWORD')
DB_HOST = os.getenv('DB_HOST')
DB_PORT = os.getenv('DB_PORT')
DB_NAME = os.getenv('DB_NAME')
# Create a connection string
connection_string = f"mariadb+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# Create the SQLAlchemy engine
engine = create_engine(connection_string)
# Define a list of file paths and corresponding table names with primary keys
file_paths = [
('sec_data/2015q1/sub.txt', 'sub', ['adsh']),
('sec_data/2015q1/tag.txt', 'tag', ['tag', 'version']),
('sec_data/2015q1/num.txt', 'num', ['adsh', 'tag', 'version', 'coreg', 'ddate', 'qtrs', 'uom']),
('sec_data/2015q1/pre.txt', 'pre', ['adsh', 'report', 'line'])
]
# Initialize metadata
metadata = MetaData()
# Loop through each file and write the data to the database
for i, (file_path, table_name, primary_keys) in enumerate(file_paths):
print(f"\nAnalyzing {file_path} (File {i+1}/4)...")
# Read the data into a Pandas DataFrame
df = pd.read_csv(file_path, sep='\t')
# Get the DataFrame Information
print("\nSummary Information:")
print(df.info())
# If the file being processed is 'num.txt', fix the `coreg` column
if table_name == 'num':
df['coreg'] = df['coreg'].fillna('nocoreg')
print("\nUpdated 'coreg' column (NaN values replaced with 'nocoreg'):")
print(df[['coreg']].head(10)) # Display first 10 rows of the 'coreg' column for verification
# Dropping rows with any missing values in the primary keys
df.dropna(subset=primary_keys, inplace=True)
# Dropping duplicate rows based on primary keys
# df.drop_duplicates(subset=primary_keys, keep='first', inplace=True)
# Replace NaN values with None to ensure compatibility with SQL NULL
df = df.replace([np.nan, np.inf, -np.inf], None)
# Get Updated Information
print("\nUpdated Information:")
print(df.info())
# Reflect the already existing table from the database schema
table = Table(table_name, metadata, autoload_with=engine)
# Perform Upsert operation for each row in the DataFrame
with engine.connect() as conn:
for row in df.itertuples(index=False):
# Create a dictionary of the row data
data = {key: getattr(row, key) for key in df.columns}
# Prepare insert statement using SQLAlchemy with MySQL-specific ON DUPLICATE KEY UPDATE
insert_stmt = insert(table).values(**data)
# Construct the `ON DUPLICATE KEY UPDATE` part
update_stmt = insert_stmt.on_duplicate_key_update(
{col.name: insert_stmt.inserted[col.name] for col in table.columns}
)
# Execute the upsert statement
conn.execute(update_stmt)
print(f"\nCleaned data from {file_path} has been written to the '{table_name}' table in the database with upsert functionality.\n")
print("\nAll files have been processed and cleaned data has been written to the database.")
#FIXME: Foreign key missing because usgapp is in the past constantly, Q1 gaap is based on the year before gaap.