HEX
Server: Apache/2.4.52 (Ubuntu)
System: Linux spn-python 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 x86_64
User: arjun (1000)
PHP: 8.1.2-1ubuntu2.20
Disabled: NONE
Upload Files
File: //home/arjun/projects/good-life-be/helper/gpt_schedule_optimizer.py
import json
import sys
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Dict, Any, Tuple
from enum import Enum
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import os
import re
from difflib import SequenceMatcher
from langchain_openai import ChatOpenAI 


load_dotenv()


class Category(str, Enum):
    """Enum choices for categories"""
    CORESCHEDULE = 'Core Schedule'
    HEALTHANDWELLNESS = 'Health & Wellness'
    PERSONALDEVELOPMENTANDLEARNING = 'Personal Development & Learning'
    LEISUREANDRECREATION = 'Leisure & Recreation'
    HOUSEHOLDANDROUTINEMAINTENANCE = 'Household & Routine Maintenance'
    RELATTIONSHIPSANDSOCIAL = 'Relationships & Social'


class SubCategory(str, Enum):
    """Enum choices for subcategories"""
    SleepPattern = "Sleep Patterns"
    EnergyPatterns = "Energy Patterns"
    WorkSchedule = "Work Schedule"
    MealSchedule = "Meal Schedule"
    ExerciseAndPhysicalActivity = "Exercise & Physical Activity"
    MentalWellness = "Mental Wellness"
    HealthMaintenance = "Health Maintenance"
    LeisureTime = "Leisure Time"
    VacationPlanning = "Vacation Planning"
    RegularMaintenance = "Regular Maintenance"
    Chores = "Chores"
    AdditionalMaintenanceTasks = "Additional Maintenance Tasks"
    FamilyTime = "Family Time"
    PartnerTime = "Partner Time"
    SocialConnections = "Social Connections"
    SpecialEvents = "Special Events"
    Holidays = "Holidays"
    General = "General"


class ConflictManagedevent(BaseModel):
    event_title: str = Field(description="Use the given event title.")
    event_description: str = Field(description="Use the given event description.")
    start_time: str = Field(description="Event start time in HH:MM format. Use 24 hour format.")
    end_time: str = Field(description="Event end time in HH:MM format. Use 24 hour format.")
    bufferTime: int = Field(description="Buffer duration in minutes")
    category: Category = Field(description="Use the given category")
    sub_category: SubCategory = Field(description="Use the given sub-category")
    priority: float = Field(description="Denotes the priority level of the event. Use the given priority.")

    class Config:
        use_enum_values = True


class ListofConflictmanagedEvent(BaseModel):
    """Model to hold a list of all events."""
    events: List[ConflictManagedevent] = Field(
        default_factory=list, 
        description="List of all events"
    )


@dataclass
class DistributionTracker:
    data: Dict[str, Dict] = None

    def __post_init__(self):
        if self.data is None:
            self.data = {}
        # For each key in the distribution data, initialize current_count if needed.
        for key in self.data:
            monthly = self.data[key].get('monthly_count', {})
            original_monthly = monthly.get('original_count', [])
            # If current_count is not present or not the same length, initialize it.
            if 'current_count' not in monthly or len(monthly.get('current_count', [])) < len(original_monthly):
                self.data[key]['monthly_count']['current_count'] = [0] * len(original_monthly)
            
            weekly = self.data[key].get('weekly_count', {})
            original_weekly = weekly.get('original_count', [])
            if 'current_count' not in weekly or len(weekly.get('current_count', [])) < len(original_weekly):
                self.data[key]['weekly_count']['current_count'] = [0] * len(original_weekly)

    def update_counts(self, date: str, subcategory: str, is_removal: bool = False):
        dt = datetime.strptime(date, "%Y-%m-%d")
        month_idx = dt.month - 1
        week_idx = dt.isocalendar()[1] - 1
        if subcategory not in self.data:
            return
        change = -1 if is_removal else 1
        self.data[subcategory]['monthly_count']['current_count'][month_idx] += change
        self.data[subcategory]['weekly_count']['current_count'][week_idx] += change

    def get_distribution_score(self, date: str, subcategory: str) -> float:
        if subcategory not in self.data:
            return 0.0
        dt = datetime.strptime(date, "%Y-%m-%d")
        month_idx = dt.month - 1
        week_idx = dt.isocalendar()[1] - 1

        monthly_current = self.data[subcategory]['monthly_count']['current_count']
        monthly_original = self.data[subcategory]['monthly_count']['original_count']
        weekly_current = self.data[subcategory]['weekly_count']['current_count']
        weekly_original = self.data[subcategory]['weekly_count']['original_count']

        monthly_ratio = monthly_current[month_idx] / monthly_original[month_idx] if monthly_original[month_idx] > 0 else 1.0
        weekly_ratio = weekly_current[week_idx] / weekly_original[week_idx] if weekly_original[week_idx] > 0 else 1.0
        return (monthly_ratio + weekly_ratio) / 2


class ScheduleOptimizer:
    def __init__(self, distribution_tracker: DistributionTracker, priority_levels: Dict[str, int]):
        self.distribution_tracker = distribution_tracker
        self.priority_levels = priority_levels
        self.min_slot_duration = timedelta(minutes=20)

    def _to_datetime(self, time_str: str) -> datetime:
        return datetime.strptime(time_str, "%H:%M")

    def _get_duration(self, start_time: str, end_time: str) -> timedelta:
        return self._to_datetime(end_time) - self._to_datetime(start_time)

    def _get_priority_score(self, event: Dict) -> float:
        for key, priority in self.priority_levels.items():
            # if key.lower() in event['event_title'].lower() or key.lower() in event['category'].lower():
            if key.lower() in str(event).lower():
                return priority / max(self.priority_levels.values())
        return 0.0

    def _get_event_score(self, event: Dict, date: str, available_duration: timedelta) -> float:
        score = 0.0
        priority_score = self._get_priority_score(event)
        score += priority_score * 0.4
        # distribution_score = 1.0 - self.distribution_tracker.get_distribution_score(date, event['subCategory'])
        
        distribution_key = get_distribution_key_from_event(event)
        distribution_score = 1.0 - self.distribution_tracker.get_distribution_score(date, distribution_key)
    
        score += distribution_score * 0.4
        event_duration = self._get_duration(event['start_time'], event['end_time'])
        duration_ratio = event_duration / available_duration
        fit_score = 1.0 - abs(0.8 - duration_ratio)
        score += fit_score * 0.2
        return score

    def optimize_failed_events(self, date: str, failed_events: List[Dict], available_slots: List[Tuple[str, str]]) -> List[Dict]:
        """
        Optimizes the scheduling of failed events using the (mutable) available_slots list.
        Instead of using a heap over copies of available slots, we sort the events by a computed score,
        then iterate over them in order and try to schedule each using the full available_slots list.
        """
        scheduled_events = []
        # Determine a reference available duration: maximum duration among current slots
        if available_slots:
            max_duration = max([self._get_duration(s, e) for s, e in available_slots])
        else:
            max_duration = timedelta()
        
        # Sort events by descending score (using the reference duration)
        remaining_events = sorted(
            failed_events,
            key=lambda e: self._get_event_score(e, date, max_duration),
            reverse=True
        )
        
        for event in remaining_events:
            # Try scheduling the event using the current, mutable available_slots list.
            self._schedule_event(event, date, available_slots, set(), scheduled_events)
        
        return scheduled_events

    def _schedule_event(self, event: Dict, date: str, available_slots: List[Tuple[str, str]], 
                        used_slots: set, scheduled_events: List[Dict]) -> bool:
        event_duration = self._get_duration(event['start_time'], event['end_time'])
        buffer_time = timedelta(minutes=event['bufferTime'])
        total_needed = event_duration + buffer_time
        allowed = allowed_categories_for_event(event)
                     
        # Iterate over the available_slots list (which is mutable)
        for idx in range(len(available_slots)):
            slot_start, slot_end = available_slots[idx]
            # Instead of a single category, get all categories spanned by the slot.
            slot_categories = get_slot_categories(slot_start, slot_end)
                       
            # Only continue if there is an intersection between allowed and slot categories.
            if allowed.isdisjoint(slot_categories):
                continue
            
            slot_duration = self._get_duration(slot_start, slot_end)
            if slot_duration >= total_needed:
                # Schedule event at the current start of this slot.
                new_event = event.copy()
                new_event['start_time'] = slot_start
                new_event['end_time'] = (self._to_datetime(slot_start) + event_duration).strftime("%H:%M")
                scheduled_events.append(new_event)
                
                used_length = total_needed
                fraction_used = used_length / slot_duration
                
                if fraction_used >= 0.75: # Check if slot range time is used above 75%
                    # Remove this slot entirely.
                    available_slots.pop(idx)
                else:
                    # Update the slot by advancing its start time.
                    new_slot_start = (self._to_datetime(slot_start) + used_length).strftime("%H:%M")
                    available_slots[idx] = (new_slot_start, slot_end)
                return True
        return False


def are_titles_similar(title1, title2, threshold=0.9):
    """
    Checks if two titles are similar based on a similarity threshold.
    
    Args:
        title1 (str): First title.
        title2 (str): Second title.
        threshold (float): Similarity threshold (0 to 1). Default is 0.9.
    
    Returns:
        bool: True if the titles are similar, False otherwise.
    """
    # Preprocess titles: lowercase, remove special characters, and normalize spaces
    def preprocess(title):
        title = title.lower()
        title = re.sub(r'[^a-zA-Z0-9\s]', '', title)  # Remove special characters
        title = re.sub(r'\s+', ' ', title).strip()    # Normalize spaces
        return title

    # Preprocess both titles
    title1_clean = preprocess(title1)
    title2_clean = preprocess(title2)
    
    # Calculate similarity using SequenceMatcher
    similarity = SequenceMatcher(None, title1_clean, title2_clean).ratio()
    
    # Return True if similarity is above the threshold
    return similarity >= threshold

    
def remove_duplicate_dicts(list_of_dicts):
    """Removes duplicate dictionaries from a list, preserving order."""
    seen_dicts = set()
    unique_dicts = []
    for dictionary in list_of_dicts:
        # Convert dictionary to a hashable type (tuple of sorted items)
        dict_tuple = tuple(sorted(dictionary.items()))
        if dict_tuple not in seen_dicts:
            seen_dicts.add(dict_tuple)
            unique_dicts.append(dictionary)
    return unique_dicts

    
def get_time_range_category(time_str: str) -> str:
    """
    Returns a category for a single time value.
    """
    t = datetime.strptime(time_str, "%H:%M").time()
    if t >= datetime.strptime("05:00", "%H:%M").time() and t <= datetime.strptime("08:59", "%H:%M").time():
        return "early_morning"
    elif t >= datetime.strptime("09:00", "%H:%M").time() and t <= datetime.strptime("11:59", "%H:%M").time():
        return "late_morning"
    elif t >= datetime.strptime("12:00", "%H:%M").time() and t <= datetime.strptime("14:59", "%H:%M").time():
        return "early_afternoon"
    elif t >= datetime.strptime("15:00", "%H:%M").time() and t <= datetime.strptime("17:59", "%H:%M").time():
        return "late_afternoon"
    elif t >= datetime.strptime("18:00", "%H:%M").time() and t <= datetime.strptime("20:59", "%H:%M").time():
        return "evening"
    elif t >= datetime.strptime("21:00", "%H:%M").time() and t <= datetime.strptime("23:59", "%H:%M").time():
        return "night"
    else:
        return 'late_night'
        
        
def get_slot_category(slot_start: str, slot_end: str) -> str:
    """
    Determines a slot category based on its start and end times.
    If the slot spans two categories, return the category that the slot's midpoint falls into.
    """
    start_dt = datetime.strptime(slot_start, "%H:%M")
    end_dt = datetime.strptime(slot_end, "%H:%M")
    midpoint = start_dt + (end_dt - start_dt) / 2
    return get_time_range_category(midpoint.strftime("%H:%M"))


def allowed_categories_for_event(event: Dict) -> set:
    """
    Returns a set of allowed slot categories for an event.
    
    If the event spans two or fewer categories (i.e. the start and end categories are adjacent or identical),
    then allow both.
    
    If the event spans more than two categories, then allow only the event's start category and its immediate
    next category. For example, if an event starts in "early_afternoon" and ends in "evening" (spanning three
    categories: early_afternoon, late_afternoon, and evening), then allowed categories will be {"early_afternoon", "late_afternoon"}.
    """
    # Define the ordered list of categories.
    categories_order = ["early_morning", "late_morning", "early_afternoon", "late_afternoon", "evening", "night"]
    
    start_cat = get_time_range_category(event['start_time'])
    end_cat = get_time_range_category(event['end_time'])
    
    try:
        start_index = categories_order.index(start_cat)
    except ValueError:
        start_index = None
    try:
        end_index = categories_order.index(end_cat)
    except ValueError:
        end_index = None
    
    # If we cannot find one of the categories, return both as a fallback.
    if start_index is None or end_index is None:
        return {start_cat, end_cat}
    
    # Calculate how many categories are spanned.
    span = end_index - start_index
    
    # If the event spans two or fewer categories (span <= 1), allow both.
    if span <= 1:
        return {start_cat, end_cat}
    else:
        # If the event spans two or more categories, allow only the start category and its immediate next.
        return {categories_order[start_index], categories_order[start_index + 1]}


def get_slot_categories(slot_start: str, slot_end: str) -> set:
    """
    Given an available slot defined by slot_start and slot_end (in HH:MM),
    returns the set of all time range categories that this slot spans.
    """
    categories_order = ["early_morning", "late_morning", "early_afternoon", "late_afternoon", "evening", "night"]
    start_cat = get_time_range_category(slot_start)
    end_cat = get_time_range_category(slot_end)
    
    try:
        start_index = categories_order.index(start_cat)
    except ValueError:
        start_index = None
    try:
        end_index = categories_order.index(end_cat)
    except ValueError:
        end_index = None
    
    # If we can't find indices, fall back to using the two categories
    if start_index is None or end_index is None:
        return {start_cat, end_cat}
    
    # If the slot spans overnight (end index < start index), combine from start_index to end and beginning to end_index.
    if end_index < start_index:
        slot_cats = set(categories_order[start_index:]) | set(categories_order[:end_index+1])
    else:
        slot_cats = set(categories_order[start_index:end_index+1])
    return slot_cats


def get_distribution_key_from_event(event: Dict[str, Any]) -> str:
    """
    Determine the distribution key for an event.
    
    - For events in the "Mental Wellness" subcategory:
         * If the event title (in lowercase) contains "meditation", return "meditation".
         * If the event title contains "journaling", return "journaling".
         * Otherwise, return "Mental Wellness".
    - For other events in the "Personal Development & Learning" category (but not Mental Wellness),
         return "Personal Development & Learning".
    - For all other events, return the event's subcategory.
    
    This function works with events in the original format (with a nested 'event_details')
    or already-unwrapped event dictionaries.
    """
    details = event.get("event_details", event)
    category = details.get("category", "").strip()
    subcategory = details.get("subCategory", "").strip()
    title_lower = details.get("event_title", "").lower()
    
    if subcategory == "Mental Wellness":
        if "meditation" in title_lower:
            return "Meditation"
        elif "journaling" in title_lower:
            return "Journaling"
        else:
            return "Mental Wellness"
    elif category == "Personal Development & Learning":
        # For events in Personal Development & Learning not in Mental Wellness,
        # use the category as the distribution key.
        return "Personal Development & Learning"
    else:
        return subcategory


def to_datetime(time_str):
    return datetime.strptime(time_str, "%H:%M")
    
    
def get_available_time_slots(fixed_events: List[dict], all_events: List[dict]) -> List[Tuple[str, str]]:
    """
    Calculates available time slots for scheduling, considering fixed events AND their buffer times.
    Only returns slots that are 20 minutes or longer.
    Clips events to the current day's window (dynamic start - 23:59).
    Handles overnight events by taking only the portion that falls into the current day.
    Day start is determined by the earliest daytime event (where start_time < end_time).
    """
    available_slots = []
    day_end = datetime.strptime("23:59", "%H:%M")
    occupied_slots = []
    min_slot_duration = timedelta(minutes=20)

    if not fixed_events:
        return []

    # Find earliest time from daytime events only
    daytime_events = [event for event in all_events 
                     if to_datetime(event['start_time']) <= to_datetime(event['end_time'])]
    
    if not daytime_events:
        return []  # No daytime events to base day start on
    
    day_start = min(
        to_datetime(event['start_time'])
        for event in daytime_events
    )

    # Process fixed events and handle overnight events
    for event in fixed_events:
        start_time = to_datetime(event['start_time'])
        end_time = to_datetime(event['end_time'])
        buffer_timedelta = timedelta(minutes=event['bufferTime'])
        
        # Handle overnight events (end time is less than start time)
        if end_time < start_time:
            # For overnight events, we only take the portion until end of day
            occupied_slots.append((start_time, day_end))
        else:
            # Normal event within the same day
            end_with_buffer = end_time + buffer_timedelta
            # Clip the end time (with buffer) to day_end if it exceeds it
            end_with_buffer = min(end_with_buffer, day_end)
            occupied_slots.append((start_time, end_with_buffer))

    # Sort occupied slots by start time
    occupied_slots.sort(key=lambda x: x[0])

    # Merge overlapping slots
    if occupied_slots:
        merged_slots = [occupied_slots[0]]
        for current_start, current_end in occupied_slots[1:]:
            last_start, last_end = merged_slots[-1]
            if current_start <= last_end:
                merged_slots[-1] = (last_start, max(last_end, current_end))
            else:
                merged_slots.append((current_start, current_end))
        occupied_slots = merged_slots

    # Find available slots
    current_time = day_start
    for occ_start, occ_end in occupied_slots:
        if current_time < occ_start:
            slot_duration = occ_start - current_time
            if slot_duration >= min_slot_duration:
                available_slots.append((current_time.strftime("%H:%M"), occ_start.strftime("%H:%M")))
        current_time = max(current_time, occ_end)

    # Add final slot if there's enough time
    if current_time < day_end:
        final_duration = day_end - current_time
        if final_duration >= min_slot_duration:
            available_slots.append((current_time.strftime("%H:%M"), day_end.strftime("%H:%M")))

    return available_slots


def is_fixed_activity(event: Dict[str, Any]) -> bool:
    if event.get("ismanual", False):
        return True
    fixed_keywords = ["Sleep Patterns", "Work Hour", "Meal Schedule", "Focus Work"]
    return any(keyword.lower() in str(event).lower() for keyword in fixed_keywords)


def get_llm_priority_score(event: Dict, date: str, distribution_tracker: DistributionTracker, priority_levels: Dict[str, int]) -> float:
    """
    Computes a combined score for an event that will be used to guide the LLM rescheduling.
    This score is a weighted average (50% base priority, 50% distribution) where a lower distribution 
    value (i.e. less frequent) improves the score.
    """
    # Base priority (normalize using your _get_priority_score function)
    # Here we assume _get_priority_score is defined similarly as before.
    def _get_priority_score_local(event: Dict) -> float:
        for key, priority in priority_levels.items():
            if key.lower() in str(event).lower():
                return priority / max(priority_levels.values())
        return 0.0

    priority_score = _get_priority_score_local(event)
    
    # Use the distribution key function (which takes into account special cases for Mental Wellness)
    distribution_key = get_distribution_key_from_event(event)
    distribution_score = 1.0 - distribution_tracker.get_distribution_score(date, distribution_key)
    
    # Weighted average: 50% from base priority and 50% from distribution factor.
    return (priority_score * 0.5) + (distribution_score * 0.5)


def format_schedule_for_prompt(events_list, date:str, distribution_tracker, priority_levels):
    # Sort events by start time
    sorted_events = sorted(events_list, key=lambda x: x['event_details']['start_time'])

    fixed_events = []
    conflicting_events = []
    all_non_fixed_events = []
    all_events = []

    for event in sorted_events:
        event_details = event['event_details']
        # Ensure times are in HH:MM format (with corrections if needed)
        start_time = event_details['start_time']
        if ":" not in start_time:
            if len(start_time) == 4:
                start_time = f"0{start_time[:1]}:{start_time[1:]}"
            elif len(start_time) == 3:
                start_time = f"0{start_time[:1]}:{start_time[1:]}"
            elif len(start_time) == 2:
                start_time = f"{start_time[:1]}:{start_time[1:]}0"
            elif len(start_time) == 1:
                start_time = f"0{start_time}:00"
        elif len(start_time) == 4:
            start_time = f"0{start_time}"
        end_time = event_details['end_time']
        if ":" not in end_time:
            if len(end_time) == 4:
                end_time = f"0{end_time[:1]}:{end_time[1:]}"
            elif len(end_time) == 3:
                end_time = f"0{end_time[:1]}:{end_time[1:]}"
            elif len(end_time) == 2:
                end_time = f"{end_time[:1]}:{end_time[1:]}0"
            elif len(end_time) == 1:
                end_time = f"0{end_time}:00"
        elif len(end_time) == 4:
            end_time = f"0{end_time}"

        event_details['start_time'] = start_time
        event_details['end_time'] = end_time

        if is_fixed_activity(event_details):
            fixed_events.append(event_details)
            all_events.append(event_details)
        elif event['conflict']:
            priority = get_llm_priority_score(event_details, date, distribution_tracker, priority_levels)
            conflicting_events.append((priority, event_details))
            event_details['priority'] = priority
            all_non_fixed_events.append(event_details)
            all_events.append(event_details)
        else:
            priority = get_llm_priority_score(event_details, date, distribution_tracker, priority_levels)
            event_details['priority'] = priority
            all_non_fixed_events.append(event_details)
            all_events.append(event_details)

    conflicting_events.sort(key=lambda x: x[0], reverse=True)
    all_non_fixed_events.sort(key=lambda x: x['priority'], reverse=True)

    conflicting_events_str = "\n".join([
        f"- Priority {p}: {e['event_title']} ({e['start_time']}-{e['end_time']}, Buffer: {e['bufferTime']} mins)"
        for p, e in conflicting_events
    ])

    current_schedule = "\n".join([
        f"[{e['event_details']['start_time']}-{e['event_details']['end_time']}] {e['event_details']['event_title']} (Buffer: {e['event_details']['bufferTime']} mins)"
        for e in sorted_events
    ])

    all_non_fixed_events_str = "\n".join([
        f"Title: {e['event_title']} - Description: {e['event_description']} - Category: {e['category']} - Subcategory: {e['subCategory']} - Priority: {e['priority']} - Current start: {e['start_time']} - end: {e['end_time']} - Buffer: {e['bufferTime']} mins"
        for e in all_non_fixed_events
    ])

    return {
        "all_events": all_events,
        "conflicting_events_list": conflicting_events_str,
        "formatted_schedule": current_schedule,
        "fixed_events_raw": fixed_events,
        "all_non_fixed_events": all_non_fixed_events,
        "all_non_fixed_events_raw": all_non_fixed_events_str
    }


def process_llm_event(event: Dict[str, Any]) -> Dict[str, Any]:
    """
    Process an LLM event by converting enum fields in 'category' and 'sub_category'
    to plain strings.
    """
    new_event = {}
    for key, value in event.items():
        # Check if the key is one of the enum fields and if its value is an Enum instance.
        if key in ("category", "sub_category", "subCategory") and isinstance(value, Enum):
            new_event[key] = value.value
        else:
            new_event[key] = value
    return new_event



prompt_template = """You are an expert schedule optimizer. Your task is to reschedule the provided events within the available time slots by following the RULES, CONSIDERATIONS AND EDGE CASE HANDLING GUIDELINES.
ABSOLUTE ADHERENCE TO THE RULES AND EDGE CASE HANDLING GUIDELINES ARE MANDATORY. 

**RULES**
1. The rescheduled event including its buffer time must be STRICTLY WITHIN the available time slot.
2. Reschedule events starting with the HIGHEST PRIORITY.
3. No event (including its buffer) should overlap with another.
4. Each event's original duration MUST BE PRESERVED.
5. Buffer times are CRITICAL and must be applied immediately after each event.
    - Example: If Event A ends at 09:00 with a 15-minute buffer, Event B cannot start before 09:15.

**OPTIONAL CONSIDERATIONS**
1. Try to schedule events in their contextually appropriate slots (e.g., morning routines in the morning).
    - Morning: 06:00-11:59, Afternoon: 12:00-17:59, Evening: 18:00-20:59, Night: 21:00-23:59

**EDGE CASE HANDLING GUIDELINES**
1. Reduce buffer times starting with the lowest priority events and reschedule.
2. If rescheduling is not possible even after reducing buffers, shorten the duration of the lowest priority events and reschedule.

## Events to be rescheduled:
{all_non_fixed_events_raw}

## Available time slots:
{available_time_slots_str}
"""


# --- Validation Functions ---

def validate_schedule(rescheduled_events: List[Dict[str, Any]], fixed_events: List[Dict[str, Any]], original_events: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
    validation_results = {
        "fixed_event_overlaps": [],
        "rescheduled_event_overlaps": [],
        "buffer_violations": [],
        "duration_mismatches": [],
        "buffer_fixed_overlaps": []
    }

    def to_datetime(time_str):
        return datetime.strptime(time_str, "%H:%M")

    # 1. Check overlaps with fixed events
    for event in rescheduled_events:
        event_start = to_datetime(event['start_time'])
        event_end = to_datetime(event['end_time'])
        for fixed_event in fixed_events:
            fixed_start = to_datetime(fixed_event['start_time'])
            fixed_end = to_datetime(fixed_event['end_time'])
            if not (event_end <= fixed_start or event_start >= fixed_end):
                validation_results["fixed_event_overlaps"].append({
                    "type": 1,
                    "error": "Rescheduled event overlaps with fixed event",
                    "rescheduled_event": event,
                    "fixed_event": fixed_event,
                    "overlap_period": {
                        "start": max(event_start, fixed_start).strftime("%H:%M"),
                        "end": min(event_end, fixed_end).strftime("%H:%M")
                    }
                })

    # 2. Check event buffer does not overlap with fixed events
    for event in rescheduled_events:
        event_end_dt = to_datetime(event['end_time'])
        buffer_end_dt = event_end_dt + timedelta(minutes=event['bufferTime'])
        for fixed_event in fixed_events:
            fixed_start_dt = to_datetime(fixed_event['start_time'])
            fixed_end_dt = to_datetime(fixed_event['end_time'])
            if not (buffer_end_dt <= fixed_start_dt or event_end_dt >= fixed_end_dt):
                validation_results["buffer_fixed_overlaps"].append({
                    "type": 2,
                    "error": "Event buffer overlaps with fixed event",
                    "rescheduled_event": event,
                    "fixed_event": fixed_event,
                    "buffer_period": {
                        "start": event_end_dt.strftime("%H:%M"),
                        "end": buffer_end_dt.strftime("%H:%M")
                    }
                })

    # 3. Check for overlaps between rescheduled events
    sorted_events = sorted(rescheduled_events, key=lambda x: to_datetime(x['start_time']))
    for i in range(len(sorted_events) - 1):
        event1 = sorted_events[i]
        event2 = sorted_events[i + 1]
        if to_datetime(event2['start_time']) < to_datetime(event1['end_time']):
            validation_results["rescheduled_event_overlaps"].append({
                "type": 3,
                "error": "Two rescheduled events overlap",
                "first_event": event1,
                "second_event": event2,
                "overlap_period": {
                    "start": to_datetime(event2['start_time']).strftime("%H:%M"),
                    "end": to_datetime(event1['end_time']).strftime("%H:%M")
                }
            })

    # 4. Check that required buffer time between rescheduled events is met
    for i in range(len(sorted_events) - 1):
        event1 = sorted_events[i]
        event2 = sorted_events[i + 1]
        event1_end = to_datetime(event1['end_time'])
        event2_start = to_datetime(event2['start_time'])
        required_buffer = timedelta(minutes=event1['bufferTime'])
        if event2_start < event1_end + required_buffer:
            validation_results["buffer_violations"].append({
                "type": 4,
                "error": "Buffer time not respected",
                "first_event": event1,
                "second_event": event2,
                "required_buffer": event1['bufferTime'],
                "actual_gap": (event2_start - event1_end).total_seconds() / 60
            })

    # 5. Ensure the duration is maintained compared to the original events
    original_durations = {
        e['event_details']['event_title']:
            to_datetime(e['event_details']['end_time']) - to_datetime(e['event_details']['start_time'])
        for e in original_events if not is_fixed_activity(e['event_details'])
    }
    for event in rescheduled_events:
        title = event['event_title']
        if title in original_durations:
            orig_duration = original_durations[title]
            new_duration = to_datetime(event['end_time']) - to_datetime(event['start_time'])
            if new_duration != orig_duration:
                validation_results["duration_mismatches"].append({
                    "type": 5,
                    "error": "Event duration has changed",
                    "event": event,
                    "original_duration": orig_duration.total_seconds() / 60,
                    "rescheduled_duration": new_duration.total_seconds() / 60
                })

    return validation_results


def process_daily_schedule(date: str, events: List[Dict], priority_levels: Dict[str, int],
                           distribution_data: Dict[str, Dict], llm) -> Tuple[List[Dict], Dict[str, Dict]]:
    """
    Processes a single day's schedule:
      1. Formats event data and computes available slots.
      2. Uses the LLM for initial rescheduling.
      3. Validates the LLM output.
      4. Identifies events that failed (including overlaps and buffer issues).
      5. Explicitly schedules long events (≥2.5 hours) and falls back to the ScheduleOptimizer
         for any remaining failed events.
      6. Returns the combined (optimized) schedule and updated distribution data.
    """

    # Check for vacation or holiday events
    vacation_or_holiday_flag = False
    for event in events:
        if event['event_details'].get('is_vacation_event', False):
            vacation_or_holiday_flag = True
        elif event['event_details'].get('subCategory') == "Holidays":
            vacation_or_holiday_flag = True
    
    # Initialize distribution tracker
    distribution_tracker = DistributionTracker(distribution_data)
    
    # Format schedule data early so we can check for conflicts
    schedule_data = format_schedule_for_prompt(events, date, distribution_tracker, priority_levels)
  
    final_schedule_dict = {}

    # If there are no conflicts or the day is a vacation/holiday, return the original events after updating distribution.
    if not schedule_data["conflicting_events_list"] or vacation_or_holiday_flag:
        for event in events:           
            dist_key = get_distribution_key_from_event(event)
            distribution_tracker.update_counts(date, dist_key) 

        final_schedule_dict["rescheduled_events"] = []
        final_schedule_dict["fixed_events"] = events    

        return final_schedule_dict, distribution_tracker.data
    
    final_fixed_events = schedule_data["fixed_events_raw"]
    available_slots = get_available_time_slots(final_fixed_events, schedule_data["all_events"])
            
    available_time_slots_str="\n".join([f"- {start}-{end}" for start, end in available_slots])
            
    # --- LLM Initial Scheduling ---
    prompt = prompt_template.format(
        all_non_fixed_events_raw=schedule_data["all_non_fixed_events_raw"],
        available_time_slots_str=available_time_slots_str
    )
    structured_llm = llm.with_structured_output(ListofConflictmanagedEvent)
    try:
        llm_output = structured_llm.invoke(prompt)
        llm_result = json.loads(llm_output.model_dump_json())
        llm_events = llm_result.get("events", [])
        
        # Process each event so that enum values become plain strings
        processed_llm_events = [process_llm_event(event) for event in llm_events]
        rescheduled_events = [ConflictManagedevent(**event).model_dump() for event in processed_llm_events]
        
    except Exception as e:
        rescheduled_events = []

    # --- Validation ---
    original_non_fixed = [e for e in events if e['conflict']]
    validation_results = validate_schedule(rescheduled_events, final_fixed_events, original_non_fixed)
    failed_events = []
    for error_list in validation_results.values():
        for error in error_list:
            if error['type'] in [1, 2]:
                failed_events.append(error['rescheduled_event'])                
            elif error['type'] in [3, 4]:
                failed_events.extend([error['first_event'], error['second_event']]) 
            # elif error['type'] == 5:
            #     failed_events.append(error['event'])
                
    # Remove duplicates based on event_title
    failed_events = list({e['event_title']: e for e in failed_events}.values())

    # --- Identify events missed by the LLM ---
    # Use the full list of non-fixed events from schedule_data["all_non_fixed_events"]
    all_non_fixed_events = schedule_data["all_non_fixed_events"]
    
    # Build a mapping from original event titles to the original event details.
    failed_events_original = []
    for fe in failed_events:
        failed_title = fe["event_title"]
        found_original = None
        for orig in all_non_fixed_events:
            orig_title = orig["event_title"]
            if are_titles_similar(failed_title, orig_title):
                found_original = orig
                break
        if found_original:
            failed_events_original.append(found_original)
        else:
            failed_events_original.append(fe)
    manual_optimize_events = failed_events_original   
     
    llm_event_titles = {e["event_title"] for e in rescheduled_events}
    missed_events = [event for event in all_non_fixed_events if event["event_title"] not in llm_event_titles]
    for missed in missed_events:
        if missed["event_title"] not in {e["event_title"] for e in failed_events}:
            manual_optimize_events.append(missed)

    # --- Combine events ---
    successful_llm_events = [e for e in rescheduled_events if e['event_title'] not in {fe['event_title'] for fe in failed_events}]

    # Combine fixed events and successful LLM events to calculate available slots for optimizer
    combined_fixed_events_for_optimizer = final_fixed_events + successful_llm_events
    available_slots_for_optimizer = get_available_time_slots(combined_fixed_events_for_optimizer, schedule_data["all_events"])

    optimizer = ScheduleOptimizer(distribution_tracker, priority_levels)
    to_dt = optimizer._to_datetime  # for convenience

    # --- Explicitly schedule long events that failed (>=2.5 hours) ---
    used_slots = set()
    scheduled_long_events = []
    long_failed_events = [
        e for e in manual_optimize_events if to_dt(e['end_time']) - to_dt(e['start_time']) >= timedelta(hours=2.5)
    ]
    
    for event in long_failed_events:
        # Only try if this event was not already scheduled by the LLM
        if not any(e['event_title'] == event['event_title'] for e in rescheduled_events):
            optimizer._schedule_event(event, date, available_slots_for_optimizer, used_slots, scheduled_long_events)

    # --- Optimize remaining failed events ---
    optimized_events = optimizer.optimize_failed_events(date, manual_optimize_events, available_slots_for_optimizer)
   
    final_schedule = successful_llm_events + scheduled_long_events + optimized_events

    # --- Update distribution counts ---
    for event in final_schedule:
        # distribution_tracker.update_counts(date, event.get('sub_category', event.get('subCategory')))

        dist_key = get_distribution_key_from_event(event)
        distribution_tracker.update_counts(date, dist_key)        

    # --- Remove any duplicates ---
    unique_final_schedule = remove_duplicate_dicts(final_schedule)

    # --- Add Fixed events ---
    final_schedule_dict["rescheduled_events"] = unique_final_schedule
    final_schedule_dict["fixed_events"] = final_fixed_events
    
    return final_schedule_dict, distribution_tracker.data


def check_conflicts_(events):
    """Check for time conflicts and buffer time issues in a list of events."""
    """If any event has subCategory "Holidays" or is_vacation_event == True,
    then all events will have conflict set to False and conflict_reason set to "".
    """
    # First, check for any holiday or vacation events
    for event in events:
        details = event.get('event_details', {})
        if details.get('subCategory') == "Holidays" or details.get('is_vacation_event') is True:
            # Found a holiday or vacation event, so clear conflicts for all events
            for ev in events:
                ev['conflict'] = False
                ev['conflict_reason'] = ""
            return events  # Exit early
    for i, event in enumerate(events):
        # Initialize conflict status for each event
        event['conflict'] = False
        event['conflict_reason'] = ""
        
        current_start = to_datetime(event['event_details']['start_time'])
        current_end = to_datetime(event['event_details']['end_time'])
        # current_buffer = event['event_details'].get('bufferTime', 0)

        for j, other_event in enumerate(events):
            if i == j:
                continue

            other_start = to_datetime(other_event['event_details']['start_time'])
            other_end = to_datetime(other_event['event_details']['end_time'])
            
            # Check for time overlap
            if current_start < other_end and current_end > other_start and current_start < current_end:
                event['conflict'] = True
                event['conflict_reason'] = "time overlap"
                break  # No need to check further if overlap is found

    return events

def filter_events(input_list):
    filtered_events = {}
    
    for event in input_list:
        # Create a key based on event_title and category
        details = event.get('event_details', {})
        key = (details.get('event_title', ''), details.get('category', ''))
        
        # If this key already exists, compare conflict flags.
        if key in filtered_events:
            existing_event = filtered_events[key]
            # If the existing event has conflict True and the new event is conflict False, update.
            if existing_event['conflict'] and not event['conflict']:
                filtered_events[key] = event
            # If both have the same conflict value, you can choose to keep the existing one.
        else:
            filtered_events[key] = event
    
    # Return the filtered events as a list.
    return list(filtered_events.values())


def filter_dinner_events(calendar_data):
    # Create a copy of the calendar data to modify
    filtered_calendar = calendar_data.copy()
    
    # Iterate through each date in the calendar
    for date, day_data in filtered_calendar.items():
        events = day_data.get('events', [])
        dinner_events = []
        
        # First pass: collect all dinner events
        for event in events:
            event_details = event.get('event_details', {})
            event_title = event_details.get('event_title', '').lower()
            sub_category = event_details.get('subCategory', '')
            
            if 'dinner' in event_title:
                dinner_events.append({
                    'event': event,
                    'sub_category': sub_category
                })
        
        # Check if we have multiple dinner events
        if len(dinner_events) > 1:
            # Check for different types of dinner events
            has_meal_schedule = any(de['sub_category'] == 'Meal Schedule' for de in dinner_events)
            has_family_time = any(de['sub_category'] == 'Family Time' for de in dinner_events)
            has_partner_time = any(de['sub_category'] == 'Partner Time' for de in dinner_events)
            
            # If Partner Time dinner exists, remove both Meal Schedule and Family Time dinners
            if has_partner_time:
                filtered_events = [
                    event for event in events
                    if not (
                        'dinner' in event.get('event_details', {}).get('event_title', '').lower() and
                        (event.get('event_details', {}).get('subCategory') == 'Meal Schedule' or
                         event.get('event_details', {}).get('subCategory') == 'Family Time')
                    )
                ]

                # Update the day's events
                filtered_calendar[date]['events'] = filtered_events
            
            # If no Partner Time dinner but both Meal Schedule and Family Time exist
            elif has_meal_schedule and has_family_time:
                # Remove Meal Schedule dinner events
                filtered_events = [
                    event for event in events
                    if not (
                        'dinner' in event.get('event_details', {}).get('event_title', '').lower() and
                        event.get('event_details', {}).get('subCategory') == 'Meal Schedule'
                    )
                ]
                # Update the day's events
                filtered_calendar[date]['events'] = filtered_events
    
    return filtered_calendar


def format_calendar_data(Calendar_data:dict, Conflict_Resolved_Data:dict):
    Date = list(Calendar_data.keys())[0]
    Day = Calendar_data[Date]["day"]
    events = Calendar_data[Date]["events"]

    Updated_Data = []
    for event in events:
        event.pop("conflict", None)
        event.pop("conflict_reason", None)
        if event['event_details']["subCategory"] == "Holidays" or event['event_details']["is_vacation_event"] == True:
            event["event_details"]["start_time"] = "00:00"
            event["event_details"]["end_time"] = "23:59"
        event_title = event["event_details"]["event_title"]
        category = event['event_details']["category"]
        Conflict_Resolved_Events = Conflict_Resolved_Data["rescheduled_events"]
        Fixed_Events = Conflict_Resolved_Data["fixed_events"]
        if Conflict_Resolved_Events:
            for resolved_events in Conflict_Resolved_Events:
                if resolved_events["event_title"] == event_title and resolved_events["category"] == category:
                    event["event_details"]["start_time"] = resolved_events["start_time"]
                    event["event_details"]["end_time"] = resolved_events["end_time"]
                    event["event_details"]["bufferTime"] = resolved_events["bufferTime"]
                    
                    Updated_Data.append(event)
            for fixedevents in Fixed_Events:
                if fixedevents["event_title"] == event_title and fixedevents["category"] == category:
                    Updated_Data.append(event)
        else:  
            for fixedevents in Fixed_Events:
                if fixedevents["event_details"]["event_title"] == event_title and fixedevents["event_details"]["category"] == category:
                    Updated_Data.append(event) 
    
    weekday_names = ["Sunday","Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"]
    for data in Updated_Data:
        event_title = data["event_details"]["event_title"]
        weekday_names_found = [i for i in weekday_names if i in event_title.split()]
        if weekday_names_found and weekday_names_found[0] != Day and not data["event_details"]["ismanual"]:
            data["event_details"]["event_title"] = event_title.replace(weekday_names_found[0],"").strip()

    conflict_checked_events = check_conflicts_(Updated_Data)
    # Get the filtered events list
    filtered_list = check_conflicts_(filter_events(conflict_checked_events))

    Calendar = {
        Date:{

            "day": Day,
            "events": filtered_list
        }
    }

    Updated_Calendar = filter_dinner_events(Calendar)
    # Sort events by start_time for each day
    for date, details in Updated_Calendar.items():
        if "events" in details:
            details["events"].sort(key=lambda event: event["event_details"]["start_time"])
        
        return {"Calendar":Updated_Calendar}
    

def schedule_optimization_task(calendar_data: Dict,
                               priority_levels: Dict[str, int],
                               distribution_data: Dict[str, Dict]) -> Dict[str, Any]:
    """
    This is the task function (to be run as a cron job or Celery task) that:
      - Accepts a day's events and distribution data.
      - Returns the optimized schedule and updated distribution data.
    """
    
    date = list(calendar_data.keys())[0]
    
    events = calendar_data[date].get('events', [])
    
    llm = ChatOpenAI(
        model="gpt-4o-mini",
        max_tokens=6400,
        temperature=0.25,
        top_p=0.01,
    )
    optimized_schedule, updated_distribution = process_daily_schedule(
        date, events, priority_levels, distribution_data, llm
    )
    
    formatted_calendar_data = format_calendar_data(Calendar_data=calendar_data,
                                                   Conflict_Resolved_Data=optimized_schedule)
    
    return {
        "Calendar": formatted_calendar_data,
        "distribution_data": updated_distribution
    }

if __name__ == "__main__":
    calendar_data = json.loads(sys.argv[1])
    priority_levels = json.loads(sys.argv[2])
    distribution_data = json.loads(sys.argv[3])
    try:
      output = schedule_optimization_task(
    calendar_data=calendar_data,
    priority_levels=priority_levels,
    distribution_data=distribution_data
)
      print(json.dumps(output))
    except Exception as e:
        output = None
        print('fali',e)