from __future__ import annotations

import datetime as dt
from enum import Enum
import json
from pathlib import Path
from typing import Final

from benchmark_results_collector.private import common
from benchmark_results_collector.private.db_adapter import Db


class MeasurementType(Enum):
    PARTIAL = 'partial_measurement'
    TOTAL = 'total_measurement'


class Measurement:
    CALLER: Final[str] = 'replay_benchmark'
    TIME_UNIT: Final[str] = 'ms'
    MEM_UNIT: Final[str] = 'kB'
    timestamp: dt.datetime

    def __init__(self, data: dict, measurement_type: MeasurementType):
        self.__measurement_type = measurement_type
        self.__block_number = data['block_number']
        self.__real_ms = data['real_ms']
        self.__cpu_ms = data['cpu_ms']
        self.__current_mem = data['current_mem']
        self.__peak_mem = data['peak_mem']

        self.indexes = [self.IndexMemoryDetail(self, **index) for index in data['index_memory_details_cntr']]

    @property
    def measurement_type(self) -> str:
        return self.__measurement_type.value

    @property
    def block_number(self) -> int:
        return self.__block_number

    def to_db_data_instances(self) -> list[common.DbData]:
        return [
            common.DbData(**self.__real_time_testcase(), measurement_timestamp=self.timestamp),
            common.DbData(**self.__cpu_time_testcase(), measurement_timestamp=self.timestamp),
            common.DbData(**self.__current_memory_usage_testcase(), measurement_timestamp=self.timestamp),
            common.DbData(**self.__peak_memory_usage_testcase(), measurement_timestamp=self.timestamp),
        ]

    def __detailed_dict(self, method: str, value: int, unit: str):
        return {
            'caller': self.CALLER,
            'method': method,
            'params': json.dumps({'block': self.block_number}),
            'value': value,
            'unit': unit,
        }

    def __real_time_testcase(self) -> dict:
        return self.__detailed_dict(
            method=f'{self.measurement_type}_real_time',
            value=self.__real_ms,
            unit=self.TIME_UNIT,
        )

    def __cpu_time_testcase(self) -> dict:
        return self.__detailed_dict(
            method=f'{self.measurement_type}_cpu_time',
            value=self.__cpu_ms,
            unit=self.TIME_UNIT,
        )

    def __current_memory_usage_testcase(self) -> dict:
        return self.__detailed_dict(
            method=f'{self.measurement_type}_current_memory_usage',
            value=self.__current_mem,
            unit=self.MEM_UNIT,
        )

    def __peak_memory_usage_testcase(self) -> dict:
        return self.__detailed_dict(
            method=f'{self.measurement_type}_peak_memory_usage',
            value=self.__peak_mem,
            unit=self.MEM_UNIT,
        )

    class IndexMemoryDetail:
        UNIT: Final[str] = 'B'
        CALLER: Final[str] = 'replay_benchmark_index_memory_detail'

        def __init__(self, measurement: Measurement, **kwargs):
            self.__measurement = measurement
            self.__index_name = kwargs.pop('index_name').replace('::', '_')
            self.__index_size = kwargs.pop('index_size')
            self.__item_sizeof = kwargs.pop('item_sizeof')
            self.__total_index_mem_usage = kwargs.pop('total_index_mem_usage')

        def to_db_data_instances(self):
            return [
                common.DbData(**self.__index_size_testcase(), measurement_timestamp=Measurement.timestamp),
                common.DbData(**self.__item_sizeof_testcase(), measurement_timestamp=Measurement.timestamp),
                common.DbData(**self.__total_index_mem_usage_testcase(), measurement_timestamp=Measurement.timestamp),
            ]

        def __detailed_dict(self, method: str, value: int):
            return {
                'caller': self.CALLER,
                'method': method,
                'params': json.dumps({'block': self.__measurement.block_number}),
                'value': value,
                'unit': self.UNIT,
            }

        def __index_size_testcase(self):
            return self.__detailed_dict(
                method=f'{self.__measurement.measurement_type}_{self.__index_name}_index_size',
                value=self.__index_size,
            )

        def __item_sizeof_testcase(self):
            return self.__detailed_dict(
                method=f'{self.__measurement.measurement_type}_{self.__index_name}_item_sizeof',
                value=self.__item_sizeof,
            )

        def __total_index_mem_usage_testcase(self):
            return self.__detailed_dict(
                method=f'{self.__measurement.measurement_type}_{self.__index_name}_total_index_mem_usage',
                value=self.__total_index_mem_usage,
            )


async def main(database: Db, file: Path, benchmark_id: int, measurement_timestamp: dt.datetime):
    text = common.get_text_from_log_file(file)
    replay = json.loads(text)
    measurements = replay['measurements']
    total_measurement = replay['total_measurement']

    Measurement.timestamp = measurement_timestamp
    parsed_measurements = [
        Measurement(data=measurement, measurement_type=MeasurementType.PARTIAL) for measurement in measurements
    ]
    parsed_measurements.append(Measurement(data=total_measurement, measurement_type=MeasurementType.TOTAL))

    mapped_instances = []
    for measurement in parsed_measurements:
        mapped_instances.extend(measurement.to_db_data_instances())
        for index in measurement.indexes:
            mapped_instances.extend(index.to_db_data_instances())

    common.DbData.distinguish_objects_having_same_hash(objects=mapped_instances)
    for mapped in mapped_instances:
        await mapped.insert(database=database, benchmark_id=benchmark_id)
