import qutip as qt
import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse as sp
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.mplot3d import Axes3D
import warnings
import os
import time
import scipy
from scipy.optimize import curve_fit

# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

# System parameters
N = 2  # 2x2 lattice (4 sites) - limited for spin simulation due to Hilbert space size
N_sites = N * N  # Total number of sites
MAX_EXCITATIONS = 6  # Truncate Hilbert space to states with at most this many excitations
epsilon = 15.0  # Single-particle energy (meV)
U = 2.0  # Coulomb interaction (meV)
t_base = 1.0  # Base tunneling rate (meV)
Delta_0 = 10.0  # Energy gap at 0K (meV)
Tc = 300  # Critical temperature (K)

# Spin-related parameters
B_field = 0.1  # External magnetic field (meV) - Zeeman splitting
J_ex = 0.2  # Exchange coupling between adjacent spins (meV)
g_factor = 2.0  # g-factor for electron spins
spin_orbit = 0.05  # Spin-orbit coupling strength (meV)

# Physical constants
kB = 0.08617333262  # Boltzmann's constant (meV/K)
hbar = 0.6582119569  # Reduced Planck constant (meV·ns)
e_charge = 1.0  # Elementary charge (normalized)
mu_B = 0.05788  # Bohr magneton (meV/T)

def get_basis_states(N_sites, max_excitations=None, include_spin=False):
    """
    Generate basis states, optionally limited to max_excitations.
    If include_spin=True, each site has 4 possible states: |00⟩, |01⟩, |10⟩, |11⟩
    (empty, spin-up, spin-down, doubly-occupied)
    
    Returns: list of basis state indices in computational basis
    """
    if include_spin:
        # With spin, each site has 4 possible states
        # This leads to 4^N_sites states, which grows very quickly
        # For practical reasons, we'll limit to states with a maximum number of excitations
        if max_excitations is None:
            return None  # Full Hilbert space would be too large to handle
        
        # For spin, we need a more sophisticated approach
        # We'll create states with up to max_excitations of either spin-up or spin-down
        basis_states = []
        
        # Generate all possible states with up to max_excitations
        # We represent a state as a tuple of (n_up_1, n_down_1, n_up_2, n_down_2, ...)
        # where n_up_i and n_down_i are the occupation numbers (0 or 1) for site i
        
        def generate_states(site, n_up_total, n_down_total):
            if site == N_sites:
                if n_up_total + n_down_total <= max_excitations:
                    return [tuple()]
                return []
            
            states = []
            for n_up in range(2):  # 0 or 1 spin-up electron
                for n_down in range(2):  # 0 or 1 spin-down electron
                    # Skip doubly occupied states if they would exceed max_excitations
                    if n_up_total + n_down_total + n_up + n_down <= max_excitations:
                        for state in generate_states(site + 1, n_up_total + n_up, n_down_total + n_down):
                            states.append((n_up, n_down) + state)
            return states
        
        occupation_states = generate_states(0, 0, 0)
        
        # Convert occupation states to basis indices
        for state in occupation_states:
            # Convert the tuple of occupation numbers to binary string
            bits = ''.join(str(bit) for pair in zip(state[::2], state[1::2]) for bit in pair)
            # Alternative if the above is problematic:
            # bits = ''
            # for i in range(0, len(state), 2):
            #     bits += str(state[i]) + str(state[i+1])
            
            basis_states.append(int(bits, 2))
        
        return basis_states
    else:
        # Without spin, each site has 2 possible states (empty or occupied)
        if max_excitations is None:
            return list(range(2**N_sites))
        
        basis_states = []
        for i in range(2**N_sites):
            binary = format(i, f'0{N_sites}b')
            if binary.count('1') <= max_excitations:
                basis_states.append(i)
        
        return basis_states

def create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=None):
    """
    Create Hamiltonian with spin degrees of freedom.
    Each site now has 4 possible states: empty, spin-up, spin-down, doubly-occupied.
    
    Parameters:
    - N: lattice dimension (NxN)
    - epsilon: single-particle energy
    - t_base: tunneling amplitude
    - U: Coulomb interaction
    - B_field: external magnetic field (Zeeman splitting)
    - J_ex: exchange coupling
    - spin_orbit: spin-orbit coupling strength
    - T: temperature
    - max_excitations: maximum number of excitations (to truncate Hilbert space)
    """
    N_sites = N * N
    
    # For a system with spin, the Hilbert space is much larger
    # Each site has 4 states: |00⟩, |01⟩, |10⟩, |11⟩ (empty, up, down, double)
    # This leads to 4^N_sites states, which grows very quickly
    
    # Due to computational constraints, we need to truncate the Hilbert space
    # We'll create a mapping from our truncated basis to the full Hilbert space
    
    # Get basis states (limited to max_excitations)
    basis_states = get_basis_states(N_sites, max_excitations, include_spin=True)
    dim = len(basis_states)
    
    # Create inverse mapping for quick lookups
    basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
    
    # Initialize sparse Hamiltonian
    H_data = sp.lil_matrix((dim, dim), dtype=complex)
    
    # Temperature-dependent parameters
    if T < Tc:
        scaling = (1 - T/Tc)**(1/3)
    else:
        scaling = 0.001  # Small non-zero value to avoid singularities
    
    epsilon_T = epsilon * scaling
    t = t_base * scaling
    U_T = U * scaling
    J_ex_T = J_ex * scaling
    
    # Generate tunneling pairs between adjacent sites
    tunneling_pairs = []
    for i in range(N):
        for j in range(N):
            site = i * N + j
            if j < N - 1:  # Right neighbor
                tunneling_pairs.append((site, site + 1))
            if i < N - 1:  # Down neighbor
                tunneling_pairs.append((site, site + N))
    
    # Add terms to Hamiltonian
    for idx, state_idx in enumerate(basis_states):
        # Convert state index to binary representation
        state_binary = format(state_idx, f'0{2*N_sites}b')
        
        # Extract occupation numbers for each site and spin
        occupations = []
        for site in range(N_sites):
            n_up = int(state_binary[2*site])
            n_down = int(state_binary[2*site + 1])
            occupations.append((n_up, n_down))
        
        # 1. On-site energy (single-particle)
        energy = 0
        for site, (n_up, n_down) in enumerate(occupations):
            energy += epsilon_T * (n_up + n_down)
        
        # 2. On-site Coulomb repulsion
        for site, (n_up, n_down) in enumerate(occupations):
            energy += U_T * n_up * n_down  # Cost for double occupation
        
        # 3. Zeeman splitting
        for site, (n_up, n_down) in enumerate(occupations):
            energy += B_field * (n_up - n_down) / 2  # Spin-up gets +B/2, spin-down gets -B/2
        
        # Add the diagonal energy term
        H_data[idx, idx] = energy
        
        # 4. Tunneling terms - separate for spin-up and spin-down
        for i, j in tunneling_pairs:
            # Try tunneling spin-up from i to j
            if occupations[i][0] == 1 and occupations[j][0] == 0:
                new_occupations = occupations.copy()
                new_occupations[i] = (0, new_occupations[i][1])
                new_occupations[j] = (1, new_occupations[j][1])
                
                # Convert back to state index
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                new_state = int(new_binary, 2)
                
                # If new state is in our basis, add tunneling term
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    H_data[idx, new_idx] -= t
                    H_data[new_idx, idx] -= t  # Hermitian conjugate
            
            # Try tunneling spin-down from i to j
            if occupations[i][1] == 1 and occupations[j][1] == 0:
                new_occupations = occupations.copy()
                new_occupations[i] = (new_occupations[i][0], 0)
                new_occupations[j] = (new_occupations[j][0], 1)
                
                # Convert back to state index
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                new_state = int(new_binary, 2)
                
                # If new state is in our basis, add tunneling term
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    H_data[idx, new_idx] -= t
                    H_data[new_idx, idx] -= t  # Hermitian conjugate
        
        # 5. Exchange interaction (S_i · S_j)
        for i, j in tunneling_pairs:
            S_i_plus = None
            S_i_minus = None
            S_j_plus = None
            S_j_minus = None
            S_i_z = None
            S_j_z = None
            
            # S_i^+ flips spin from down to up at site i
            if occupations[i][1] == 1 and occupations[i][0] == 0:
                new_occupations = occupations.copy()
                new_occupations[i] = (1, 0)  # Flip down to up
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                S_i_plus = int(new_binary, 2)
            
            # S_i^- flips spin from up to down at site i
            if occupations[i][0] == 1 and occupations[i][1] == 0:
                new_occupations = occupations.copy()
                new_occupations[i] = (0, 1)  # Flip up to down
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                S_i_minus = int(new_binary, 2)
            
            # S_j^+ flips spin from down to up at site j
            if occupations[j][1] == 1 and occupations[j][0] == 0:
                new_occupations = occupations.copy()
                new_occupations[j] = (1, 0)  # Flip down to up
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                S_j_plus = int(new_binary, 2)
            
            # S_j^- flips spin from up to down at site j
            if occupations[j][0] == 1 and occupations[j][1] == 0:
                new_occupations = occupations.copy()
                new_occupations[j] = (0, 1)  # Flip up to down
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                S_j_minus = int(new_binary, 2)
            
            # S_i^z is (n_up - n_down)/2 at site i
            S_i_z = (occupations[i][0] - occupations[i][1]) / 2
            
            # S_j^z is (n_up - n_down)/2 at site j
            S_j_z = (occupations[j][0] - occupations[j][1]) / 2
            
            # Add S_i^z S_j^z term
            energy += J_ex_T * S_i_z * S_j_z
            
            # Add S_i^+ S_j^- term
            if S_i_plus is not None and S_j_minus is not None:
                if S_i_plus in basis_lookup and S_j_minus in basis_lookup:
                    double_flip = None
                    new_occupations = occupations.copy()
                    new_occupations[i] = (1, 0)  # Flip down to up at site i
                    new_occupations[j] = (0, 1)  # Flip up to down at site j
                    new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                    double_flip = int(new_binary, 2)
                    
                    if double_flip in basis_lookup:
                        new_idx = basis_lookup[double_flip]
                        H_data[idx, new_idx] += J_ex_T / 2
                        H_data[new_idx, idx] += J_ex_T / 2  # Hermitian conjugate
            
            # Add S_i^- S_j^+ term
            if S_i_minus is not None and S_j_plus is not None:
                if S_i_minus in basis_lookup and S_j_plus in basis_lookup:
                    double_flip = None
                    new_occupations = occupations.copy()
                    new_occupations[i] = (0, 1)  # Flip up to down at site i
                    new_occupations[j] = (1, 0)  # Flip down to up at site j
                    new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                    double_flip = int(new_binary, 2)
                    
                    if double_flip in basis_lookup:
                        new_idx = basis_lookup[double_flip]
                        H_data[idx, new_idx] += J_ex_T / 2
                        H_data[new_idx, idx] += J_ex_T / 2  # Hermitian conjugate
        
        # 6. Spin-orbit coupling
        for i, j in tunneling_pairs:
            # Spin-orbit couples spin and orbital motion
            # Spin-up can tunnel and flip to spin-down
            if occupations[i][0] == 1 and occupations[j][1] == 0:
                new_occupations = occupations.copy()
                new_occupations[i] = (0, new_occupations[i][1])
                new_occupations[j] = (new_occupations[j][0], 1)
                
                # Convert back to state index
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                new_state = int(new_binary, 2)
                
                # If new state is in our basis, add spin-orbit term
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    H_data[idx, new_idx] += 1j * spin_orbit
                    H_data[new_idx, idx] -= 1j * spin_orbit  # Anti-Hermitian
            
            # Spin-down can tunnel and flip to spin-up
            if occupations[i][1] == 1 and occupations[j][0] == 0:
                new_occupations = occupations.copy()
                new_occupations[i] = (new_occupations[i][0], 0)
                new_occupations[j] = (1, new_occupations[j][1])
                
                # Convert back to state index
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                new_state = int(new_binary, 2)
                
                # If new state is in our basis, add spin-orbit term
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    H_data[idx, new_idx] -= 1j * spin_orbit
                    H_data[new_idx, idx] += 1j * spin_orbit  # Anti-Hermitian
    
    # Convert to optimized format for computation
    H_data = H_data.tocsr()
    
    # Construct proper dimensions for QuTiP
    dims = [[dim], [dim]]
    
    return qt.Qobj(H_data, dims=dims)

def create_hamiltonian_spinless(N, epsilon, t_base, U, T, max_excitations=None):
    """
    Create Hamiltonian without spin degrees of freedom.
    This is the original implementation for comparison.
    """
    N_sites = N * N
    
    # Get basis states (potentially truncated)
    basis_states = get_basis_states(N_sites, max_excitations, include_spin=False)
    dim = len(basis_states)
    
    # Create inverse mapping for quick lookups
    basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
    
    # Initialize sparse Hamiltonian
    H_data = sp.lil_matrix((dim, dim), dtype=complex)
    
    # Temperature-dependent parameters
    if T < Tc:
        scaling = (1 - T/Tc)**(1/3)
    else:
        scaling = 0.001  # Small non-zero value to avoid singularities
    
    epsilon_T = epsilon * scaling
    t = t_base * scaling
    U_T = U * scaling
    
    # Generate tunneling pairs
    tunneling_pairs = []
    for i in range(N):
        for j in range(N):
            site = i * N + j
            if j < N - 1:  # Right neighbor
                tunneling_pairs.append((site, site + 1))
            if i < N - 1:  # Down neighbor
                tunneling_pairs.append((site, site + N))
    
    # Add terms to Hamiltonian
    # 1. Single-particle energy terms
    for basis_idx, state in enumerate(basis_states):
        state_binary = format(state, f'0{N_sites}b')
        
        # Count occupied sites and add energy
        occupied_sites = [i for i, bit in enumerate(state_binary) if bit == '1']
        energy = epsilon_T * len(occupied_sites)
        
        # Add Coulomb interaction between occupied sites
        for i, site1 in enumerate(occupied_sites):
            for site2 in occupied_sites[i+1:]:
                energy += U_T
        
        H_data[basis_idx, basis_idx] = energy
    
    # 2. Tunneling terms
    for i, j in tunneling_pairs:
        # For each pair of connected sites, add tunneling term if both states in truncated basis
        for basis_idx, state in enumerate(basis_states):
            state_binary = list(format(state, f'0{N_sites}b'))
            
            # Try tunneling from i to j
            if state_binary[N_sites-1-i] == '1' and state_binary[N_sites-1-j] == '0':
                # Create new state with electron moved from i to j
                new_binary = state_binary.copy()
                new_binary[N_sites-1-i] = '0'
                new_binary[N_sites-1-j] = '1'
                new_state = int(''.join(new_binary), 2)
                
                # If new state is in our basis, add tunneling term
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    H_data[basis_idx, new_idx] -= t
                    H_data[new_idx, basis_idx] -= t  # Hermitian conjugate
    
    # Convert to optimized format for computation
    H_data = H_data.tocsr()
    
    # Construct proper dimensions for QuTiP
    dims = [[dim], [dim]]
    
    return qt.Qobj(H_data, dims=dims)

def create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=None):
    """
    Create initial state with spin degrees of freedom.
    """
    basis_states = get_basis_states(N_sites, max_excitations, include_spin=True)
    dim = len(basis_states)
    basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
    
    if config == 'singlet_pair':
        # Create a singlet state (|↑↓⟩ - |↓↑⟩)/√2 on the first two sites
        state1_binary = '01' + '10' + '00' * (N_sites - 2)  # |↓↑⟩ on sites 0,1
        state2_binary = '10' + '01' + '00' * (N_sites - 2)  # |↑↓⟩ on sites 0,1
        
        state1 = int(state1_binary, 2)
        state2 = int(state2_binary, 2)
        
        # Check if these states are in our truncated basis
        if state1 in basis_lookup and state2 in basis_lookup:
            idx1 = basis_lookup[state1]
            idx2 = basis_lookup[state2]
            
            # Create superposition (|↑↓⟩ - |↓↑⟩)/√2
            psi = (qt.basis(dim, idx1) - qt.basis(dim, idx2)).unit()
            return psi
    
    elif config == 'spin_up':
        # Single spin-up on the first site
        state_binary = '10' + '00' * (N_sites - 1)
        state_idx = int(state_binary, 2)
        
        if state_idx in basis_lookup:
            return qt.basis(dim, basis_lookup[state_idx])
    
    elif config == 'spin_down':
        # Single spin-down on the first site
        state_binary = '01' + '00' * (N_sites - 1)
        state_idx = int(state_binary, 2)
        
        if state_idx in basis_lookup:
            return qt.basis(dim, basis_lookup[state_idx])
    
    # Default: return the ground state (all empty)
    state_binary = '00' * N_sites
    state_idx = int(state_binary, 2)
    
    if state_idx in basis_lookup:
        return qt.basis(dim, basis_lookup[state_idx])
    else:
        # If even the ground state is not in our basis (unlikely), return the first basis state
        return qt.basis(dim, 0)

def create_initial_state_spinless(N_sites, config='superposition', max_excitations=None):
    """
    Create initial state without spin degrees of freedom.
    This is the original implementation for comparison.
    """
    if max_excitations is None:
        # Full Hilbert space
        if config == 'superposition':
            state_list = [qt.basis(2, 0) for _ in range(N_sites)]
            state_list[0] = (qt.basis(2, 0) + qt.basis(2, 1)).unit()
            state_list[1] = (qt.basis(2, 0) + qt.basis(2, 1)).unit()
            return qt.tensor(state_list)
        elif config == 'single_excitation':
            state_list = [qt.basis(2, 0) for _ in range(N_sites)]
            state_list[0] = qt.basis(2, 1)
            return qt.tensor(state_list)
    else:
        # Truncated Hilbert space
        basis_states = get_basis_states(N_sites, max_excitations, include_spin=False)
        dim = len(basis_states)
        basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
        
        if config == 'superposition':
            # Try to create a superposition of ground and first excited state
            ground_state = int('0' * N_sites, 2)
            excited_state = int('1' + '0' * (N_sites - 1), 2)
            
            if ground_state in basis_lookup and excited_state in basis_lookup:
                ground_idx = basis_lookup[ground_state]
                excited_idx = basis_lookup[excited_state]
                
                return (qt.basis(dim, ground_idx) + qt.basis(dim, excited_idx)).unit()
            else:
                # If states are not in truncated basis, return the first basis state
                return qt.basis(dim, 0)
        
        elif config == 'single_excitation':
            excited_state = int('1' + '0' * (N_sites - 1), 2)
            
            if excited_state in basis_lookup:
                return qt.basis(dim, basis_lookup[excited_state])
            else:
                # If state is not in truncated basis, return the first basis state
                return qt.basis(dim, 0)
        
        # Default: return the first basis state
        return qt.basis(dim, 0)

def calculate_phase_memory(T, tau0=100, T0=77):
    """Calculate phase memory time using the model from Eq. 4"""
    return tau0 * (1 + (T0/T)**(2/3))

def create_spin_measurement_operators(N_sites, max_excitations):
    """
    Create operators to measure spin properties.
    
    Returns:
    - S_z operators for each site
    - S_x operators for each site
    - S_y operators for each site
    - Total S_z operator
    """
    basis_states = get_basis_states(N_sites, max_excitations, include_spin=True)
    dim = len(basis_states)
    basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
    
    S_z_ops = []
    S_x_ops = []
    S_y_ops = []
    
    for site in range(N_sites):
        # S_z operator for this site
        S_z = sp.lil_matrix((dim, dim), dtype=complex)
        
        # S_x and S_y operators require spin flips
        S_x = sp.lil_matrix((dim, dim), dtype=complex)
        S_y = sp.lil_matrix((dim, dim), dtype=complex)
        
        for idx, state_idx in enumerate(basis_states):
            # Convert state index to binary representation
            state_binary = format(state_idx, f'0{2*N_sites}b')
            
            # Extract occupation numbers for each site and spin
            occupations = []
            for s in range(N_sites):
                n_up = int(state_binary[2*s])
                n_down = int(state_binary[2*s + 1])
                occupations.append((n_up, n_down))
            
            # Measure S_z = (n_up - n_down)/2
            n_up = occupations[site][0]
            n_down = occupations[site][1]
            S_z[idx, idx] = (n_up - n_down) / 2
            
            # S_x and S_y involve spin flips
            # S_x = (S_+ + S_-)/2
            # S_y = (S_+ - S_-)/2i
            
            # S_+ flips spin from down to up
            if n_down == 1 and n_up == 0:
                new_occupations = occupations.copy()
                new_occupations[site] = (1, 0)  # Flip down to up
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                new_state = int(new_binary, 2)
                
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    S_x[idx, new_idx] = 0.5  # S_x component
                    S_y[idx, new_idx] = 0.5j  # S_y component
            
            # S_- flips spin from up to down
            if n_up == 1 and n_down == 0:
                new_occupations = occupations.copy()
                new_occupations[site] = (0, 1)  # Flip up to down
                new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                new_state = int(new_binary, 2)
                
                if new_state in basis_lookup:
                    new_idx = basis_lookup[new_state]
                    S_x[idx, new_idx] = 0.5  # S_x component
                    S_y[idx, new_idx] = -0.5j  # S_y component
        
        # Convert to QuTiP operators
        S_z_ops.append(qt.Qobj(S_z.tocsr(), dims=[[dim], [dim]]))
        S_x_ops.append(qt.Qobj(S_x.tocsr(), dims=[[dim], [dim]]))
        S_y_ops.append(qt.Qobj(S_y.tocsr(), dims=[[dim], [dim]]))
    
    # Total spin operators
    S_z_total = sum(S_z_ops)
    S_x_total = sum(S_x_ops)
    S_y_total = sum(S_y_ops)
    
    return S_z_ops, S_x_ops, S_y_ops, S_z_total, S_x_total, S_y_total

def create_charge_measurement_operators(N_sites, max_excitations, include_spin=False):
    """
    Create operators to measure charge (occupation) at each site.
    
    Returns a list of occupation operators, one for each site.
    """
    if include_spin:
        basis_states = get_basis_states(N_sites, max_excitations, include_spin=True)
        dim = len(basis_states)
        
        occupation_ops = []
        
        for site in range(N_sites):
            occupation = sp.lil_matrix((dim, dim), dtype=complex)
            
            for idx, state_idx in enumerate(basis_states):
                # Convert state index to binary representation
                state_binary = format(state_idx, f'0{2*N_sites}b')
                
                # Extract occupation numbers for this site (both spins)
                n_up = int(state_binary[2*site])
                n_down = int(state_binary[2*site + 1])
                
                # Total occupation is sum of up and down
                occupation[idx, idx] = n_up + n_down
            
            occupation_ops.append(qt.Qobj(occupation.tocsr(), dims=[[dim], [dim]]))
        
        return occupation_ops
    else:
        # Original spinless implementation
        basis_states = get_basis_states(N_sites, max_excitations, include_spin=False)
        dim = len(basis_states)
        basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
        
        e_ops = []
        
        for site in range(N_sites):
            occupation = sp.lil_matrix((dim, dim), dtype=complex)
            
            for basis_idx, state in enumerate(basis_states):
                state_binary = format(state, f'0{N_sites}b')
                if state_binary[N_sites-1-site] == '1':
                    occupation[basis_idx, basis_idx] = 1.0
            
            e_ops.append(qt.Qobj(occupation.tocsr(), dims=[[dim], [dim]]))
        
        return e_ops

def create_collapse_operators(N_sites, gamma, max_excitations, include_spin=False):
    """
    Create collapse operators for decoherence.
    
    With spin, we need operators for:
    - Spin dephasing (T2*)
    - Spin relaxation (T1)
    - Charge dephasing
    - Charge relaxation
    """
    if include_spin:
        basis_states = get_basis_states(N_sites, max_excitations, include_spin=True)
        dim = len(basis_states)
        basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
        
        c_ops_list = []
        
        # Parameters for different decoherence channels
        gamma_spin_dephasing = gamma  # Spin dephasing rate
        gamma_spin_relaxation = gamma / 2  # Spin relaxation rate (typically T1 ≈ 2*T2*)
        gamma_charge_dephasing = gamma  # Charge dephasing rate
        gamma_charge_relaxation = gamma / 2  # Charge relaxation rate
        
        for site in range(N_sites):
            # 1. Spin dephasing (affects S_z)
            spin_dephasing = sp.lil_matrix((dim, dim), dtype=complex)
            
            for idx, state_idx in enumerate(basis_states):
                state_binary = format(state_idx, f'0{2*N_sites}b')
                occupations = []
                for s in range(N_sites):
                    n_up = int(state_binary[2*s])
                    n_down = int(state_binary[2*s + 1])
                    occupations.append((n_up, n_down))
                
                # Dephasing proportional to S_z value
                n_up = occupations[site][0]
                n_down = occupations[site][1]
                if n_up == 1 and n_down == 0:  # Spin up
                    spin_dephasing[idx, idx] = 1.0
                elif n_up == 0 and n_down == 1:  # Spin down
                    spin_dephasing[idx, idx] = -1.0
            
            c_ops_list.append(np.sqrt(gamma_spin_dephasing) * qt.Qobj(spin_dephasing.tocsr(), dims=[[dim], [dim]]))
            
            # 2. Spin relaxation (S_+ and S_-)
            # S_+ (up to down relaxation)
            spin_up_to_down = sp.lil_matrix((dim, dim), dtype=complex)
            
            for idx, state_idx in enumerate(basis_states):
                state_binary = format(state_idx, f'0{2*N_sites}b')
                occupations = []
                for s in range(N_sites):
                    n_up = int(state_binary[2*s])
                    n_down = int(state_binary[2*s + 1])
                    occupations.append((n_up, n_down))
                
                n_up = occupations[site][0]
                n_down = occupations[site][1]
                
                # Relaxation from up to down
                if n_up == 1 and n_down == 0:
                    new_occupations = occupations.copy()
                    new_occupations[site] = (0, 1)  # Flip up to down
                    new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                    new_state = int(new_binary, 2)
                    
                    if new_state in basis_lookup:
                        new_idx = basis_lookup[new_state]
                        spin_up_to_down[idx, new_idx] = 1.0
            
            c_ops_list.append(np.sqrt(gamma_spin_relaxation) * qt.Qobj(spin_up_to_down.tocsr(), dims=[[dim], [dim]]))
            
            # S_- (down to up relaxation) - typically weaker due to thermal effects
            spin_down_to_up = sp.lil_matrix((dim, dim), dtype=complex)
            
            for idx, state_idx in enumerate(basis_states):
                state_binary = format(state_idx, f'0{2*N_sites}b')
                occupations = []
                for s in range(N_sites):
                    n_up = int(state_binary[2*s])
                    n_down = int(state_binary[2*s + 1])
                    occupations.append((n_up, n_down))
                
                n_up = occupations[site][0]
                n_down = occupations[site][1]
                
                # Relaxation from down to up (slower, due to thermal effects)
                if n_up == 0 and n_down == 1:
                    new_occupations = occupations.copy()
                    new_occupations[site] = (1, 0)  # Flip down to up
                    new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                    new_state = int(new_binary, 2)
                    
                    if new_state in basis_lookup:
                        new_idx = basis_lookup[new_state]
                        spin_down_to_up[idx, new_idx] = 1.0
            
            c_ops_list.append(np.sqrt(gamma_spin_relaxation * 0.5) * qt.Qobj(spin_down_to_up.tocsr(), dims=[[dim], [dim]]))
            
            # 3. Charge dephasing
            charge_dephasing = sp.lil_matrix((dim, dim), dtype=complex)
            
            for idx, state_idx in enumerate(basis_states):
                state_binary = format(state_idx, f'0{2*N_sites}b')
                occupations = []
                for s in range(N_sites):
                    n_up = int(state_binary[2*s])
                    n_down = int(state_binary[2*s + 1])
                    occupations.append((n_up, n_down))
                
                # Dephasing proportional to charge
                total_charge = occupations[site][0] + occupations[site][1]
                charge_dephasing[idx, idx] = total_charge
            
            c_ops_list.append(np.sqrt(gamma_charge_dephasing) * qt.Qobj(charge_dephasing.tocsr(), dims=[[dim], [dim]]))
            
            # 4. Charge relaxation (both spins)
            for spin_idx in range(2):  # 0 for up, 1 for down
                charge_relaxation = sp.lil_matrix((dim, dim), dtype=complex)
                
                for idx, state_idx in enumerate(basis_states):
                    state_binary = format(state_idx, f'0{2*N_sites}b')
                    occupations = []
                    for s in range(N_sites):
                        n_up = int(state_binary[2*s])
                        n_down = int(state_binary[2*s + 1])
                        occupations.append((n_up, n_down))
                    
                    # Relaxation loses an electron
                    if occupations[site][spin_idx] == 1:
                        new_occupations = occupations.copy()
                        if spin_idx == 0:
                            new_occupations[site] = (0, new_occupations[site][1])
                        else:
                            new_occupations[site] = (new_occupations[site][0], 0)
                        
                        new_binary = ''.join(f"{n_up}{n_down}" for n_up, n_down in new_occupations)
                        new_state = int(new_binary, 2)
                        
                        if new_state in basis_lookup:
                            new_idx = basis_lookup[new_state]
                            charge_relaxation[idx, new_idx] = 1.0
                
                c_ops_list.append(np.sqrt(gamma_charge_relaxation) * qt.Qobj(charge_relaxation.tocsr(), dims=[[dim], [dim]]))
        
        return c_ops_list
    else:
        # Original spinless implementation
        basis_states = get_basis_states(N_sites, max_excitations, include_spin=False)
        dim = len(basis_states)
        basis_lookup = {state: idx for idx, state in enumerate(basis_states)}
        
        c_ops_list = []
        
        for site in range(N_sites):
            # Dephasing operator
            dephasing = sp.lil_matrix((dim, dim), dtype=complex)
            
            for basis_idx, state in enumerate(basis_states):
                state_binary = format(state, f'0{N_sites}b')
                if state_binary[N_sites-1-site] == '1':
                    dephasing[basis_idx, basis_idx] = 1.0
            
            c_ops_list.append(np.sqrt(gamma) * qt.Qobj(dephasing.tocsr(), dims=[[dim], [dim]]))
            
            # Relaxation operator
            relaxation = sp.lil_matrix((dim, dim), dtype=complex)
            
            for basis_idx, state in enumerate(basis_states):
                state_binary = format(state, f'0{N_sites}b')
                if state_binary[N_sites-1-site] == '1':
                    new_binary = list(state_binary)
                    new_binary[N_sites-1-site] = '0'
                    new_state = int(''.join(new_binary), 2)
                    
                    if new_state in basis_lookup:
                        new_idx = basis_lookup[new_state]
                        relaxation[basis_idx, new_idx] = 1.0
            
            c_ops_list.append(np.sqrt(gamma/2) * qt.Qobj(relaxation.tocsr(), dims=[[dim], [dim]]))
        
        return c_ops_list

def simulate_spin_decoherence(H, psi0, T, tmax=100, num_points=100, include_spin=True, max_excitations=None):
    """
    Simulate time evolution with decoherence, including spin effects.
    
    Returns:
    - times: time points
    - charge_data: site occupation probabilities
    - spin_data: spin expectation values
    - tau_phi: phase memory time
    """
    N_sites = N * N
    tau_phi = calculate_phase_memory(T)
    gamma = 1/tau_phi if tau_phi > 0 else 0
    
    # Create collapse operators for decoherence
    c_ops_list = create_collapse_operators(N_sites, gamma, max_excitations, include_spin)
    
    # Create density matrix from pure state
    rho0 = psi0 * psi0.dag()
    
    # Define times for evolution
    times = np.linspace(0, tmax, num_points)
    
    # Create measurement operators
    if include_spin:
        # For spin simulations, measure both charge and spin
        charge_ops = create_charge_measurement_operators(N_sites, max_excitations, include_spin)
        S_z_ops, S_x_ops, S_y_ops, S_z_total, S_x_total, S_y_total = create_spin_measurement_operators(N_sites, max_excitations)
        
        e_ops = charge_ops + S_z_ops + S_x_ops + S_y_ops + [S_z_total, S_x_total, S_y_total]
    else:
        # For spinless simulations, just measure charge
        e_ops = create_charge_measurement_operators(N_sites, max_excitations, include_spin)
    
    # Create solver options
    options = qt.Options(nsteps=5000, atol=1e-7, rtol=1e-5, max_step=tmax/50)
    
    # Simulate the system evolution
    print(f"Starting simulation at T={T}K with spin={include_spin}")
    start_time = time.time()
    
    result = qt.mesolve(H, rho0, times, c_ops=c_ops_list, e_ops=e_ops, options=options)
    
    end_time = time.time()
    print(f"Simulation completed in {end_time - start_time:.1f} seconds")
    
    # Extract results
    if include_spin:
        charge_data = np.array([result.expect[i] for i in range(N_sites)])
        spin_z_data = np.array([result.expect[i] for i in range(N_sites, 2*N_sites)])
        spin_x_data = np.array([result.expect[i] for i in range(2*N_sites, 3*N_sites)])
        spin_y_data = np.array([result.expect[i] for i in range(3*N_sites, 4*N_sites)])
        total_spin = np.array([result.expect[i] for i in range(4*N_sites, 4*N_sites+3)])
        
        return times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, tau_phi
    else:
        # Just return charge data for spinless case
        charge_data = np.array([result.expect[i] for i in range(len(e_ops))])
        return times, charge_data, tau_phi

def analyze_temperature_dependence_with_spin(temps=[77, 150, 225, 300], tmax=100):
    """
    Analyze system behavior at different temperatures, including spin effects.
    """
    N_sites = N * N
    charge_data_list = []
    spin_z_data_list = []
    spin_x_data_list = []
    spin_y_data_list = []
    total_spin_list = []
    tau_phi_values = []
    energy_gaps = []
    
    # Use truncated Hilbert space for efficiency
    psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
    
    for T in temps:
        # Create Hamiltonian at this temperature
        H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
        
        # Calculate energy gap
        eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=5)
        gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
        energy_gaps.append(gap)
        
        # Simulate time evolution
        times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, tau_phi = \
            simulate_spin_decoherence(H, psi0, T, tmax, include_spin=True, max_excitations=MAX_EXCITATIONS)
        
        charge_data_list.append(charge_data)
        spin_z_data_list.append(spin_z_data)
        spin_x_data_list.append(spin_x_data)
        spin_y_data_list.append(spin_y_data)
        total_spin_list.append(total_spin)
        tau_phi_values.append(tau_phi)
        
        print(f"Temperature: {T}K, Theoretical τ_φ: {tau_phi:.2f} ns, Energy Gap: {gap:.4f} meV")
    
    return temps, times, charge_data_list, spin_z_data_list, spin_x_data_list, spin_y_data_list, total_spin_list, tau_phi_values, energy_gaps

def analyze_temperature_dependence_spinless(temps=[77, 150, 225, 300], tmax=100):
    """
    Analyze system behavior at different temperatures without spin effects.
    This is the original implementation for comparison.
    """
    N_sites = N * N
    coherence_data = []
    tau_phi_values = []
    energy_gaps = []
    
    # Use truncated Hilbert space for efficiency
    psi0 = create_initial_state_spinless(N_sites, config='superposition', max_excitations=MAX_EXCITATIONS)
    
    for T in temps:
        # Create Hamiltonian at this temperature
        H = create_hamiltonian_spinless(N, epsilon, t_base, U, T, max_excitations=MAX_EXCITATIONS)
        
        # Calculate energy gap
        eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=2)
        gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
        energy_gaps.append(gap)
        
        # Simulate time evolution
        times, occupations, tau_phi = simulate_spin_decoherence(H, psi0, T, tmax, include_spin=False, max_excitations=MAX_EXCITATIONS)
        coherence_data.append(occupations)
        tau_phi_values.append(tau_phi)
        
        print(f"Temperature: {T}K, Theoretical τ_φ: {tau_phi:.2f} ns, Energy Gap: {gap:.4f} meV")
    
    return temps, times, coherence_data, tau_phi_values, energy_gaps

def plot_quantum_dot_lattice():
    """Plot 3D visualization of the quantum dot lattice"""
    print("Generating quantum dot lattice plot...")
    try:
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111, projection='3d')
        
        x, y = np.meshgrid(range(N), range(N))
        x = x.flatten()
        y = y.flatten()
        z = np.zeros(N*N)
        
        ax.scatter(x, y, z, s=200, c='blue', alpha=0.7, label='Quantum Dots')
        
        # Draw tunneling connections
        tunneling_pairs = []
        for i in range(N):
            for j in range(N):
                site = i * N + j
                if j < N - 1:
                    tunneling_pairs.append((site, site + 1))
                if i < N - 1:
                    tunneling_pairs.append((site, site + N))
        
        for i, j in tunneling_pairs:
            ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], 'k-', alpha=0.5)
        
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f'{N}x{N} Quantum Dot Lattice')
        ax.set_xlim(-0.5, N-0.5)
        ax.set_ylim(-0.5, N-0.5)
        ax.set_zlim(-0.5, 0.5)
        plt.tight_layout()
        plt.savefig('quantum_dot_lattice.png', dpi=300, bbox_inches='tight')
        print("Successfully saved quantum_dot_lattice.png")
    except Exception as e:
        print(f"Error saving quantum dot lattice plot: {str(e)}")
    finally:
        plt.close()

def plot_charge_occupations(temps, times, charge_data_list):
    """Plot charge occupation probabilities at different temperatures"""
    print("Generating charge occupation plot...")
    try:
        fig, axs = plt.subplots(2, 2, figsize=(12, 10))
        axs = axs.flatten()
        
        # Maximum number of sites to plot (to avoid cluttering)
        max_sites = min(4, N*N)
        colors = plt.cm.tab10(np.linspace(0, 1, max_sites))
        
        for i, T in enumerate(temps):
            ax = axs[i]
            for site in range(max_sites):
                # Convert to real part for plotting
                site_data = np.real(charge_data_list[i][site])
                ax.plot(times, site_data, c=colors[site], label=f'Site {site}')
            
            ax.set_xlabel('Time (ns)')
            ax.set_ylabel('Charge Occupation')
            ax.set_title(f'T = {T}K')
            ax.set_ylim(0, 1)
            ax.grid(True, alpha=0.3)
            if i == 0:
                ax.legend()
        
        plt.tight_layout()
        plt.savefig('charge_occupation_vs_temperature.png', dpi=300, bbox_inches='tight')
        print("Successfully saved charge_occupation_vs_temperature.png")
    except Exception as e:
        print(f"Error saving charge occupations plot: {str(e)}")
    finally:
        plt.close()

def plot_spin_occupations(temps, times, spin_z_data_list):
    """Plot spin_z occupation probabilities at different temperatures"""
    print("Generating spin occupation plot...")
    try:
        fig, axs = plt.subplots(2, 2, figsize=(12, 10))
        axs = axs.flatten()
        
        # Maximum number of sites to plot (to avoid cluttering)
        max_sites = min(4, N*N)
        colors = plt.cm.tab10(np.linspace(0, 1, max_sites))
        
        for i, T in enumerate(temps):
            ax = axs[i]
            for site in range(max_sites):
                # Convert to real part for plotting
                site_data = np.real(spin_z_data_list[i][site])
                ax.plot(times, site_data, c=colors[site], label=f'Site {site}')
            
            ax.set_xlabel('Time (ns)')
            ax.set_ylabel('Spin-z Expectation')
            ax.set_title(f'T = {T}K')
            ax.set_ylim(-0.5, 0.5)
            ax.grid(True, alpha=0.3)
            if i == 0:
                ax.legend()
        
        plt.tight_layout()
        plt.savefig('spin_z_vs_temperature.png', dpi=300, bbox_inches='tight')
        print("Successfully saved spin_z_vs_temperature.png")
    except Exception as e:
        print(f"Error saving spin occupations plot: {str(e)}")
    finally:
        plt.close()

def plot_total_spin(temps, times, total_spin_list):
    """Plot total spin components at different temperatures"""
    print("Generating total spin plot...")
    try:
        fig, axs = plt.subplots(2, 2, figsize=(12, 10))
        axs = axs.flatten()
        
        labels = ['S_z', 'S_x', 'S_y']
        colors = ['r', 'g', 'b']
        
        for i, T in enumerate(temps):
            ax = axs[i]
            for j, component in enumerate(range(3)):
                # Convert to real part for plotting
                data = np.real(total_spin_list[i][j])
                ax.plot(times, data, c=colors[j], label=labels[j])
            
            ax.set_xlabel('Time (ns)')
            ax.set_ylabel('Total Spin Expectation')
            ax.set_title(f'T = {T}K')
            ax.set_ylim(-0.5, 0.5)
            ax.grid(True, alpha=0.3)
            if i == 0:
                ax.legend()
        
        plt.tight_layout()
        plt.savefig('total_spin_vs_temperature.png', dpi=300, bbox_inches='tight')
        print("Successfully saved total_spin_vs_temperature.png")
    except Exception as e:
        print(f"Error saving total spin plot: {str(e)}")
    finally:
        plt.close()

def plot_decoherence_time(temps, tau_phi_values, with_spin=True):
    """Plot decoherence time as a function of temperature"""
    print("Generating decoherence time plot...")
    try:
        plt.figure(figsize=(10, 6))
        plt.plot(temps, tau_phi_values, 'o-', linewidth=2, markersize=8)
        
        # Fit curve
        def decoherence_model(T, a, T0):
            return a * (1 + (T0/T)**(2/3))
        
        popt, _ = curve_fit(decoherence_model, temps, tau_phi_values, p0=[100, 77])
        fit_temps = np.linspace(min(temps), max(temps), 100)
        fit_values = decoherence_model(fit_temps, *popt)
        
        plt.plot(fit_temps, fit_values, 'r--', linewidth=1.5, 
                 label=f'$τ_φ = {popt[0]:.2f}[1+({popt[1]:.2f}/T)^{{2/3}}]$')
        
        plt.xlabel('Temperature (K)')
        plt.ylabel('Decoherence Time $τ_φ$ (ns)')
        title_prefix = "Spin " if with_spin else ""
        plt.title(f'{title_prefix}Decoherence Time vs Temperature')
        plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.tight_layout()
        
        filename = 'spin_decoherence_vs_temperature.png' if with_spin else 'decoherence_vs_temperature.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
        return popt
    except Exception as e:
        print(f"Error saving decoherence time plot: {str(e)}")
        return [100, 77]
    finally:
        plt.close()

def plot_energy_gap(temps, energy_gaps, with_spin=True):
    """Plot energy gap as a function of temperature"""
    print("Generating energy gap plot...")
    try:
        plt.figure(figsize=(10, 6))
        
        # Convert to real part for plotting
        energy_gaps_real = np.real(energy_gaps)
        
        # Plot simulated gaps
        plt.plot(temps, energy_gaps_real, 'o-', linewidth=2, label='Simulated Gap')
        
        # Plot theoretical prediction
        t_values = np.linspace(0, max(temps), 100)
        predicted_gaps = []
        for T in t_values:
            if T < Tc:
                gap = Delta_0 * (1 - T/Tc)**(1/3)
            else:
                gap = 0
            predicted_gaps.append(gap)
        
        plt.plot(t_values, predicted_gaps, '--', 
                label='Predicted Gap $\\Delta_0(1-T/T_c)^{1/3}$')
        
        plt.xlabel('Temperature (K)')
        plt.ylabel('Energy Gap (meV)')
        title_prefix = "Spin " if with_spin else ""
        plt.title(f'{title_prefix}Energy Gap vs Temperature')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.tight_layout()
        
        filename = 'spin_energy_gap_vs_temperature.png' if with_spin else 'energy_gap_vs_temperature.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
    except Exception as e:
        print(f"Error saving energy gap plot: {str(e)}")
    finally:
        plt.close()

def calculate_collective_enhancement(N_values, N0=9, tau0=100):
    """Calculate collective enhancement of coherence time"""
    print("Generating collective enhancement plot...")
    try:
        tau_values = tau0 * np.sqrt(N_values / N0)
        
        plt.figure(figsize=(10, 6))
        plt.plot(N_values, tau_values, 'g-', linewidth=2)
        plt.xscale('log')
        plt.xlabel('Number of Quantum Dots (N)')
        plt.ylabel('Coherence Time (ns)')
        plt.title('Collective Enhancement of Coherence')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig('collective_enhancement.png', dpi=300, bbox_inches='tight')
        print("Successfully saved collective_enhancement.png")
        return N_values, tau_values
    except Exception as e:
        print(f"Error saving collective enhancement plot: {str(e)}")
        return N_values, tau0 * np.sqrt(N_values / N0)
    finally:
        plt.close()

def plot_spin_coherence_vs_magnetic_field(B_values, T=77):
    """Plot spin coherence time vs magnetic field strength"""
    print("Generating magnetic field dependence plot...")
    try:
        N_sites = N * N
        coherence_times = []
        spin_polarizations = []
        
        # Use fixed initial state
        psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
        
        for B in B_values:
            # Update global B_field
            global B_field
            B_field = B
            
            # Create Hamiltonian at this B-field
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            
            # Calculate spin coherence properties
            # Simplified: just evolve for a short time and check polarization
            times = np.linspace(0, 10, 20)  # Short evolution to check initial dynamics
            _, _, spin_z_data, _, _, total_spin, tau_phi = simulate_spin_decoherence(
                H, psi0, T, tmax=10, num_points=20, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Extract final z-polarization
            final_polarization = np.real(total_spin[0][-1])
            spin_polarizations.append(final_polarization)
            coherence_times.append(tau_phi)
        
        # Plot results
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        # Coherence time on left y-axis
        ax1.plot(B_values, coherence_times, 'b-o', label='Coherence Time')
        ax1.set_xlabel('Magnetic Field (meV)')
        ax1.set_ylabel('Coherence Time (ns)', color='b')
        ax1.tick_params(axis='y', labelcolor='b')
        
        # Spin polarization on right y-axis
        ax2 = ax1.twinx()
        ax2.plot(B_values, spin_polarizations, 'r-^', label='Spin Polarization')
        ax2.set_ylabel('Spin-z Polarization', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper center')
        
        plt.title('Spin Coherence vs Magnetic Field (T = {}K)'.format(T))
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        filename = 'spin_coherence_vs_magnetic_field.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
        return B_values, coherence_times, spin_polarizations
    except Exception as e:
        print(f"Error saving magnetic field dependence plot: {str(e)}")
        return B_values, [100] * len(B_values), [0] * len(B_values)
    finally:
        plt.close()

def analyze_spin_exchange_coupling(J_values, T=77):
    """Analyze the effect of exchange coupling strength on spin dynamics"""
    print("Generating exchange coupling analysis plot...")
    try:
        N_sites = N * N
        singlet_fidelities = []
        energy_gaps = []
        
        # Use singlet initial state
        psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
        
        for J in J_values:
            # Update global exchange coupling
            global J_ex
            J_ex = J
            
            # Create Hamiltonian with this exchange coupling
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            
            # Calculate energy gap
            eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=5)
            gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
            energy_gaps.append(gap)
            
            # Evolve for a specific time to measure singlet preservation
            tmax = 50  # ns
            times, _, spin_z_data, spin_x_data, spin_y_data, total_spin, _ = simulate_spin_decoherence(
                H, psi0, T, tmax=tmax, num_points=50, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Calculate singlet fidelity (simplified metric)
            # For a true singlet state, total spin should be zero
            total_spin_mag = np.sqrt(
                np.abs(total_spin[0])**2 + 
                np.abs(total_spin[1])**2 + 
                np.abs(total_spin[2])**2
            )
            avg_fidelity = 1 - np.mean(total_spin_mag)  # Higher value means closer to singlet
            singlet_fidelities.append(avg_fidelity)
        
        # Plot results
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        # Singlet fidelity on left y-axis
        ax1.plot(J_values, singlet_fidelities, 'b-o', label='Singlet Fidelity')
        ax1.set_xlabel('Exchange Coupling J (meV)')
        ax1.set_ylabel('Singlet Fidelity', color='b')
        ax1.tick_params(axis='y', labelcolor='b')
        ax1.set_ylim(0, 1)
        
        # Energy gap on right y-axis
        ax2 = ax1.twinx()
        ax2.plot(J_values, np.real(energy_gaps), 'r-^', label='Energy Gap')
        ax2.set_ylabel('Energy Gap (meV)', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper center')
        
        plt.title('Exchange Coupling Effects (T = {}K)'.format(T))
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        filename = 'spin_exchange_coupling_analysis.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
        return J_values, singlet_fidelities, energy_gaps
    except Exception as e:
        print(f"Error saving exchange coupling analysis plot: {str(e)}")
        return J_values, [0] * len(J_values), [0] * len(J_values)
    finally:
        plt.close()

def analyze_spin_orbit_effects(SO_values, T=77):
    """Analyze the effect of spin-orbit coupling on spin dynamics and transport"""
    print("Generating spin-orbit coupling analysis plot...")
    try:
        N_sites = N * N
        spin_flip_rates = []
        coherence_times = []
        
        # Start with spin-up state
        psi0 = create_initial_state_with_spin(N_sites, config='spin_up', max_excitations=MAX_EXCITATIONS)
        
        for so_coupling in SO_values:
            # Update global spin-orbit coupling
            global spin_orbit
            spin_orbit = so_coupling
            
            # Create Hamiltonian with this spin-orbit coupling
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            
            # Simulate time evolution
            tmax = 50  # ns
            times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, tau_phi = simulate_spin_decoherence(
                H, psi0, T, tmax=tmax, num_points=100, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Extract spin-flip rate (approximate from total spin-z)
            # For initial spin-up, the spin-z should decay with spin-orbit coupling
            spin_z_values = np.real(total_spin[0])
            if len(spin_z_values) > 1:
                # Fit exponential decay to get rate
                try:
                    def exp_decay(t, a, tau):
                        return a * np.exp(-t / tau)
                    
                    # Use absolute values to fit magnitude
                    abs_spin_z = np.abs(spin_z_values)
                    
                    # Only fit if there's decay
                    if abs_spin_z[0] > abs_spin_z[-1]:
                        popt, _ = curve_fit(exp_decay, times, abs_spin_z, p0=[abs_spin_z[0], 10])
                        flip_rate = 1 / popt[1] if popt[1] > 0 else 0
                    else:
                        flip_rate = 0
                except:
                    # If fitting fails, use simplified metric
                    flip_rate = (abs_spin_z[0] - abs_spin_z[-1]) / tmax if abs_spin_z[0] > 0 else 0
            else:
                flip_rate = 0
                
            spin_flip_rates.append(flip_rate)
            coherence_times.append(tau_phi)
        
        # Plot results
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        # Spin-flip rate on left y-axis
        ax1.plot(SO_values, spin_flip_rates, 'b-o', label='Spin-Flip Rate')
        ax1.set_xlabel('Spin-Orbit Coupling (meV)')
        ax1.set_ylabel('Spin-Flip Rate (ns$^{-1}$)', color='b')
        ax1.tick_params(axis='y', labelcolor='b')
        
        # Coherence time on right y-axis
        ax2 = ax1.twinx()
        ax2.plot(SO_values, coherence_times, 'r-^', label='Coherence Time')
        ax2.set_ylabel('Coherence Time (ns)', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper center')
        
        plt.title('Spin-Orbit Coupling Effects (T = {}K)'.format(T))
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        filename = 'spin_orbit_coupling_analysis.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
        return SO_values, spin_flip_rates, coherence_times
    except Exception as e:
        print(f"Error saving spin-orbit coupling analysis plot: {str(e)}")
        return SO_values, [0] * len(SO_values), [100] * len(SO_values)
    finally:
        plt.close()

def analyze_magnetic_field_dynamics(B_values=[0.05, 0.2, 0.5], T=77, tmax=50):
    """Analyze spin dynamics under different magnetic fields"""
    print("Generating magnetic field dynamics plot...")
    try:
        N_sites = N * N
        
        # Start with a superposition state (to see oscillations)
        psi0 = create_initial_state_with_spin(N_sites, config='spin_up', max_excitations=MAX_EXCITATIONS)
        
        fig, axs = plt.subplots(len(B_values), 3, figsize=(16, 4*len(B_values)))
        if len(B_values) == 1:
            axs = [axs]  # Make it 2D if only one B value
        
        for i, B in enumerate(B_values):
            # Update global B field
            global B_field
            B_field = B
            
            # Create Hamiltonian with this B field
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            
            # Simulate time evolution
            times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, _ = simulate_spin_decoherence(
                H, psi0, T, tmax=tmax, num_points=200, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Plot spin components for site 0
            axs[i, 0].plot(times, np.real(spin_x_data[0]), 'r-', label='Spin-X')
            axs[i, 0].plot(times, np.real(spin_y_data[0]), 'g-', label='Spin-Y')
            axs[i, 0].plot(times, np.real(spin_z_data[0]), 'b-', label='Spin-Z')
            axs[i, 0].set_title(f'Site 0 Spin Components (B = {B} meV)')
            axs[i, 0].set_xlabel('Time (ns)')
            axs[i, 0].set_ylabel('Spin Expectation')
            axs[i, 0].legend()
            axs[i, 0].grid(True, alpha=0.3)
            
            # Plot total spin components
            axs[i, 1].plot(times, np.real(total_spin[0]), 'r-', label='Total Spin-X')
            axs[i, 1].plot(times, np.real(total_spin[1]), 'g-', label='Total Spin-Y')
            axs[i, 1].plot(times, np.real(total_spin[2]), 'b-', label='Total Spin-Z')
            axs[i, 1].set_title(f'Total Spin Components (B = {B} meV)')
            axs[i, 1].set_xlabel('Time (ns)')
            axs[i, 1].set_ylabel('Spin Expectation')
            axs[i, 1].legend()
            axs[i, 1].grid(True, alpha=0.3)
            
            # Plot spin magnitude
            total_spin_mag = np.sqrt(
                np.abs(total_spin[0])**2 + 
                np.abs(total_spin[1])**2 + 
                np.abs(total_spin[2])**2
            )
            axs[i, 2].plot(times, np.real(total_spin_mag), 'k-', label='Spin Magnitude')
            axs[i, 2].set_title(f'Total Spin Magnitude (B = {B} meV)')
            axs[i, 2].set_xlabel('Time (ns)')
            axs[i, 2].set_ylabel('|S|')
            axs[i, 2].grid(True, alpha=0.3)
            
            # Calculate Larmor frequency
            if len(times) > 10:
                try:
                    # Find peaks in spin-x or spin-y to estimate oscillation period
                    from scipy.signal import find_peaks
                    peaks, _ = find_peaks(np.real(spin_x_data[0]))
                    
                    if len(peaks) >= 2:
                        avg_period = np.mean(np.diff(times[peaks]))
                        larmor_freq = 1 / avg_period if avg_period > 0 else 0
                        axs[i, 2].text(0.05, 0.9, f'Larmor freq: {larmor_freq:.3f} GHz', 
                                      transform=axs[i, 2].transAxes, bbox=dict(facecolor='white', alpha=0.7))
                except Exception as e:
                    print(f"Could not calculate Larmor frequency: {str(e)}")
        
        plt.tight_layout()
        
        filename = 'magnetic_field_spin_dynamics.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
    except Exception as e:
        print(f"Error saving magnetic field dynamics plot: {str(e)}")
    finally:
        plt.close()

def analyze_spin_charge_interaction(T=77, tmax=100):
    """Analyze how spin and charge degrees of freedom interact"""
    print("Generating spin-charge interaction plot...")
    try:
        N_sites = N * N
        
        # Use parameters that highlight spin-charge interaction
        global spin_orbit, J_ex, B_field
        original_so = spin_orbit
        original_J = J_ex
        original_B = B_field
        
        # Three scenarios
        scenarios = [
            {"name": "Base (Low Coupling)", "SO": 0.01, "J": 0.1, "B": 0.05},
            {"name": "Medium Coupling", "SO": 0.1, "J": 0.2, "B": 0.1},
            {"name": "Strong Coupling", "SO": 0.3, "J": 0.5, "B": 0.2}
        ]
        
        fig, axs = plt.subplots(len(scenarios), 2, figsize=(14, 4*len(scenarios)))
        
        # Run for each scenario
        for i, scenario in enumerate(scenarios):
            # Set parameters
            spin_orbit = scenario["SO"]
            J_ex = scenario["J"]
            B_field = scenario["B"]
            
            # Create initial state with charge and spin
            psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
            
            # Create Hamiltonian
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            
            # Simulate time evolution
            times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, _ = simulate_spin_decoherence(
                H, psi0, T, tmax=tmax, num_points=200, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Calculate correlation between charge and spin
            charge_spin_correlation = []
            for site in range(min(4, N_sites)):  # Show first few sites
                site_correlation = []
                for t in range(len(times)):
                    # Pearson correlation would be ideal but simplified here
                    # Using the product as a measure of correlation
                    corr = np.real(charge_data[site][t] * spin_z_data[site][t])
                    site_correlation.append(corr)
                charge_spin_correlation.append(site_correlation)
            
            # Plot charge and spin dynamics
            for site in range(min(4, N_sites)):
                axs[i, 0].plot(times, np.real(charge_data[site]), 
                             label=f'Charge Site {site}', linestyle='-', alpha=0.7)
                axs[i, 0].plot(times, np.real(spin_z_data[site]), 
                             label=f'Spin-z Site {site}', linestyle='--', alpha=0.7)
            
            axs[i, 0].set_title(f'Charge & Spin Dynamics: {scenario["name"]}')
            axs[i, 0].set_xlabel('Time (ns)')
            axs[i, 0].set_ylabel('Expectation Value')
            axs[i, 0].grid(True, alpha=0.3)
            axs[i, 0].legend(ncol=2, fontsize=8)
            
            # Plot charge-spin correlation
            for site in range(min(4, N_sites)):
                axs[i, 1].plot(times, charge_spin_correlation[site], 
                             label=f'Site {site} Correlation', alpha=0.7)
            
            axs[i, 1].set_title(f'Charge-Spin Correlation: {scenario["name"]}')
            axs[i, 1].set_xlabel('Time (ns)')
            axs[i, 1].set_ylabel('Correlation')
            axs[i, 1].grid(True, alpha=0.3)
            axs[i, 1].legend()
            
            # Add parameter information
            param_text = f"SO={scenario['SO']}, J={scenario['J']}, B={scenario['B']}"
            axs[i, 1].text(0.5, 0.95, param_text, transform=axs[i, 1].transAxes, 
                          horizontalalignment='center', bbox=dict(facecolor='white', alpha=0.7))
        
        plt.tight_layout()
        
        filename = 'spin_charge_interaction.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
        
        # Restore original parameters
        spin_orbit = original_so
        J_ex = original_J
        B_field = original_B
    except Exception as e:
        print(f"Error saving spin-charge interaction plot: {str(e)}")
        # Restore original parameters
        spin_orbit = original_so
        J_ex = original_J
        B_field = original_B
    finally:
        plt.close()

def compare_spinless_and_spin_models(temps=[77, 300], tmax=50):
    """Compare the spinless and spin-inclusive models at different temperatures"""
    print("Generating spinless vs spin model comparison plot...")
    try:
        N_sites = N * N
        
        # Initialize states
        psi0_spinless = create_initial_state_spinless(N_sites, config='superposition', max_excitations=MAX_EXCITATIONS)
        psi0_spin = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
        
        # Create figure with subplots
        fig, axs = plt.subplots(2, 2, figsize=(14, 12))
        axs = axs.flatten()
        
        # Plot occupation dynamics
        for i, T in enumerate(temps):
            # Spinless model
            H_spinless = create_hamiltonian_spinless(N, epsilon, t_base, U, T, max_excitations=MAX_EXCITATIONS)
            times, occ_spinless, tau_phi_spinless = simulate_spin_decoherence(
                H_spinless, psi0_spinless, T, tmax=tmax, include_spin=False, max_excitations=MAX_EXCITATIONS
            )
            
            # Spin model
            H_spin = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            times, charge_data, spin_z_data, _, _, total_spin, tau_phi_spin = simulate_spin_decoherence(
                H_spin, psi0_spin, T, tmax=tmax, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Plot occupations for spinless model (first site)
            axs[i].plot(times, np.real(occ_spinless[0]), 'b-', label='Spinless Model (Site 0)')
            
            # Plot charge for spin model (first site)
            axs[i].plot(times, np.real(charge_data[0]), 'r--', label='Spin Model (Site 0 Charge)')
            
            # Plot spin-z for spin model (first site)
            axs[i+2].plot(times, np.real(spin_z_data[0]), 'g-', label='Spin-z (Site 0)')
            
            # Plot total spin-z
            axs[i+2].plot(times, np.real(total_spin[0]), 'm--', label='Total Spin-z')
            
            # Set titles and labels
            axs[i].set_title(f'Charge Dynamics at T = {T}K')
            axs[i].set_xlabel('Time (ns)')
            axs[i].set_ylabel('Occupation Probability')
            axs[i].legend()
            axs[i].grid(True, alpha=0.3)
            
            axs[i+2].set_title(f'Spin Dynamics at T = {T}K')
            axs[i+2].set_xlabel('Time (ns)')
            axs[i+2].set_ylabel('Spin Expectation Value')
            axs[i+2].legend()
            axs[i+2].grid(True, alpha=0.3)
            
            # Print coherence times for comparison
            print(f"T={T}K: Spinless τ_φ = {tau_phi_spinless:.2f} ns, Spin τ_φ = {tau_phi_spin:.2f} ns")
        
        plt.tight_layout()
        
        filename = 'spinless_vs_spin_model_comparison.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
    except Exception as e:
        print(f"Error saving model comparison plot: {str(e)}")
    finally:
        plt.close()

def plot_spin_entanglement(temps=[77, 150, 225, 300], tmax=50):
    """Plot spin entanglement measures for the system over time at different temperatures"""
    print("Generating spin entanglement plot...")
    try:
        N_sites = N * N
        
        # Create figure with subplots
        fig, axs = plt.subplots(2, 2, figsize=(14, 12))
        axs = axs.flatten()
        
        # Initialize with spin-entangled state (singlet)
        psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
        
        for i, T in enumerate(temps):
            # Create Hamiltonian and evolve system
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
            
            times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, _ = simulate_spin_decoherence(
                H, psi0, T, tmax=tmax, include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Create a simplified entanglement measure
            # For a singlet state, individual spin-z values should be opposite (sum to zero)
            # while total spin magnitude should be minimal
            entanglement_measure = []
            
            for t in range(len(times)):
                # Correlation between first two sites' spins
                correlation = -np.real(spin_z_data[0][t] * spin_z_data[1][t])
                
                # Total spin should be close to zero for entangled singlet
                total_spin_mag = np.sqrt(
                    np.abs(total_spin[0][t])**2 + 
                    np.abs(total_spin[1][t])**2 + 
                    np.abs(total_spin[2][t])**2
                )
                
                # Higher correlation and lower total spin magnitude indicate entanglement
                entanglement = (correlation + (1 - total_spin_mag)) / 2
                entanglement_measure.append(max(0, min(1, entanglement)))  # Normalize to [0,1]
            
            # Plot entanglement measure
            axs[i].plot(times, entanglement_measure, 'b-', linewidth=2)
            
            # Add spins for reference
            axs[i].plot(times, np.abs(np.real(spin_z_data[0])), 'r--', label='|Spin-z Site 0|', alpha=0.5)
            axs[i].plot(times, np.abs(np.real(spin_z_data[1])), 'g--', label='|Spin-z Site 1|', alpha=0.5)
            
            # Set title and labels
            axs[i].set_title(f'Spin Entanglement at T = {T}K')
            axs[i].set_xlabel('Time (ns)')
            axs[i].set_ylabel('Entanglement Measure')
            axs[i].set_ylim(0, 1)
            axs[i].grid(True, alpha=0.3)
            if i == 0:
                axs[i].legend(loc='lower right')
        
        plt.tight_layout()
        
        filename = 'spin_entanglement_vs_temperature.png'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Successfully saved {filename}")
    except Exception as e:
        print(f"Error saving spin entanglement plot: {str(e)}")
    finally:
        plt.close()

def main_with_spin_effects():
    """Main function to run simulation with spin effects"""
    
    # Set the current working directory to where the script is located
    script_dir = os.path.dirname(os.path.abspath(__file__))
    print(f"Running from directory: {script_dir}")
    
    # Ensure outputs directory exists
    output_dir = os.path.join(script_dir, 'outputs')
    try:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Output will be saved to: {output_dir}")
    except Exception as e:
        print(f"Warning: Could not create outputs directory: {str(e)}")
        output_dir = script_dir
        print(f"Falling back to script directory for outputs: {output_dir}")
    
    # Change to the output directory for saving files
    os.chdir(output_dir)
    
    print(f"Starting simulation of {N}x{N} quantum dot lattice with spin effects")
    print(f"Hilbert space truncated to max {MAX_EXCITATIONS} excitations")
    
    start_time = time.time()
    
    # 1. Calculate energy gap at 0K with spin effects
    H_0K = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, 0, max_excitations=MAX_EXCITATIONS)
    eigenvalues = H_0K.eigenenergies(sparse=True, sort='low', eigvals=5)
    print(f"Lowest 5 energy levels with spin (meV): {eigenvalues}")
    print(f"Energy gap with spin at 0K: {eigenvalues[1] - eigenvalues[0]:.4f} meV")
    
    # 2. Analyze temperature dependence with spin effects
    temperatures = [77, 150, 225, 300]
    temps, times, charge_data_list, spin_z_data_list, spin_x_data_list, spin_y_data_list, total_spin_list, tau_phi_values, energy_gaps = analyze_temperature_dependence_with_spin(temperatures)
    
    # 3. For comparison, analyze temperature dependence without spin
    temps_spinless, times_spinless, coherence_data_spinless, tau_phi_values_spinless, energy_gaps_spinless = analyze_temperature_dependence_spinless(temperatures)
    
    # 4. Create visualizations
    plot_quantum_dot_lattice()
    plot_charge_occupations(temps, times, charge_data_list)
    plot_spin_occupations(temps, times, spin_z_data_list)
    plot_total_spin(temps, times, total_spin_list)
    
    # 5. Plot decoherence time comparison between spin and spinless models
    plt.figure(figsize=(10, 6))
    plt.plot(temps, tau_phi_values, 'o-', linewidth=2, markersize=8, label='With Spin')
    plt.plot(temps_spinless, tau_phi_values_spinless, 's--', linewidth=2, markersize=8, label='Without Spin')
    plt.xlabel('Temperature (K)')
    plt.ylabel('Decoherence Time $τ_φ$ (ns)')
    plt.title('Decoherence Time vs Temperature: Spin vs Spinless')
    plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig('decoherence_comparison_spin_vs_spinless.png', dpi=300, bbox_inches='tight')
    print("Successfully saved decoherence_comparison_spin_vs_spinless.png")
    plt.close()
    
    # 6. Plot energy gap comparison
    plt.figure(figsize=(10, 6))
    plt.plot(temps, np.real(energy_gaps), 'o-', linewidth=2, label='With Spin')
    plt.plot(temps_spinless, np.real(energy_gaps_spinless), 's--', linewidth=2, label='Without Spin')
    plt.xlabel('Temperature (K)')
    plt.ylabel('Energy Gap (meV)')
    plt.title('Energy Gap vs Temperature: Spin vs Spinless')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig('energy_gap_comparison_spin_vs_spinless.png', dpi=300, bbox_inches='tight')
    print("Successfully saved energy_gap_comparison_spin_vs_spinless.png")
    plt.close()
    
    # 7. Analyze magnetic field dependence
    B_values = np.linspace(0.05, 0.5, 10)  # Range of magnetic field values
    B_values, coherence_times, spin_polarizations = plot_spin_coherence_vs_magnetic_field(B_values)
    
    # 8. Detailed magnetic field dynamics
    B_select = [0.05, 0.2, 0.5]  # Selected B-field values for detailed analysis
    analyze_magnetic_field_dynamics(B_select)
    
    # 9. Analyze exchange coupling dependence
    J_values = np.linspace(0.05, 0.5, 10)  # Range of exchange coupling values
    J_values, singlet_fidelities, exchange_energy_gaps = analyze_spin_exchange_coupling(J_values)
    
    # 10. Analyze spin-orbit coupling
    SO_values = np.linspace(0.01, 0.2, 10)  # Range of spin-orbit coupling values
    SO_values, spin_flip_rates, so_coherence_times = analyze_spin_orbit_effects(SO_values)
    
    # 11. Analyze spin-charge interaction
    analyze_spin_charge_interaction()
    
    # 12. Compare spinless and spin models
    compare_spinless_and_spin_models()
    
    # 13. Plot spin entanglement
    plot_spin_entanglement()
    
    # 14. Analyze effects of disorder (fabrication imperfections)
    disorder_results = analyze_disorder_effects(disorder_levels=[0.0, 0.05, 0.1, 0.2], T=77, num_samples=5)
    
    # 15. Create comprehensive plot
    plt.figure(figsize=(15, 12))
    
    # Plot 1: Decoherence time comparison
    plt.subplot(2, 3, 1)
    plt.plot(temps, tau_phi_values, 'o-', linewidth=2, label='With Spin')
    plt.plot(temps_spinless, tau_phi_values_spinless, 's--', linewidth=2, label='Without Spin')
    plt.xlabel('Temperature (K)')
    plt.ylabel('Decoherence Time (ns)')
    plt.title('Decoherence Time Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Energy gap comparison
    plt.subplot(2, 3, 2)
    plt.plot(temps, np.real(energy_gaps), 'o-', linewidth=2, label='With Spin')
    plt.plot(temps_spinless, np.real(energy_gaps_spinless), 's--', linewidth=2, label='Without Spin')
    plt.xlabel('Temperature (K)')
    plt.ylabel('Energy Gap (meV)')
    plt.title('Energy Gap Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 3: Magnetic field effects
    plt.subplot(2, 3, 3)
    plt.plot(B_values, spin_polarizations, 'o-', color='purple', linewidth=2)
    plt.xlabel('Magnetic Field (meV)')
    plt.ylabel('Spin-z Polarization')
    plt.title('Magnetic Field Effects')
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Exchange coupling effects
    plt.subplot(2, 3, 4)
    plt.plot(J_values, singlet_fidelities, 'o-', color='green', linewidth=2)
    plt.xlabel('Exchange Coupling (meV)')
    plt.ylabel('Singlet Fidelity')
    plt.title('Exchange Coupling Effects')
    plt.grid(True, alpha=0.3)
    
    # Plot 5: Spin-orbit effects
    plt.subplot(2, 3, 5)
    plt.plot(SO_values, spin_flip_rates, 'o-', color='red', linewidth=2)
    plt.xlabel('Spin-Orbit Coupling (meV)')
    plt.ylabel('Spin-Flip Rate (ns⁻¹)')
    plt.title('Spin-Orbit Effects')
    plt.grid(True, alpha=0.3)
    
    # Plot 6: Spin vs Charge comparison at 77K
    plt.subplot(2, 3, 6)
    plt.plot(times, np.real(charge_data_list[0][0]), 'b-', label='Charge (Site 0)')
    plt.plot(times, np.real(spin_z_data_list[0][0]), 'r--', label='Spin-z (Site 0)')
    plt.xlabel('Time (ns)')
    plt.ylabel('Expectation Value')
    plt.title('Charge vs Spin Dynamics (77K)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('spin_effects_comprehensive_summary.png', dpi=300, bbox_inches='tight')
    print("Successfully saved spin_effects_comprehensive_summary.png")
    plt.close()
    
    # Print summary of results
    print("\nSPIN SIMULATION RESULTS SUMMARY")
    print("===============================")
    print(f"Energy gap with spin at 0K: {eigenvalues[1] - eigenvalues[0]:.4f} meV")
    print(f"Room temperature coherence time (with spin): {tau_phi_values[-1]:.2f} ns")
    print(f"77K coherence time (with spin): {tau_phi_values[0]:.2f} ns")
    print(f"Coherence time ratio (77K/300K): {tau_phi_values[0]/tau_phi_values[-1]:.2f}")
    print(f"Magnetic field effect on polarization (B=0.5): {spin_polarizations[-1]:.4f}")
    print(f"Exchange coupling effect on singlet fidelity (J=0.5): {singlet_fidelities[-1]:.4f}")
    print(f"Spin-orbit coupling effect on spin-flip rate (SO=0.2): {spin_flip_rates[-1]:.4f} ns⁻¹")
    
    print("\nComparison with spinless model:")
    print(f"Room temperature coherence time difference: {tau_phi_values[-1] - tau_phi_values_spinless[-1]:.2f} ns")
    print(f"Energy gap difference at 77K: {np.real(energy_gaps[0] - energy_gaps_spinless[0]):.4f} meV")
    
    print("\nAdditional quantum effects demonstrated with spin:")
    print("- Spin-dependent quantum tunneling")
    print("- Spin-orbit coupling effects")
    print("- Exchange interaction and spin entanglement")
    print("- Zeeman splitting in magnetic fields")
    print("- Spin coherence phenomena")
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"\nSpin simulation complete! All results saved as PNG files.")
    print(f"Total runtime: {total_time/60:.1f} minutes")

def calculate_phase_memory_alternative(T, model='standard', tau0=100, T0=77):
    """Calculate phase memory time using different decoherence models
    
    Parameters:
    - model: Decoherence model to use ('standard', 'exponential', 'power_law', 'ohmic')
    - tau0: Base coherence time at reference temperature (ns)
    - T0: Reference temperature (K)
    
    Returns: Coherence time (ns)
    """
    if model == 'standard':
        # Our current model: τ(T) = τ0[1+(T0/T)^(2/3)]
        return tau0 * (1 + (T0/T)**(2/3))
    
    elif model == 'exponential':
        # Exponential temperature dependence: τ(T) = τ0*exp(T0/T - 1)
        return tau0 * np.exp(T0/T - 1)
    
    elif model == 'power_law':
        # Simple power law: τ(T) = τ0*(T0/T)^α with α=1
        return tau0 * (T0/T)
    
    elif model == 'ohmic':
        # Ohmic spectral density inspired: τ(T) = τ0/(1 + (T/T0)^2)
        return tau0 / (1 + (T/T0)**2)
    
    else:
        # Default to standard model
        return tau0 * (1 + (T0/T)**(2/3))
        
def parameter_sensitivity_analysis(base_temp=77, parameter_ranges=None):
    """Perform sensitivity analysis on model parameters
    
    Parameters:
    - base_temp: Temperature for analysis (K)
    - parameter_ranges: Dictionary of parameter names and test ranges
    
    Returns: Dictionary of results for plotting
    """
    if parameter_ranges is None:
        # Default parameter ranges to test (as fractions of original values)
        parameter_ranges = {
            'epsilon': np.linspace(10.0, 20.0, 5),  # Single-particle energy
            'U': np.linspace(1.0, 3.0, 5),  # Coulomb interaction
            't_base': np.linspace(0.5, 1.5, 5),  # Tunneling rate
            'B_field': np.linspace(0.05, 0.2, 5),  # Magnetic field
            'J_ex': np.linspace(0.1, 0.3, 5),  # Exchange coupling
            'spin_orbit': np.linspace(0.01, 0.1, 5)  # Spin-orbit coupling
        }
    
    # Store original parameter values
    original_params = {
        'epsilon': epsilon,
        'U': U,
        't_base': t_base,
        'B_field': B_field,
        'J_ex': J_ex,
        'spin_orbit': spin_orbit
    }
    
    # Results storage
    results = {param: {'values': [], 'coherence': [], 'energy_gap': []} 
              for param in parameter_ranges.keys()}
    
    # Create initial state once
    N_sites = N * N
    psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
    
    # For each parameter
    for param_name, param_values in parameter_ranges.items():
        print(f"Testing sensitivity to {param_name}...")
        
        # For each test value
        for val in param_values:
            # Set global parameter
            globals()[param_name] = val
            
            # Create Hamiltonian with this parameter value
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, 
                                           base_temp, max_excitations=MAX_EXCITATIONS)
            
            # Calculate energy gap
            eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=5)
            gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
            
            # Run short simulation to get coherence time
            _, _, _, _, _, _, tau_phi = simulate_spin_decoherence(
                H, psi0, base_temp, tmax=10, num_points=20, 
                include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Store results
            results[param_name]['values'].append(val)
            results[param_name]['coherence'].append(tau_phi)
            results[param_name]['energy_gap'].append(np.real(gap))
        
        # Restore original parameter value
        globals()[param_name] = original_params[param_name]
    
    # Generate visualization
    plt.figure(figsize=(15, 10))
    
    for i, (param_name, data) in enumerate(results.items()):
        # Normalize parameter values for comparison
        norm_values = np.array(data['values']) / original_params[param_name]
        
        # Plot coherence time sensitivity
        plt.subplot(2, 3, i+1)
        plt.plot(norm_values, data['coherence'], 'o-', linewidth=2)
        plt.axvline(x=1.0, color='r', linestyle='--', alpha=0.5)  # Mark original value
        plt.xlabel(f'Normalized {param_name}')
        plt.ylabel('Coherence Time (ns)')
        plt.title(f'Sensitivity to {param_name}')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('parameter_sensitivity_analysis.png', dpi=300, bbox_inches='tight')
    print("Successfully saved parameter_sensitivity_analysis.png")
    plt.close()
    
    return results
    
def run_3x3_simulation():
    """Run a more hardware-friendly 3x3 lattice simulation with spin effects"""
    # Set parameters for 3x3 simulation
    global N, MAX_EXCITATIONS

    # Save original parameters
    original_N = N
    original_MAX_EXCITATIONS = MAX_EXCITATIONS
    
    N = 3  # 3x3 lattice
    MAX_EXCITATIONS = 2  # Reduce max excitations to keep Hilbert space manageable
    
    print(f"Starting 3x3 lattice simulation with MAX_EXCITATIONS={MAX_EXCITATIONS}")
    print(f"This will create a more limited but hardware-friendly simulation")
    
    # Create output directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, 'outputs')
    try:
        os.makedirs(output_dir, exist_ok=True)
        print(f"3x3 output will be saved to: {output_dir}")
    except Exception as e:
        print(f"Warning: Could not create 3x3 outputs directory: {str(e)}")
        output_dir = script_dir
    
    # Change to output directory
    original_dir = os.getcwd()
    os.chdir(output_dir)
    
    start_time = time.time()
    
    # Run a more limited set of analyses to save computation time
    
    # 1. Calculate energy gap at 0K with spin effects
    H_0K = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, 0, max_excitations=MAX_EXCITATIONS)
    eigenvalues = H_0K.eigenenergies(sparse=True, sort='low', eigvals=5)
    print(f"3x3 Lowest 5 energy levels with spin (meV): {eigenvalues}")
    print(f"3x3 Energy gap with spin at 0K: {eigenvalues[1] - eigenvalues[0]:.4f} meV")
    
    # 2. Analyze temperature dependence with limited temperatures
    temperatures = [77, 300]  # Just two temperatures to save computation
    
    # Initialize with spin-entangled state (singlet)
    N_sites = N * N
    psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
    
    results = {}
    
    for T in temperatures:
        print(f"Simulating 3x3 lattice at T={T}K")
        # Create Hamiltonian for this temperature
        H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, T, max_excitations=MAX_EXCITATIONS)
        
        # Calculate energy gap
        eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=5)
        gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
        
        # Simulate with reduced time points to save memory
        tmax = 50  # Maximum simulation time (ns)
        num_points = 50  # Reduced number of time points
        
        # Run simulation
        times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, tau_phi = simulate_spin_decoherence(
            H, psi0, T, tmax=tmax, num_points=num_points, include_spin=True, max_excitations=MAX_EXCITATIONS
        )
        
        # Store results
        results[T] = {
            'times': times,
            'charge_data': charge_data,
            'spin_z_data': spin_z_data,
            'spin_x_data': spin_x_data,
            'spin_y_data': spin_y_data,
            'total_spin': total_spin,
            'tau_phi': tau_phi,
            'gap': gap
        }
        
        print(f"3x3 lattice at T={T}K: τ_φ = {tau_phi:.2f} ns, Energy Gap = {gap:.4f} meV")
    
    # 3. Create 3x3 lattice visualization
    plot_quantum_dot_lattice()
    
    # 4. Plot charge dynamics at different temperatures
    plt.figure(figsize=(10, 8))
    for i, T in enumerate(temperatures):
        plt.subplot(2, 1, i+1)
        for site in range(min(4, N_sites)):  # Show first few sites
            plt.plot(results[T]['times'], np.real(results[T]['charge_data'][site]), 
                    label=f'Site {site}')
        
        plt.title(f'3x3 Charge Dynamics (T = {T}K)')
        plt.xlabel('Time (ns)')
        plt.ylabel('Charge Occupation')
        plt.grid(True, alpha=0.3)
        plt.legend()
    
    plt.tight_layout()
    plt.savefig('3x3_charge_dynamics.png', dpi=300, bbox_inches='tight')
    print("Successfully saved 3x3_charge_dynamics.png")
    plt.close()
    
    # 5. Plot spin dynamics at different temperatures
    plt.figure(figsize=(10, 8))
    for i, T in enumerate(temperatures):
        plt.subplot(2, 1, i+1)
        for j in range(3):  # Plot x, y, z components
            component = ['x', 'y', 'z'][j]
            spin_data = [results[T]['spin_z_data'], results[T]['spin_x_data'], results[T]['spin_y_data']][j]
            if j == 0:  # Only z-component available
                plt.plot(results[T]['times'], np.real(results[T]['total_spin'][j]), 
                        label=f'Total Spin-{component}', linewidth=2)
        
        plt.title(f'3x3 Spin Dynamics (T = {T}K)')
        plt.xlabel('Time (ns)')
        plt.ylabel('Spin Expectation')
        plt.grid(True, alpha=0.3)
        plt.legend()
    
    plt.tight_layout()
    plt.savefig('3x3_spin_dynamics.png', dpi=300, bbox_inches='tight')
    print("Successfully saved 3x3_spin_dynamics.png")
    plt.close()
    
    # 6. Create summary plot comparing 2x2 and 3x3 results
    plt.figure(figsize=(12, 8))
    
    # Plot temperature vs coherence time
    plt.subplot(2, 2, 1)
    plt.plot(temperatures, [results[T]['tau_phi'] for T in temperatures], 'o-', 
             label='3x3 Lattice', linewidth=2, markersize=8)
    plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
    plt.xlabel('Temperature (K)')
    plt.ylabel('Coherence Time (ns)')
    plt.title('3x3 Decoherence Time vs Temperature')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Plot temperature vs energy gap
    plt.subplot(2, 2, 2)
    plt.plot(temperatures, [np.real(results[T]['gap']) for T in temperatures], 'o-', 
             label='3x3 Lattice', linewidth=2, markersize=8)
    plt.xlabel('Temperature (K)')
    plt.ylabel('Energy Gap (meV)')
    plt.title('3x3 Energy Gap vs Temperature')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Plot charge dynamics at 77K
    plt.subplot(2, 2, 3)
    for site in range(min(4, N_sites)):
        plt.plot(results[77]['times'], np.real(results[77]['charge_data'][site]), 
                label=f'Site {site}')
    plt.xlabel('Time (ns)')
    plt.ylabel('Charge Occupation')
    plt.title('3x3 Charge Dynamics (T = 77K)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Plot spin dynamics at 77K
    plt.subplot(2, 2, 4)
    for j in range(3):
        component = ['x', 'y', 'z'][j]
        plt.plot(results[77]['times'], np.real(results[77]['total_spin'][j]), 
                label=f'Total Spin-{component}')
    plt.xlabel('Time (ns)')
    plt.ylabel('Spin Expectation')
    plt.title('3x3 Spin Dynamics (T = 77K)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('3x3_summary.png', dpi=300, bbox_inches='tight')
    print("Successfully saved 3x3_summary.png")
    plt.close()
    
    # Print summary
    end_time = time.time()
    total_time = end_time - start_time
    
    print("\n3x3 LATTICE SIMULATION RESULTS SUMMARY")
    print("======================================")
    print(f"Energy gap with spin at 0K: {eigenvalues[1] - eigenvalues[0]:.4f} meV")
    print(f"Room temperature coherence time: {results[300]['tau_phi']:.2f} ns")
    print(f"77K coherence time: {results[77]['tau_phi']:.2f} ns")
    print(f"Coherence time ratio (77K/300K): {results[77]['tau_phi']/results[300]['tau_phi']:.2f}")
    print(f"\nTotal 3x3 simulation runtime: {total_time/60:.1f} minutes")
    
    # Restore original parameters and directory
    N = original_N
    MAX_EXCITATIONS = original_MAX_EXCITATIONS
    os.chdir(original_dir)
    
    return results
    
def run_4x4_simulation():
    """Run a highly constrained 4x4 lattice simulation with spin effects"""
    # Save original parameters
    global N, MAX_EXCITATIONS
    original_N = N
    original_MAX_EXCITATIONS = MAX_EXCITATIONS
    
    # Set parameters for 4x4 simulation
    N = 4  # 4x4 lattice
    MAX_EXCITATIONS = 2  # Very limited excitations for 4x4
    
    print(f"Starting 4x4 lattice simulation with MAX_EXCITATIONS={MAX_EXCITATIONS}")
    print(f"This creates a highly constrained simulation due to the large Hilbert space")
    
    # Create output directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, 'outputs')
    try:
        os.makedirs(output_dir, exist_ok=True)
        print(f"4x4 output will be saved to: {output_dir}")
    except Exception as e:
        print(f"Warning: Could not create 4x4 outputs directory: {str(e)}")
        output_dir = script_dir
    
    # Change to output directory
    original_dir = os.getcwd()
    os.chdir(output_dir)
    
    start_time = time.time()
    
    # For 4x4, we'll focus on just the essentials to keep computation time reasonable
    N_sites = N * N
    
    # Calculate size of truncated Hilbert space (estimate)
    hilbert_size = 0
    for n_excit in range(MAX_EXCITATIONS + 1):
        # Binomial coefficient for distributing n_excit excitations among 2*N_sites possibilities
        # Factor of 2 for spin up/down
        hilbert_size += scipy.special.comb(2*N_sites, n_excit, exact=True)
    
    print(f"Estimated size of truncated Hilbert space: {hilbert_size}")
    
    # Check if Hilbert space is too large
    if hilbert_size > 10000:
        print("WARNING: Hilbert space is very large. Simulation may take excessive time/memory.")
        print("Consider reducing MAX_EXCITATIONS.")
    
    # Just calculate coherence time at 77K and 300K
    temperatures = [77, 300]
    results = {}
    
    try:
        # Create initial state - use a simpler state for 4x4
        psi0 = create_initial_state_with_spin(N_sites, config='spin_up', max_excitations=MAX_EXCITATIONS)
        
        for T in temperatures:
            print(f"Simulating 4x4 lattice at T={T}K")
            
            # Create Hamiltonian
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, 
                                            T, max_excitations=MAX_EXCITATIONS)
            
            # Calculate energy gap - may be slow for large matrices
            try:
                eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=2)
                gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
            except Exception as e:
                print(f"Could not calculate eigenvalues: {str(e)}")
                gap = None
            
            # Very limited simulation - just enough to get coherence time
            tmax = 20  # Short simulation time
            num_points = 20  # Few time points
            
            # Run simulation
            try:
                times, charge_data, spin_z_data, spin_x_data, spin_y_data, total_spin, tau_phi = \
                    simulate_spin_decoherence(
                        H, psi0, T, tmax=tmax, num_points=num_points, 
                        include_spin=True, max_excitations=MAX_EXCITATIONS
                    )
                
                # Store minimal results
                results[T] = {
                    'tau_phi': tau_phi,
                    'gap': gap
                }
                
                print(f"4x4 lattice at T={T}K: τ_φ = {tau_phi:.2f} ns, Energy Gap = {gap:.4f} meV")
            except Exception as e:
                print(f"Simulation failed at T={T}K: {str(e)}")
                results[T] = {'tau_phi': None, 'gap': gap}
        
        # Plot coherence time vs temperature
        if all(results[T]['tau_phi'] is not None for T in temperatures):
            plt.figure(figsize=(8, 6))
            plt.plot(temperatures, [results[T]['tau_phi'] for T in temperatures], 'o-', 
                    linewidth=2, markersize=8, label='4x4 Lattice')
            plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
            plt.xlabel('Temperature (K)')
            plt.ylabel('Coherence Time (ns)')
            plt.title('4x4 Lattice: Decoherence Time vs Temperature')
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.tight_layout()
            plt.savefig('4x4_coherence_vs_temperature.png', dpi=300, bbox_inches='tight')
            print("Successfully saved 4x4_coherence_vs_temperature.png")
            plt.close()
    
    except Exception as e:
        print(f"4x4 simulation failed: {str(e)}")
    
    # Print summary
    end_time = time.time()
    total_time = end_time - start_time
    
    print("\n4x4 LATTICE SIMULATION RESULTS SUMMARY")
    print("======================================")
    print(f"Coherence times:")
    for T in temperatures:
        if T in results and results[T]['tau_phi'] is not None:
            print(f"  T={T}K: {results[T]['tau_phi']:.2f} ns")
    
    print(f"\nTotal 4x4 simulation runtime: {total_time/60:.1f} minutes")
    
    # Restore original parameters and directory
    N = original_N
    MAX_EXCITATIONS = original_MAX_EXCITATIONS
    os.chdir(original_dir)
    
    return results

def analyze_disorder_effects(disorder_levels=[0.0, 0.05, 0.1, 0.2], T=77, num_samples=5):
    """
    Analyze the effect of parameter disorder (fabrication imperfections) on coherence.
    
    Parameters:
    - disorder_levels: List of disorder strengths (as fraction of parameter values)
    - T: Temperature for analysis (K)
    - num_samples: Number of disorder realizations to average for each level
    
    Returns: Results of disorder analysis
    """
    print("Analyzing effects of parameter disorder (fabrication imperfections)...")
    
    # Store original parameters
    original_params = {
        'epsilon': epsilon,
        'U': U,
        't_base': t_base,
        'B_field': B_field,
        'J_ex': J_ex,
        'spin_orbit': spin_orbit
    }
    
    # Parameters to apply disorder to
    param_names = ['epsilon', 'U', 't_base', 'B_field', 'J_ex', 'spin_orbit']
    
    # Results storage
    results = {
        'disorder_levels': disorder_levels,
        'coherence_times': [],
        'energy_gaps': [],
        'coherence_std': [],
        'gap_std': []
    }
    
    N_sites = N * N
    
    # Test each disorder level
    for disorder in disorder_levels:
        print(f"Testing disorder level: {disorder*100:.1f}%")
        
        coherence_times = []
        energy_gaps = []
        
        # Run multiple samples with random disorder
        for sample in range(num_samples):
            # Apply random disorder to all parameters
            for param in param_names:
                # Generate random factor between (1-disorder) and (1+disorder)
                random_factor = 1.0 + disorder * (2.0 * np.random.random() - 1.0)
                globals()[param] = original_params[param] * random_factor
            
            # Create initial state
            psi0 = create_initial_state_with_spin(N_sites, config='singlet_pair', max_excitations=MAX_EXCITATIONS)
            
            # Create Hamiltonian with disordered parameters
            H = create_hamiltonian_with_spin(N, epsilon, t_base, U, B_field, J_ex, spin_orbit, 
                                           T, max_excitations=MAX_EXCITATIONS)
            
            # Calculate energy gap
            eigenvalues = H.eigenenergies(sparse=True, sort='low', eigvals=5)
            gap = eigenvalues[1] - eigenvalues[0] if len(eigenvalues) > 1 else 0
            
            # Run short simulation to get coherence time
            _, _, _, _, _, _, tau_phi = simulate_spin_decoherence(
                H, psi0, T, tmax=10, num_points=20, 
                include_spin=True, max_excitations=MAX_EXCITATIONS
            )
            
            # Store results
            coherence_times.append(tau_phi)
            energy_gaps.append(np.real(gap))
        
        # Calculate statistics
        mean_coherence = np.mean(coherence_times)
        std_coherence = np.std(coherence_times)
        mean_gap = np.mean(energy_gaps)
        std_gap = np.std(energy_gaps)
        
        results['coherence_times'].append(mean_coherence)
        results['coherence_std'].append(std_coherence)
        results['energy_gaps'].append(mean_gap)
        results['gap_std'].append(std_gap)
        
        print(f"  Average coherence time: {mean_coherence:.2f} ± {std_coherence:.2f} ns")
        print(f"  Average energy gap: {mean_gap:.4f} ± {std_gap:.4f} meV")
    
    # Restore original parameters
    for param, value in original_params.items():
        globals()[param] = value
    
    # Create visualization
    plt.figure(figsize=(10, 6))
    
    # Plot coherence time vs disorder
    plt.errorbar([d*100 for d in disorder_levels], results['coherence_times'], 
                yerr=results['coherence_std'], fmt='o-', linewidth=2, 
                label='Coherence Time', capsize=5)
    
    plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
    plt.xlabel('Disorder Level (%)')
    plt.ylabel('Coherence Time (ns)')
    plt.title(f'Effect of Fabrication Disorder on Coherence (T={T}K)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    
    plt.savefig('disorder_robustness_analysis.png', dpi=300, bbox_inches='tight')
    print("Successfully saved disorder_robustness_analysis.png")
    plt.close()
    
    # Create a second plot for energy gap
    plt.figure(figsize=(10, 6))
    plt.errorbar([d*100 for d in disorder_levels], results['energy_gaps'], 
                yerr=results['gap_std'], fmt='o-', linewidth=2, 
                label='Energy Gap', capsize=5, color='orange')
    
    plt.xlabel('Disorder Level (%)')
    plt.ylabel('Energy Gap (meV)')
    plt.title(f'Effect of Fabrication Disorder on Energy Gap (T={T}K)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    
    plt.savefig('disorder_energy_gap_analysis.png', dpi=300, bbox_inches='tight')
    print("Successfully saved disorder_energy_gap_analysis.png")
    plt.close()
    
    return results

def comprehensive_analysis():
    """Run comprehensive analysis with parameter sensitivity and multiple decoherence models"""
    # Override the phase memory calculation function with our new version
    global calculate_phase_memory
    
    print("Starting comprehensive analysis...")
    
    results = {}
    
    # 1. Test different decoherence models
    decoherence_models = ['standard', 'exponential', 'power_law', 'ohmic']
    temperatures = [77, 150, 225, 300]
    
    # Store original simulation function for later restoration
    original_phase_memory_func = calculate_phase_memory
    
    calculate_phase_memory = calculate_phase_memory_alternative
    
    # Test each decoherence model
    model_results = {}
    for model in decoherence_models:
        print(f"Testing decoherence model: {model}")
        model_coherence = []
        
        for T in temperatures:
            tau_phi = calculate_phase_memory(T, model=model)
            model_coherence.append(tau_phi)
            print(f"  {model} model at T={T}K: τ_φ = {tau_phi:.2f} ns")
        
        model_results[model] = model_coherence
    
    # Plot comparison of decoherence models
    plt.figure(figsize=(10, 6))
    for model, coherence_values in model_results.items():
        plt.plot(temperatures, coherence_values, 'o-', label=model.capitalize())
    
    plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
    plt.xlabel('Temperature (K)')
    plt.ylabel('Coherence Time (ns)')
    plt.title('Decoherence Time vs Temperature: Model Comparison')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig('decoherence_model_comparison.png', dpi=300, bbox_inches='tight')
    print("Successfully saved decoherence_model_comparison.png")
    plt.close()
    
    # Restore original function
    calculate_phase_memory = original_phase_memory_func
    
    # 2. Run parameter sensitivity analysis
    sensitivity_results = parameter_sensitivity_analysis()
    results['sensitivity'] = sensitivity_results
    
    # 3. Run lattice size scaling analysis (2x2, 3x3, 4x4 if possible)
    # Store coherence times and energy gaps for each lattice size
    lattice_scaling = {
        '2x2': [],
        '3x3': [],
        '4x4': []
    }
    
    # We already have 2x2 results from main simulation
    # Run 3x3 and 4x4 simulations
    results_3x3 = run_3x3_simulation()
    try:
        results_4x4 = run_4x4_simulation()
    except Exception as e:
        print(f"4x4 simulation failed, proceeding without it: {str(e)}")
        results_4x4 = None
    
    # Collect coherence times for plotting
    for T in [77, 300]:  # Just compare two temperatures
        # Add placeholders for values we'll collect later
        for size in lattice_scaling.keys():
            lattice_scaling[size].append(None)
        
        # We'll assume we have access to these results from earlier runs
        # In practice, you'd need to properly collect these values
        if T in results_3x3:
            lattice_scaling['3x3'][0 if T==77 else 1] = results_3x3[T]['tau_phi']
        
        if results_4x4 and T in results_4x4:
            lattice_scaling['4x4'][0 if T==77 else 1] = results_4x4[T]['tau_phi']
    
    # Plot lattice scaling results if we have them
    if any(lattice_scaling['3x3']) or any(lattice_scaling['4x4']):
        plt.figure(figsize=(10, 6))
        
        sizes = []
        coherence_77K = []
        coherence_300K = []
        
        for size_label, values in lattice_scaling.items():
            if size_label == '2x2':
                numeric_size = 4
            elif size_label == '3x3':
                numeric_size = 9
            else:  # 4x4
                numeric_size = 16
            
            if values[0] is not None:  # 77K value exists
                sizes.append(numeric_size)
                coherence_77K.append(values[0])
            
            if values[1] is not None:  # 300K value exists
                if numeric_size not in sizes:
                    sizes.append(numeric_size)
                coherence_300K.append(values[1])
        
        if coherence_77K:
            plt.plot(sizes, coherence_77K, 'o-', label='77K', color='blue')
        if coherence_300K:
            plt.plot(sizes, coherence_300K, 's-', label='300K', color='red')
        
        plt.xlabel('Number of Quantum Dots')
        plt.ylabel('Coherence Time (ns)')
        plt.title('Coherence Time vs Lattice Size')
        plt.axhline(y=50, color='k', linestyle='--', label='50 ns threshold')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.savefig('lattice_size_scaling.png', dpi=300, bbox_inches='tight')
        print("Successfully saved lattice_size_scaling.png")
        plt.close()
    
    print("Comprehensive analysis complete!")
    return results

if __name__ == "__main__":
    # Run original spin simulation
    main_with_spin_effects()
    
    print("\nNow running comprehensive analysis including parameter sensitivity, decoherence models, and lattice scaling...")
    comprehensive_analysis()