subroutine group_global_coordinator_impl(ctx)
!! Group-global coordinator for distributing a fragment shard to node coordinators.
use mqc_many_body_expansion, only: many_body_expansion_t
class(many_body_expansion_t), intent(in) :: ctx
integer(int64), allocatable :: group_fragment_ids(:)
integer, allocatable :: group_polymers(:, :)
type(queue_t) :: group_queue
integer(int64), allocatable :: temp_ids(:)
integer(int64) :: idx
integer(int32) :: batch_count
integer(int64), allocatable :: batch_ids(:)
type(calculation_result_t), allocatable :: batch_results(:)
integer(int64) :: results_received
integer(int64) :: total_group_fragments
integer :: finished_nodes
integer :: local_finished_workers
integer :: group_node_count
integer :: group_leader_rank, group_id
integer :: local_node_done
integer(int64) :: worker_fragment_map(ctx%resources%mpi_comms%node_comm%size())
type(request_t) :: req
call get_group_leader_rank(ctx, ctx%resources%mpi_comms%world_comm%rank(), group_leader_rank, group_id)
if (group_leader_rank /= ctx%resources%mpi_comms%world_comm%rank()) then
call logger%error("group_global_coordinator_impl called on non-group leader rank")
call abort_comm(ctx%resources%mpi_comms%world_comm, 1)
end if
group_node_count = count(ctx%group_ids == group_id)
call receive_group_assignment_matrix(ctx%resources%mpi_comms%world_comm, group_fragment_ids, group_polymers)
if (size(group_fragment_ids) > 0) then
! Queue stores local indices (1..N) into group_fragment_ids/group_polymers.
allocate (temp_ids(size(group_fragment_ids)))
do idx = 1_int64, size(group_fragment_ids, kind=int64)
temp_ids(idx) = idx
end do
call queue_init_from_list(group_queue, temp_ids)
deallocate (temp_ids)
else
group_queue%count = 0_int64
group_queue%head = 1_int64
end if
batch_count = 0
allocate (batch_ids(GROUP_RESULT_BATCH_SIZE))
allocate (batch_results(GROUP_RESULT_BATCH_SIZE))
results_received = 0_int64
total_group_fragments = int(size(group_fragment_ids, kind=int64), int64)
finished_nodes = 0
local_finished_workers = 0
local_node_done = 0
worker_fragment_map = 0
do while (finished_nodes < group_node_count .or. results_received < total_group_fragments)
call handle_local_worker_results_to_batch(ctx%resources%mpi_comms%node_comm, &
ctx%resources%mpi_comms%world_comm, &
worker_fragment_map, batch_count, batch_ids, batch_results, &
results_received)
call handle_node_results_to_batch(ctx%resources%mpi_comms%world_comm, batch_count, batch_ids, batch_results, &
results_received)
call handle_group_node_requests(ctx, group_queue, group_fragment_ids, group_polymers, finished_nodes)
if (ctx%resources%mpi_comms%node_comm%size() > 1 .and. &
local_finished_workers < ctx%resources%mpi_comms%node_comm%size() - 1) then
call handle_local_worker_requests_group(ctx, group_queue, group_fragment_ids, group_polymers, &
worker_fragment_map, local_finished_workers)
end if
if (local_node_done == 0) then
if (queue_is_empty(group_queue) .and. &
(ctx%resources%mpi_comms%node_comm%size() == 1 .or. &
local_finished_workers >= ctx%resources%mpi_comms%node_comm%size() - 1)) then
local_node_done = 1
finished_nodes = finished_nodes + 1
end if
end if
if (batch_count >= GROUP_RESULT_BATCH_SIZE) then
call flush_group_results(ctx%resources%mpi_comms%world_comm, batch_count, batch_ids, batch_results)
end if
end do
call flush_group_results(ctx%resources%mpi_comms%world_comm, batch_count, batch_ids, batch_results)
call isend(ctx%resources%mpi_comms%world_comm, 0, 0, TAG_GROUP_DONE, req)
call wait(req)
call queue_destroy(group_queue)
deallocate (group_fragment_ids)
deallocate (group_polymers)
deallocate (batch_ids)
deallocate (batch_results)
end subroutine group_global_coordinator_impl