subroutine global_coordinator_impl(ctx, json_data)
!! Internal implementation of global_coordinator with typed context
use mqc_json_output_types, only: json_output_data_t
use mqc_many_body_expansion, only: mbe_context_t
type(mbe_context_t), intent(in) :: ctx
type(json_output_data_t), intent(out), optional :: json_data !! JSON output data
type :: group_shard_t
integer(int64), allocatable :: fragment_ids(:)
integer, allocatable :: polymers(:, :)
end type group_shard_t
type(timer_type) :: coord_timer
integer(int64) :: results_received
integer :: group_done_count
integer :: group0_node_count
integer :: group0_finished_nodes
integer :: group_id
integer :: i
integer :: local_finished_workers
integer :: group0_done
integer :: local_node_done
integer(int32) :: calc_type_local
! Storage for results
type(calculation_result_t), allocatable :: results(:)
integer(int64) :: worker_fragment_map(ctx%resources%mpi_comms%node_comm%size())
type(queue_t) :: group0_queue
integer(int64), allocatable :: group0_fragment_ids(:)
integer, allocatable :: group0_polymers(:, :)
integer(int64) :: fragment_idx
integer(int64) :: chunk_id, chunk_size
integer(int64), allocatable :: group_counts(:)
integer(int64), allocatable :: group_fill(:)
integer, allocatable :: group_leader_by_group(:)
integer, allocatable :: group_node_counts(:)
integer :: n_cols
type(group_shard_t), allocatable :: group_shards(:)
! MPI request handles for non-blocking operations
type(request_t) :: req
calc_type_local = ctx%calc_type
results_received = 0_int64
group_done_count = 0
group0_finished_nodes = 0
local_finished_workers = 0
group0_done = 0
local_node_done = 0
! Allocate storage for results
allocate (results(ctx%total_fragments))
worker_fragment_map = 0
call logger%verbose("Super-global coordinator starting with "//to_char(ctx%total_fragments)// &
" fragments for "//to_char(ctx%num_nodes)//" nodes and "// &
to_char(ctx%global_groups)//" groups")
! Build group leader map and node counts
allocate (group_leader_by_group(ctx%global_groups))
group_leader_by_group = -1
allocate (group_node_counts(ctx%global_groups))
group_node_counts = 0
do i = 1, size(ctx%node_leader_ranks)
group_id = ctx%group_ids(i)
group_node_counts(group_id) = group_node_counts(group_id) + 1
if (group_leader_by_group(group_id) == -1) then
group_leader_by_group(group_id) = ctx%node_leader_ranks(i)
end if
end do
group0_node_count = group_node_counts(1)
! Partition fragments into group shards (chunked round-robin)
allocate (group_counts(ctx%global_groups))
group_counts = 0_int64
if (ctx%total_fragments > 0) then
chunk_size = max(1_int64, ctx%total_fragments/int(ctx%global_groups, int64))
do fragment_idx = 1_int64, ctx%total_fragments
chunk_id = (fragment_idx - 1_int64)/chunk_size + 1_int64
group_id = int(mod(chunk_id - 1_int64, int(ctx%global_groups, int64)) + 1_int64)
group_counts(group_id) = group_counts(group_id) + 1_int64
end do
end if
allocate (group_shards(ctx%global_groups))
allocate (group_fill(ctx%global_groups))
group_fill = 0_int64
n_cols = size(ctx%polymers, 2)
do i = 1, ctx%global_groups
if (group_counts(i) > 0_int64) then
allocate (group_shards(i)%fragment_ids(group_counts(i)))
allocate (group_shards(i)%polymers(group_counts(i), n_cols))
end if
end do
if (ctx%total_fragments > 0) then
do fragment_idx = 1_int64, ctx%total_fragments
chunk_id = (fragment_idx - 1_int64)/chunk_size + 1_int64
group_id = int(mod(chunk_id - 1_int64, int(ctx%global_groups, int64)) + 1_int64)
group_fill(group_id) = group_fill(group_id) + 1_int64
group_shards(group_id)%fragment_ids(group_fill(group_id)) = fragment_idx
group_shards(group_id)%polymers(group_fill(group_id), :) = ctx%polymers(fragment_idx, :)
end do
end if
! Dispatch shards to group globals
do i = 1, ctx%global_groups
if (group_leader_by_group(i) == 0) then
if (allocated(group_shards(i)%fragment_ids)) then
call move_alloc(group_shards(i)%fragment_ids, group0_fragment_ids)
call move_alloc(group_shards(i)%polymers, group0_polymers)
else
allocate (group0_fragment_ids(0))
allocate (group0_polymers(0, n_cols))
end if
else if (group_leader_by_group(i) > 0) then
call send_group_assignment_matrix(ctx%resources%mpi_comms%world_comm, group_leader_by_group(i), &
group_shards(i)%fragment_ids, group_shards(i)%polymers)
end if
if (allocated(group_shards(i)%fragment_ids)) deallocate (group_shards(i)%fragment_ids)
if (allocated(group_shards(i)%polymers)) deallocate (group_shards(i)%polymers)
end do
deallocate (group_shards)
deallocate (group_counts)
deallocate (group_fill)
! Initialize local group queue (group 0)
if (.not. allocated(group0_fragment_ids)) then
allocate (group0_fragment_ids(0))
allocate (group0_polymers(0, n_cols))
end if
block
integer(int64), allocatable :: temp_ids(:)
integer(int64) :: idx
if (size(group0_fragment_ids) > 0) then
! Queue stores local indices (1..N) into group0_fragment_ids/polymers.
allocate (temp_ids(size(group0_fragment_ids)))
do idx = 1_int64, size(group0_fragment_ids, kind=int64)
temp_ids(idx) = idx
end do
call queue_init_from_list(group0_queue, temp_ids)
deallocate (temp_ids)
else
group0_queue%count = 0_int64
group0_queue%head = 1_int64
end if
end block
call coord_timer%start()
do while (group_done_count < ctx%global_groups .or. results_received < ctx%total_fragments)
! PRIORITY 1: Receive batched results from group globals
call handle_group_results(ctx%resources%mpi_comms%world_comm, results, results_received, &
ctx%total_fragments, coord_timer, group_done_count, "fragment")
! PRIORITY 2: Check for incoming results from local workers
if (ctx%resources%mpi_comms%node_comm%size() > 1) then
call handle_local_worker_results(ctx, worker_fragment_map, results, results_received, coord_timer)
end if
! PRIORITY 3: Check for incoming results from node coordinators (group 0 only)
call handle_node_results(ctx, results, results_received, coord_timer)
! PRIORITY 4: Remote node coordinator requests for group 0
call handle_group_node_requests(ctx, group0_queue, group0_fragment_ids, group0_polymers, group0_finished_nodes)
! PRIORITY 5: Local workers (shared memory) - send new work for group 0
if (ctx%resources%mpi_comms%node_comm%size() > 1 .and. &
local_finished_workers < ctx%resources%mpi_comms%node_comm%size() - 1) then
call handle_local_worker_requests_group(ctx, group0_queue, group0_fragment_ids, group0_polymers, &
worker_fragment_map, local_finished_workers)
end if
! Mark local node completion once all local workers are finished and queue is empty
if (local_node_done == 0) then
if (queue_is_empty(group0_queue) .and. &
(ctx%resources%mpi_comms%node_comm%size() == 1 .or. &
local_finished_workers >= ctx%resources%mpi_comms%node_comm%size() - 1)) then
local_node_done = 1
group0_finished_nodes = group0_finished_nodes + 1
end if
end if
if (group0_done == 0) then
if (group0_finished_nodes >= group0_node_count) then
group0_done = 1
group_done_count = group_done_count + 1
end if
end if
end do
call logger%verbose("Super-global coordinator finished all fragments")
call coord_timer%stop()
call logger%info("Time to evaluate all fragments "//to_char(coord_timer%get_elapsed_time())//" s")
block
use mqc_result_types, only: mbe_result_t
type(mbe_result_t) :: mbe_result
! Compute the many-body expansion
call logger%info(" ")
call logger%info("Computing Many-Body Expansion (MBE)...")
call coord_timer%start()
! Allocate mbe_result components based on calc_type
call mbe_result%allocate_dipole() ! Always compute dipole
if (calc_type_local == CALC_TYPE_HESSIAN) then
if (.not. ctx%has_geometry()) then
call logger%error("sys_geom required for Hessian calculation in global_coordinator")
call abort_comm(ctx%resources%mpi_comms%world_comm, 1)
end if
call mbe_result%allocate_gradient(ctx%sys_geom%total_atoms)
call mbe_result%allocate_hessian(ctx%sys_geom%total_atoms)
else if (calc_type_local == CALC_TYPE_GRADIENT) then
if (.not. ctx%has_geometry()) then
call logger%error("sys_geom required for gradient calculation in global_coordinator")
call abort_comm(ctx%resources%mpi_comms%world_comm, 1)
end if
call mbe_result%allocate_gradient(ctx%sys_geom%total_atoms)
end if
call compute_mbe(ctx%polymers, ctx%total_fragments, ctx%max_level, results, mbe_result, &
ctx%sys_geom, ctx%resources%mpi_comms%world_comm, json_data)
call mbe_result%destroy()
call coord_timer%stop()
call logger%info("Time to compute MBE "//to_char(coord_timer%get_elapsed_time())//" s")
end block
! Cleanup
call queue_destroy(group0_queue)
if (allocated(group0_fragment_ids)) deallocate (group0_fragment_ids)
if (allocated(group0_polymers)) deallocate (group0_polymers)
if (allocated(group_leader_by_group)) deallocate (group_leader_by_group)
if (allocated(group_node_counts)) deallocate (group_node_counts)
deallocate (results)
end subroutine global_coordinator_impl