from __future__ import annotations

import dataclasses
import datetime as dt
import enum
import json
from pathlib import Path
import re
from typing import ClassVar, Final, Optional, Union

from benchmark_results_collector.private import common
from benchmark_results_collector.private.db_adapter import Db

HIVEMIND_INDEXER: Final = 'hivemind_indexer'


class InfoType(enum.Enum):
    CREATING_INDEXES = 'creating_indexes'
    BLOCKS_INFO = 'blocks_info'
    FILLING_DATA = 'filling_data'


@dataclasses.dataclass
class SyncInfo:
    # pylint: disable=too-many-instance-attributes

    last_block_number: ClassVar[int] = 0
    range_from: int
    range_to: int
    processing_n_blocks_time: float
    processing_total_time: float
    physical_memory: float
    virtual_memory: float
    shared_memory: float
    mem_unit: str
    time_unit: str = 'ms'
    caller: str = HIVEMIND_INDEXER

    def __post_init__(self):
        self.__class__.last_block_number = self.range_to
        self.block_count = self.range_to - self.range_from + 1

    def to_mapped_db_data_instances(self, measurement_timestamp: dt.datetime) -> list[common.DbData]:
        return [
            common.DbData(**self.__processing_blocks_partial_time(), measurement_timestamp=measurement_timestamp),
            common.DbData(**self.__processing_blocks_total_real_time(), measurement_timestamp=measurement_timestamp),
            common.DbData(**self.__memory_usage_physical(), measurement_timestamp=measurement_timestamp),
            common.DbData(**self.__memory_usage_virtual(), measurement_timestamp=measurement_timestamp),
            common.DbData(**self.__memory_usage_shared(), measurement_timestamp=measurement_timestamp),
        ]

    def __processing_blocks_partial_time(self) -> dict:
        return {
            'caller': self.caller,
            'method': f'processing_{self.block_count}_blocks_real_time',
            'params': json.dumps({'from': self.range_from, "to": self.range_to}),
            'value': round(self.processing_n_blocks_time * 10**3),
            'unit': self.time_unit,
        }

    def __processing_blocks_total_real_time(self) -> dict:
        return {
            'caller': self.caller,
            'method': 'processing_blocks_total_real_time',
            'params': json.dumps({'block': self.range_to}),
            'value': round(self.processing_total_time * 10**3),
            'unit': self.time_unit,
        }

    def __memory_usage_physical(self) -> dict:
        return {
            'caller': self.caller,
            'method': 'memory_usage_physical',
            'params': json.dumps({'block': self.range_to}),
            'value': round(self.physical_memory),
            'unit': self.mem_unit,
        }

    def __memory_usage_virtual(self) -> dict:
        return {
            'caller': self.caller,
            'method': 'memory_usage_virtual',
            'params': json.dumps({'block': self.range_to}),
            'value': round(self.virtual_memory),
            'unit': self.mem_unit,
        }

    def __memory_usage_shared(self) -> dict:
        return {
            'caller': self.caller,
            'method': 'memory_usage_shared',
            'params': json.dumps({'block': self.range_to}),
            'value': round(self.shared_memory),
            'unit': self.mem_unit,
        }


class DbOperation:
    CALLER: Final[str] = HIVEMIND_INDEXER
    TIME_UNIT: Final[str] = 'ms'

    def __init__(self, info_type: InfoType, total_time: float, partials: list[Partial]):
        self.__info_type = info_type
        self.__total_time = total_time
        self.partials = partials

    def __eq__(self, other):
        """
        When comparing partials we don't include the field containing the parent reference
        in the 'db_operation field' field because it causes recursion and parent was already compared.
        """
        if not isinstance(other, DbOperation):
            return False  # don't attempt to compare against unrelated types

        for key in self.__dict__:  # pylint: disable=consider-using-dict-items
            if key != '_DbOperation__partials' and self.__dict__[key] != other.__dict__[key]:
                return False

        for this_partial, that_partial in zip(self.partials, other.partials):
            this_partial = this_partial.__dict__.copy()
            that_partial = that_partial.__dict__.copy()

            this_partial.pop('db_operation')
            that_partial.pop('db_operation')

            if this_partial != that_partial:
                return False

        return True

    @property
    def info_type(self) -> str:
        return self.__info_type.value

    @property
    def partials(self) -> list[Partial]:
        return self.__partials

    @partials.setter
    def partials(self, container: list[Partial]):
        for partial in container:
            partial.db_operation = self
        self.__partials = container

    def to_mapped_db_data_instances(self, measurement_timestamp: dt.datetime) -> list[common.DbData]:
        return [common.DbData(**self.__real_time(), measurement_timestamp=measurement_timestamp)]

    def __real_time(self) -> dict:
        return {
            'caller': self.CALLER,
            'method': f'{self.info_type}_real_time',
            'params': '',
            'value': round(self.__total_time * 10**3),
            'unit': self.TIME_UNIT,
        }

    @dataclasses.dataclass
    class Partial:
        db_operation: DbOperation = dataclasses.field(init=False)
        table_name: str
        total_time: float
        time_unit: str = 'ms'

        def __eq__(self, other):
            raise NotImplementedError

        def to_mapped_db_data_instances(self, measurement_timestamp: dt.datetime) -> list[common.DbData]:
            return [common.DbData(**self.type_partial_time(), measurement_timestamp=measurement_timestamp)]

        def type_partial_time(self) -> dict:
            return {
                'caller': self.db_operation.CALLER,
                'method': f'{self.db_operation.info_type}_cpu_time',
                'params': json.dumps({'table_name': self.table_name}),
                'value': round(self.total_time * 10**3),
                'unit': self.time_unit,
            }


def extract_interesting_log_strings(text: str) -> dict[InfoType, list]:
    """Extracts only the fragments containing the necessary operations (creating indexes, filling data, block info)"""
    creating_indexes_regex = r'INFO - hive\.utils\.stats - Total CREATING indexes time\n((?:.*seconds\n)*.*\n.*)'
    block_info_regex = r'.*\nINFO - hive\.indexer\.sync - \[MASSIVE\] .*\n(?:.*\n){2}.*'
    filling_data_regex = r'INFO - hive\.utils\.stats - Total final operations time\n((?:.*seconds\n)*.*\n.*)'

    return {
        InfoType.CREATING_INDEXES: re.findall(creating_indexes_regex, text),
        InfoType.BLOCKS_INFO: re.findall(block_info_regex, text),
        InfoType.FILLING_DATA: re.findall(filling_data_regex, text),
    }


def parse_database_operation(lines: list[str], info_type: InfoType) -> Optional[DbOperation]:
    """Parses lines about `creating indexes` or `filling data`"""
    partial_regex = r'INFO - hive.utils.stats - `(.*)`: Processed final operations in ([\.\d]*) seconds'
    summary_regex = (
        r'INFO - hive.db.db_state - Elapsed time: ([\.\d]*)s. Calculated elapsed time: [\.\d]*s. '
        r'Difference: [-\.\d]*s'
    )

    partials = []
    for line in lines:
        if match := re.match(partial_regex, line):
            partials.append(DbOperation.Partial(table_name=match[1], total_time=float(match[2])))
        elif match := re.match(summary_regex, line):
            return DbOperation(info_type=info_type, total_time=float(match[1]), partials=partials)
    return None


def parse_blocks_info(lines: list[str]) -> Optional[SyncInfo]:
    """Parses lines about `blocks info`"""
    current_block = (
        processing_n_blocks_time
    ) = processing_total_time = physical_memory = virtual_memory = shared_memory = mem_unit = None

    processing_n_blocks_regex = r'INFO - hive\.indexer\.blocks - \[PROCESS MULTI\] (\d*) blocks in ([\.\d]*)s'
    current_block_regex = r'INFO - hive\.indexer\.sync - \[MASSIVE\] Got block (\d*) .*'
    processing_total_time_regex = r'INFO - hive\.indexer\.sync - \[MASSIVE\] Time elapsed: ([\.\d]*)s'
    memory_usage_regex = (
        r'INFO - hive\.indexer\.sync - memory usage report: physical_memory = ([\.\d]*) (.*),'
        r' virtual_memory = ([\.\d]*) (.*), shared_memory = ([\.\d]*) (.*)'
    )

    for line in lines:
        if match := re.match(current_block_regex, line):
            current_block = match[1]
        elif match := re.match(processing_n_blocks_regex, line):
            processing_n_blocks_time = match[2]
        elif match := re.match(processing_total_time_regex, line):
            processing_total_time = match[1]
        elif match := re.match(memory_usage_regex, line):
            physical_memory = match[1]
            virtual_memory = match[3]
            shared_memory = match[5]
            mem_unit = match[2] if match[2] == match[4] == match[6] else 'unknown'

    if None in (
        current_block,
        processing_n_blocks_time,
        processing_total_time,
        physical_memory,
        virtual_memory,
        shared_memory,
        mem_unit,
    ):
        return None

    return SyncInfo(
        range_from=SyncInfo.last_block_number + 1,
        range_to=int(current_block),
        processing_n_blocks_time=float(processing_n_blocks_time),
        processing_total_time=float(processing_total_time),
        physical_memory=float(physical_memory),
        virtual_memory=float(virtual_memory),
        shared_memory=float(shared_memory),
        mem_unit=mem_unit,
    )


def parse_log_strings_to_objects(
    info_type: InfoType, interesting_log_strings: list[str]
) -> Union[list[DbOperation], list[SyncInfo]]:
    if info_type in (InfoType.CREATING_INDEXES, InfoType.FILLING_DATA):
        return [
            parsed
            for text in interesting_log_strings
            if (parsed := parse_database_operation(text.split('\n'), info_type)) is not None
        ]
    if info_type == InfoType.BLOCKS_INFO:
        return [
            parsed for text in interesting_log_strings if (parsed := parse_blocks_info(text.split('\n'))) is not None
        ]
    return []


async def main(database: Db, file: Path, benchmark_id: int, measurement_timestamp: dt.datetime):
    text = common.get_text_from_log_file(file)
    interesting_log_strings = extract_interesting_log_strings(text)

    parsed_objects = {}
    for info_type, text in list(interesting_log_strings.items()):
        parsed_objects[info_type] = parse_log_strings_to_objects(info_type, text)

    all_parsed_objects: list[SyncInfo | DbOperation | DbOperation.Partial] = []
    for info_type, parsed_list in parsed_objects.items():
        for parsed in parsed_list:
            if info_type in (InfoType.CREATING_INDEXES, InfoType.FILLING_DATA):
                all_parsed_objects.extend(parsed.partials)
            all_parsed_objects.append(parsed)

    mapped_instances = []
    for parsed in all_parsed_objects:
        mapped_instances.extend(parsed.to_mapped_db_data_instances(measurement_timestamp=measurement_timestamp))

    common.DbData.distinguish_objects_having_same_hash(mapped_instances)

    for mapped in mapped_instances:
        await mapped.insert(database, benchmark_id)
