mqc_group_batching.f90 Source File

Shared batching helpers for group-global coordinator flows


This file depends on

sourcefile~~mqc_group_batching.f90~~EfferentGraph sourcefile~mqc_group_batching.f90 mqc_group_batching.f90 sourcefile~mqc_mpi_tags.f90 mqc_mpi_tags.f90 sourcefile~mqc_group_batching.f90->sourcefile~mqc_mpi_tags.f90 sourcefile~mqc_result_types.f90 mqc_result_types.f90 sourcefile~mqc_group_batching.f90->sourcefile~mqc_result_types.f90 sourcefile~mqc_error.f90 mqc_error.f90 sourcefile~mqc_result_types.f90->sourcefile~mqc_error.f90

Files dependent on this one

sourcefile~~mqc_group_batching.f90~~AfferentGraph sourcefile~mqc_group_batching.f90 mqc_group_batching.f90 sourcefile~mqc_gmbe_fragment_distribution_scheme.f90 mqc_gmbe_fragment_distribution_scheme.f90 sourcefile~mqc_gmbe_fragment_distribution_scheme.f90->sourcefile~mqc_group_batching.f90 sourcefile~mqc_mbe_mpi_fragment_distribution_scheme.f90 mqc_mbe_mpi_fragment_distribution_scheme.F90 sourcefile~mqc_mbe_mpi_fragment_distribution_scheme.f90->sourcefile~mqc_group_batching.f90 sourcefile~mqc_many_body_expansion.f90 mqc_many_body_expansion.f90 sourcefile~mqc_mbe_mpi_fragment_distribution_scheme.f90->sourcefile~mqc_many_body_expansion.f90 sourcefile~mqc_many_body_expansion.f90->sourcefile~mqc_gmbe_fragment_distribution_scheme.f90 sourcefile~mqc_driver.f90 mqc_driver.f90 sourcefile~mqc_driver.f90->sourcefile~mqc_many_body_expansion.f90 sourcefile~main.f90 main.f90 sourcefile~main.f90->sourcefile~mqc_driver.f90 sourcefile~mqc_calculation_interface.f90 mqc_calculation_interface.f90 sourcefile~mqc_calculation_interface.f90->sourcefile~mqc_driver.f90

Source Code

!! Shared batching helpers for group-global coordinator flows
module mqc_group_batching
   use pic_types, only: int32, int64
   use pic_timer, only: timer_type
   use pic_mpi_lib, only: comm_t, isend, irecv, recv, wait, iprobe, MPI_Status, request_t, MPI_ANY_SOURCE, abort_comm
   use pic_logger, only: logger => global_logger
   use pic_io, only: to_char
   use mqc_mpi_tags, only: TAG_WORKER_SCALAR_RESULT, TAG_NODE_SCALAR_RESULT, TAG_GROUP_RESULT, TAG_GROUP_DONE
   use mqc_result_types, only: calculation_result_t, result_isend, result_irecv
   implicit none
   private

   public :: append_result_to_batch
   public :: flush_group_results
   public :: handle_local_worker_results_to_batch
   public :: handle_node_results_to_batch
   public :: handle_group_results

contains

   subroutine append_result_to_batch(item_idx, result, batch_count, batch_ids, batch_results)
      !! Append a completed fragment result to the current outbound batch.
      integer(int64), intent(in) :: item_idx
      type(calculation_result_t), intent(in) :: result
      integer(int32), intent(inout) :: batch_count
      integer(int64), intent(inout) :: batch_ids(:)
      type(calculation_result_t), intent(inout) :: batch_results(:)

      batch_count = batch_count + 1
      batch_ids(batch_count) = item_idx
      batch_results(batch_count) = result
   end subroutine append_result_to_batch

   subroutine flush_group_results(world_comm, batch_count, batch_ids, batch_results)
      !! Send all currently batched results to rank 0 and reset the batch.
      type(comm_t), intent(in) :: world_comm
      integer(int32), intent(inout) :: batch_count
      integer(int64), intent(inout) :: batch_ids(:)
      type(calculation_result_t), intent(inout) :: batch_results(:)

      type(request_t) :: req
      integer :: i

      if (batch_count <= 0) return

      call isend(world_comm, batch_count, 0, TAG_GROUP_RESULT, req)
      call wait(req)
      call isend(world_comm, batch_ids(1:batch_count), 0, TAG_GROUP_RESULT, req)
      call wait(req)
      do i = 1, batch_count
         call result_isend(batch_results(i), world_comm, 0, TAG_GROUP_RESULT, req)
         call wait(req)
         call batch_results(i)%destroy()
      end do
      batch_count = 0
   end subroutine flush_group_results

   subroutine handle_local_worker_results_to_batch(node_comm, world_comm, worker_map, batch_count, batch_ids, batch_results, &
                                                   results_received)
      !! Drain pending local worker results and append them to the outbound batch.
      type(comm_t), intent(in) :: node_comm
      type(comm_t), intent(in) :: world_comm
      integer(int64), intent(inout) :: worker_map(:)
      integer(int32), intent(inout) :: batch_count
      integer(int64), intent(inout) :: batch_ids(:)
      type(calculation_result_t), intent(inout) :: batch_results(:)
      integer(int64), intent(inout), optional :: results_received

      type(MPI_Status) :: local_status
      logical :: has_pending
      integer :: worker_source
      type(request_t) :: req
      type(calculation_result_t) :: worker_result
      integer(int64) :: item_idx

      if (node_comm%size() <= 1) return

      do
         call iprobe(node_comm, MPI_ANY_SOURCE, TAG_WORKER_SCALAR_RESULT, has_pending, local_status)
         if (.not. has_pending) exit

         worker_source = local_status%MPI_SOURCE

         if (worker_map(worker_source) == 0) then
            call logger%error("Received result from worker "//to_char(worker_source)// &
                              " but no item was assigned!")
            call abort_comm(world_comm, 1)
         end if

         call result_irecv(worker_result, node_comm, worker_source, TAG_WORKER_SCALAR_RESULT, req)
         call wait(req)

         if (worker_result%has_error) then
            call logger%error("Item "//to_char(worker_map(worker_source))// &
                              " calculation failed: "// &
                              worker_result%error%get_message())
            call abort_comm(world_comm, 1)
         end if

         item_idx = worker_map(worker_source)
         worker_map(worker_source) = 0

         if (batch_count >= size(batch_ids)) then
            call flush_group_results(world_comm, batch_count, batch_ids, batch_results)
         end if

         call append_result_to_batch(item_idx, worker_result, batch_count, batch_ids, batch_results)
         if (present(results_received)) results_received = results_received + 1_int64
         if (batch_count >= size(batch_ids)) then
            call flush_group_results(world_comm, batch_count, batch_ids, batch_results)
         end if
         call worker_result%destroy()
      end do
   end subroutine handle_local_worker_results_to_batch

   subroutine handle_node_results_to_batch(world_comm, batch_count, batch_ids, batch_results, results_received)
      !! Drain pending node-level results and append them to the outbound batch.
      type(comm_t), intent(in) :: world_comm
      integer(int32), intent(inout) :: batch_count
      integer(int64), intent(inout) :: batch_ids(:)
      type(calculation_result_t), intent(inout) :: batch_results(:)
      integer(int64), intent(inout), optional :: results_received

      integer(int64) :: item_idx
      type(MPI_Status) :: status
      logical :: has_pending
      type(request_t) :: req
      type(calculation_result_t) :: node_result

      do
         call iprobe(world_comm, MPI_ANY_SOURCE, TAG_NODE_SCALAR_RESULT, has_pending, status)
         if (.not. has_pending) exit

         call irecv(world_comm, item_idx, status%MPI_SOURCE, TAG_NODE_SCALAR_RESULT, req)
         call wait(req)
         call result_irecv(node_result, world_comm, status%MPI_SOURCE, TAG_NODE_SCALAR_RESULT, req)
         call wait(req)

         if (node_result%has_error) then
            call logger%error("Item "//to_char(item_idx)//" calculation failed: "// &
                              node_result%error%get_message())
            call abort_comm(world_comm, 1)
         end if

         if (batch_count >= size(batch_ids)) then
            call flush_group_results(world_comm, batch_count, batch_ids, batch_results)
         end if

         call append_result_to_batch(item_idx, node_result, batch_count, batch_ids, batch_results)
         if (present(results_received)) results_received = results_received + 1_int64
         if (batch_count >= size(batch_ids)) then
            call flush_group_results(world_comm, batch_count, batch_ids, batch_results)
         end if
         call node_result%destroy()
      end do
   end subroutine handle_node_results_to_batch

 subroutine handle_group_results(world_comm, results, results_received, total_items, coord_timer, group_done_count, label)
      !! Receive grouped result batches on rank 0 and update global progress counters.
      type(comm_t), intent(in) :: world_comm
      type(calculation_result_t), intent(inout) :: results(:)
      integer(int64), intent(inout) :: results_received
      integer(int64), intent(in) :: total_items
      type(timer_type), intent(in) :: coord_timer
      integer, intent(inout) :: group_done_count
      character(len=*), intent(in), optional :: label

      integer(int32) :: batch_count
      integer(int64), allocatable :: batch_ids(:)
      type(MPI_Status) :: status
      logical :: has_pending
      type(request_t) :: req
      integer :: i, dummy_msg
      character(len=32) :: item_label

      if (present(label)) then
         item_label = label
      else
         item_label = "item"
      end if

      do
         call iprobe(world_comm, MPI_ANY_SOURCE, TAG_GROUP_RESULT, has_pending, status)
         if (.not. has_pending) exit

         call irecv(world_comm, batch_count, status%MPI_SOURCE, TAG_GROUP_RESULT, req)
         call wait(req)
         if (batch_count <= 0) cycle

         allocate (batch_ids(batch_count))
         call recv(world_comm, batch_ids, status%MPI_SOURCE, TAG_GROUP_RESULT, status)
         do i = 1, batch_count
            call result_irecv(results(batch_ids(i)), world_comm, status%MPI_SOURCE, TAG_GROUP_RESULT, req)
            call wait(req)

            if (results(batch_ids(i))%has_error) then
               call logger%error(trim(item_label)//" "//to_char(batch_ids(i))//" calculation failed: "// &
                                 results(batch_ids(i))%error%get_message())
               call abort_comm(world_comm, 1)
            end if

            results_received = results_received + 1
            if (mod(results_received, max(1_int64, total_items/10_int64)) == 0 .or. &
                results_received == total_items) then
               call logger%info("  Processed "//to_char(results_received)//"/"// &
                                to_char(total_items)//" "//trim(item_label)//"s ["// &
                                to_char(coord_timer%get_elapsed_time())//" s]")
            end if
         end do
         deallocate (batch_ids)
      end do

      do
         call iprobe(world_comm, MPI_ANY_SOURCE, TAG_GROUP_DONE, has_pending, status)
         if (.not. has_pending) exit
         call irecv(world_comm, dummy_msg, status%MPI_SOURCE, TAG_GROUP_DONE, req)
         call wait(req)
         group_done_count = group_done_count + 1
      end do
   end subroutine handle_group_results

end module mqc_group_batching