/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.pinot.segment.local.utils;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.pinot.common.concurrency.AdjustableSemaphore;
import org.apache.pinot.common.metrics.ServerMetrics;
import org.apache.pinot.spi.config.provider.PinotClusterConfigChangeListener;
import org.slf4j.Logger;


/**
 * Base class for segment operation throttlers, contains the common logic for the semaphore and handling the pre and
 * post query serving values. The semaphore cannot be null and must contain > 0 total permits
 */
public abstract class BaseSegmentOperationsThrottler implements PinotClusterConfigChangeListener {

  protected ServerMetrics _serverMetrics;
  protected AdjustableSemaphore _semaphore;
  /**
   * _maxConcurrency and _maxConcurrencyBeforeServingQueries must be > 0. To effectively disable throttling, this can
   * be set to a very high value
   */
  protected int _maxConcurrency;
  protected int _maxConcurrencyBeforeServingQueries;
  protected boolean _isServingQueries;
  private AtomicInteger _numSegmentsAcquiredSemaphore;
  private final Logger _logger;

  /**
   * Base segment operations throttler constructor
   * @param maxConcurrency configured concurrency
   * @param maxConcurrencyBeforeServingQueries configured concurrency before serving queries
   * @param isServingQueries whether the server is ready to serve queries or not
   * @param logger logger to use
   */
  public BaseSegmentOperationsThrottler(int maxConcurrency, int maxConcurrencyBeforeServingQueries,
      boolean isServingQueries, Logger logger) {
    _logger = logger;
    _logger.info("Initializing SegmentOperationsThrottler, maxConcurrency: {}, maxConcurrencyBeforeServingQueries: {}, "
            + "isServingQueries: {}",
        maxConcurrency, maxConcurrencyBeforeServingQueries, isServingQueries);
    Preconditions.checkArgument(maxConcurrency > 0, "Max parallelism must be > 0, but found to be: " + maxConcurrency);
    Preconditions.checkArgument(maxConcurrencyBeforeServingQueries > 0,
        "Max parallelism before serving queries must be > 0, but found to be: " + maxConcurrencyBeforeServingQueries);

    _maxConcurrency = maxConcurrency;
    _maxConcurrencyBeforeServingQueries = maxConcurrencyBeforeServingQueries;
    _isServingQueries = isServingQueries;

    // maxConcurrencyBeforeServingQueries is only used prior to serving queries and once the server is
    // ready to serve queries this is not used again. This too is configurable via ZK CLUSTER config updates while the
    // server is starting up.
    if (!isServingQueries) {
      logger.info("Serving queries is disabled, using concurrency as: {}", _maxConcurrencyBeforeServingQueries);
    }

    int concurrency = _isServingQueries ? _maxConcurrency : _maxConcurrencyBeforeServingQueries;
    _semaphore = new AdjustableSemaphore(concurrency, true);
    _numSegmentsAcquiredSemaphore = new AtomicInteger(0);
    initializeMetrics();
    _logger.info("Created semaphore with total permits: {}, available permits: {}", totalPermits(),
        availablePermits());
  }

  /**
   * Updates the throttle threshold metric
   * @param value value to update the metric to
   */
  public abstract void updateThresholdMetric(int value);

  /**
   * Updates the throttle count metric
   * @param value value to update the metric to
   */
  public abstract void updateCountMetric(int value);

  /**
   * The ServerMetrics may be created after these throttle objects are created. In that case, the initialization that
   * happens in the constructor may have occurred on the NOOP metrics. This should be called after the server metrics
   * are created and registered to ensure the correct metrics object is used and the metrics are updated correctly
   *
   * This is called in the same thread as the constructor so there is no need to make _serverMetrics volatile here
   */
  public void initializeMetrics() {
    _serverMetrics = ServerMetrics.get();
    updateThresholdMetric(_semaphore.getTotalPermits());
    updateCountMetric(0);
  }

  public synchronized void startServingQueries() {
    _logger.info("Serving queries is to be enabled, reset throttling threshold for segment operations concurrency, "
        + "total permits: {}, available permits: {}", totalPermits(), availablePermits());
    _isServingQueries = true;
    _semaphore.setPermits(_maxConcurrency);
    updateThresholdMetric(_maxConcurrency);
    _logger.info("Reset throttling completed, new concurrency: {}, total permits: {}, available permits: {}",
        _maxConcurrency, totalPermits(), availablePermits());
  }

  protected void handleMaxConcurrencyChange(Set<String> changedConfigs, Map<String, String> clusterConfigs,
      String configName, String defaultConfigValue) {
    if (!changedConfigs.contains(configName)) {
      _logger.info("changedConfigs list indicates config: {} was not updated, skipping updates", configName);
      return;
    }

    String maxParallelSegmentOperationsStr =
        clusterConfigs == null ? defaultConfigValue : clusterConfigs.getOrDefault(configName, defaultConfigValue);

    int maxConcurrency;
    try {
      maxConcurrency = Integer.parseInt(maxParallelSegmentOperationsStr);
    } catch (Exception e) {
      _logger.warn("Invalid config {} set to: {}, not making change, fix config and try again", configName,
          maxParallelSegmentOperationsStr);
      return;
    }

    if (maxConcurrency <= 0) {
      _logger.warn("config {}: {} must be > 0, not making change, fix config and try again", configName,
          maxConcurrency);
      return;
    }

    if (maxConcurrency == _maxConcurrency) {
      _logger.info("No ZK update for config {}, value: {}, total permits: {}", configName, _maxConcurrency,
          totalPermits());
      return;
    }

    _logger.info("Updated config: {} from: {} to: {}", configName, _maxConcurrency, maxConcurrency);
    _maxConcurrency = maxConcurrency;

    if (!_isServingQueries) {
      _logger.info("Serving queries hasn't been enabled yet, not updating the permits with config {}", configName);
      return;
    }
    _semaphore.setPermits(_maxConcurrency);
    updateThresholdMetric(_maxConcurrency);
    _logger.info("Updated total permits: {}", totalPermits());
  }

  protected void handleMaxConcurrencyBeforeServingQueriesChange(Set<String> changedConfigs,
      Map<String, String> clusterConfigs, String configName, String defaultConfigValue) {
    if (!changedConfigs.contains(configName)) {
      _logger.info("changedConfigs list indicates config: {} was not updated, skipping updates", configName);
      return;
    }

    String maxParallelSegmentOperationsBeforeServingQueriesStr =
        clusterConfigs == null ? defaultConfigValue : clusterConfigs.getOrDefault(configName, defaultConfigValue);

    int maxConcurrencyBeforeServingQueries;
    try {
      maxConcurrencyBeforeServingQueries = Integer.parseInt(maxParallelSegmentOperationsBeforeServingQueriesStr);
    } catch (Exception e) {
      _logger.warn("Invalid config {} set to: {}, not making change, fix config and try again", configName,
          maxParallelSegmentOperationsBeforeServingQueriesStr);
      return;
    }

    if (maxConcurrencyBeforeServingQueries <= 0) {
      _logger.warn("config {}: {} must be > 0, not making change, fix config and try again", configName,
          maxConcurrencyBeforeServingQueries);
      return;
    }

    if (maxConcurrencyBeforeServingQueries == _maxConcurrencyBeforeServingQueries) {
      _logger.info("No ZK update for config: {} value: {}, total permits: {}", configName,
          _maxConcurrencyBeforeServingQueries, totalPermits());
      return;
    }

    _logger.info("Updated config: {} from: {} to: {}", configName, _maxConcurrencyBeforeServingQueries,
        maxConcurrencyBeforeServingQueries);
    _maxConcurrencyBeforeServingQueries = maxConcurrencyBeforeServingQueries;
    if (!_isServingQueries) {
      _logger.info("config: {} was updated before serving queries was enabled, updating the permits", configName);
      _semaphore.setPermits(_maxConcurrencyBeforeServingQueries);
      updateThresholdMetric(_maxConcurrencyBeforeServingQueries);
      _logger.info("Updated total permits: {}", totalPermits());
    }
  }

  /**
   * Block trying to acquire the semaphore to perform the segment operation steps unless interrupted.
   * <p>
   * {@link #release()} should be called after the segment operation completes. It is the responsibility of the caller
   * to ensure that {@link #release()} is called exactly once for each call to this method.
   *
   * @throws InterruptedException if the current thread is interrupted
   */
  public void acquire()
      throws InterruptedException {
    _semaphore.acquire();
    updateCountMetric(_numSegmentsAcquiredSemaphore.incrementAndGet());
  }

  /**
   * Should be called after the segment operation completes. It is the responsibility of the caller to
   * ensure that this method is called exactly once for each call to {@link #acquire()}.
   */
  public void release() {
    _semaphore.release();
    updateCountMetric(_numSegmentsAcquiredSemaphore.decrementAndGet());
  }

  /**
   * Get the estimated number of threads waiting for the semaphore
   * @return the estimated queue length
   */
  public int getQueueLength() {
    return _semaphore.getQueueLength();
  }

  /**
   * Get the number of available permits
   * @return number of available permits
   */
  @VisibleForTesting
  public int availablePermits() {
    return _semaphore.availablePermits();
  }

  /**
   * Get the total number of permits
   * @return total number of permits
   */
  @VisibleForTesting
  public int totalPermits() {
    return _semaphore.getTotalPermits();
  }
}
