Source code for MEDfl.NetManager.network

# src/MEDfl/NetManager/network.py

from MEDfl.LearningManager.utils import *
from .net_helper import *
from .net_manager_queries import (CREATE_MASTER_DATASET_TABLE_QUERY,
                                  CREATE_DATASETS_TABLE_QUERY,
                                  DELETE_NETWORK_QUERY,
                                  INSERT_NETWORK_QUERY, LIST_ALL_NODES_QUERY,
                                  UPDATE_NETWORK_QUERY, GET_NETWORK_QUERY)
from .node import Node
import pandas as pd
from MEDfl.LearningManager.utils import params

from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError

[docs]class Network: """ A class representing a network. Attributes: name (str): The name of the network. mtable_exists (int): An integer flag indicating whether the MasterDataset table exists (1) or not (0). """
[docs] def __init__(self, name: str = ""): """ Initialize a Network instance. Parameters: name (str): The name of the network. """ self.name = name self.mtable_exists = int(master_table_exists()) self.validate() db_manager = DatabaseManager() db_manager.connect() self.eng = db_manager.get_connection()
[docs] def validate(self): """Validate name""" if not isinstance(self.name, str): raise TypeError("name argument must be a string")
[docs] def create_network(self): """Create a new network in the database.""" try: print(self.name) self.eng.execute(text(INSERT_NETWORK_QUERY), {"name": self.name}) self.id = self.get_netid_from_name(self.name) except SQLAlchemyError as e: print(f"Error creating network: {e}")
[docs] def use_network(self, network_name: str): """Use a network in the database. Parameters: network_name (str): The name of the network to use. Returns: Network or None: An instance of the Network class if the network exists, else None. """ try: network = pd.read_sql( text(GET_NETWORK_QUERY), self.eng, params={"name": network_name} ) if not network.empty: self.name = network.iloc[0]['NetName'] self.id = network.iloc[0]['NetId'] self.mtable_exists = int(master_table_exists()) self.validate() return self else: return None except SQLAlchemyError as e: print(f"Error using network: {e}") return None
[docs] def delete_network(self): """Delete the network from the database.""" try: self.eng.execute(text(DELETE_NETWORK_QUERY), {"name": self.name}) except SQLAlchemyError as e: print(f"Error deleting network: {e}")
[docs] def update_network(self, FLsetupId: int): """Update the network's FLsetupId in the database. Parameters: FLsetupId (int): The FLsetupId to update. """ try: self.eng.execute( text(UPDATE_NETWORK_QUERY), {"FLsetupId": FLsetupId, "id": self.id} ) except SQLAlchemyError as e: print(f"Error updating network: {e}")
[docs] def add_node(self, node: Node): """Add a node to the network. Parameters: node (Node): The node to add. """ node.create_node(self.id)
[docs] def list_allnodes(self): """List all nodes in the network. Returns: DataFrame: A DataFrame containing information about all nodes in the network. """ try: query = text(LIST_ALL_NODES_QUERY) result_proxy = self.eng.execute(query, name=self.name) result_df = pd.DataFrame(result_proxy.fetchall(), columns=result_proxy.keys()) return result_df except SQLAlchemyError as e: print(f"Error listing all nodes: {e}") return pd.DataFrame()
[docs] def create_master_dataset(self, path_to_csv: str = params['path_to_master_csv']): """ Create the MasterDataset table and insert dataset values. :param path_to_csv: Path to the CSV file containing the dataset. """ try: print(path_to_csv) data_df = pd.read_csv(path_to_csv) if self.mtable_exists != 1: columns = data_df.columns.tolist() columns_str = ",\n".join( [ f"{col} {column_map[str(data_df[col].dtype)]}" for col in columns ] ) self.eng.execute( text(CREATE_MASTER_DATASET_TABLE_QUERY.format(columns_str)) ) self.eng.execute(text(CREATE_DATASETS_TABLE_QUERY.format(columns_str))) # Process data data_df = process_eicu(data_df) # Insert data in batches batch_size = 1000 # Adjust as needed for start_idx in range(0, len(data_df), batch_size): batch_data = data_df.iloc[start_idx:start_idx + batch_size] insert_query = f"INSERT INTO MasterDataset ({', '.join(columns)}) VALUES ({', '.join([':' + col for col in columns])})" data_to_insert = batch_data.to_dict(orient='records') self.eng.execute(text(insert_query), data_to_insert) self.mtable_exists = 1 except SQLAlchemyError as e: print(f"Error creating master dataset: {e}")
[docs] @staticmethod def list_allnetworks(): """List all networks in the database. Returns: DataFrame: A DataFrame containing information about all networks in the database. """ try: db_manager = DatabaseManager() db_manager.connect() my_eng = db_manager.get_connection() result_proxy = my_eng.execute("SELECT * FROM Networks") result = result_proxy.fetchall() return pd.DataFrame(result, columns=result_proxy.keys()) except SQLAlchemyError as e: print(f"Error listing all networks: {e}") return pd.DataFrame()
[docs] def get_netid_from_name(self, name): """Get network ID from network name.""" try: result = self.eng.execute(text("SELECT NetId FROM Networks WHERE NetName = :name"), {"name": name}).fetchone() if result: return result[0] else: return None except SQLAlchemyError as e: print(f"Error fetching network ID: {e}") return None