/*
 * Decompiled with CFR 0.152.
 */
package org.apache.dolphinscheduler.plugin.task.sql;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.dolphinscheduler.plugin.datasource.api.plugin.DataSourceClientProvider;
import org.apache.dolphinscheduler.plugin.datasource.api.utils.CommonUtils;
import org.apache.dolphinscheduler.plugin.datasource.api.utils.DatasourceUtil;
import org.apache.dolphinscheduler.plugin.task.api.AbstractTaskExecutor;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;
import org.apache.dolphinscheduler.plugin.task.sql.SqlBinds;
import org.apache.dolphinscheduler.plugin.task.sql.SqlParameters;
import org.apache.dolphinscheduler.plugin.task.sql.SqlType;
import org.apache.dolphinscheduler.plugin.task.util.MapUtils;
import org.apache.dolphinscheduler.spi.datasource.BaseConnectionParam;
import org.apache.dolphinscheduler.spi.datasource.ConnectionParam;
import org.apache.dolphinscheduler.spi.enums.DataType;
import org.apache.dolphinscheduler.spi.enums.DbType;
import org.apache.dolphinscheduler.spi.enums.TaskTimeoutStrategy;
import org.apache.dolphinscheduler.spi.task.AbstractParameters;
import org.apache.dolphinscheduler.spi.task.Direct;
import org.apache.dolphinscheduler.spi.task.Property;
import org.apache.dolphinscheduler.spi.task.TaskAlertInfo;
import org.apache.dolphinscheduler.spi.task.paramparser.ParamUtils;
import org.apache.dolphinscheduler.spi.task.paramparser.ParameterUtils;
import org.apache.dolphinscheduler.spi.task.request.SQLTaskExecutionContext;
import org.apache.dolphinscheduler.spi.task.request.TaskRequest;
import org.apache.dolphinscheduler.spi.task.request.UdfFuncRequest;
import org.apache.dolphinscheduler.spi.utils.JSONUtils;
import org.apache.dolphinscheduler.spi.utils.StringUtils;
import org.slf4j.Logger;

public class SqlTask
extends AbstractTaskExecutor {
    private TaskRequest taskExecutionContext;
    private SqlParameters sqlParameters;
    private BaseConnectionParam baseConnectionParam;
    private static final String CREATE_FUNCTION_FORMAT = "create temporary function {0} as ''{1}''";
    private static final int QUERY_LIMIT = 10000;

    public SqlTask(TaskRequest taskRequest) {
        super(taskRequest);
        this.taskExecutionContext = taskRequest;
        this.sqlParameters = (SqlParameters)((Object)JSONUtils.parseObject((String)this.taskExecutionContext.getTaskParams(), SqlParameters.class));
        assert (this.sqlParameters != null);
        if (!this.sqlParameters.checkParameters()) {
            throw new RuntimeException("sql task params is not valid");
        }
    }

    public AbstractParameters getParameters() {
        return this.sqlParameters;
    }

    public void handle() throws Exception {
        this.logger.info("Full sql parameters: {}", (Object)this.sqlParameters);
        this.logger.info("sql type : {}, datasource : {}, sql : {} , localParams : {},udfs : {},showType : {},connParams : {},varPool : {} ,query max result limit  {}", new Object[]{this.sqlParameters.getType(), this.sqlParameters.getDatasource(), this.sqlParameters.getSql(), this.sqlParameters.getLocalParams(), this.sqlParameters.getUdfs(), this.sqlParameters.getShowType(), this.sqlParameters.getConnParams(), this.sqlParameters.getVarPool(), this.sqlParameters.getLimit()});
        try {
            SQLTaskExecutionContext sqlTaskExecutionContext = this.taskExecutionContext.getSqlTaskExecutionContext();
            this.baseConnectionParam = (BaseConnectionParam)DatasourceUtil.buildConnectionParams((DbType)DbType.valueOf((String)this.sqlParameters.getType()), (String)sqlTaskExecutionContext.getConnectionParams());
            SqlBinds mainSqlBinds = this.getSqlAndSqlParamsMap(this.sqlParameters.getSql());
            List<SqlBinds> preStatementSqlBinds = ((List)Optional.ofNullable(this.sqlParameters.getPreStatements()).orElse(new ArrayList())).stream().map(this::getSqlAndSqlParamsMap).collect(Collectors.toList());
            List<SqlBinds> postStatementSqlBinds = ((List)Optional.ofNullable(this.sqlParameters.getPostStatements()).orElse(new ArrayList())).stream().map(this::getSqlAndSqlParamsMap).collect(Collectors.toList());
            List<String> createFuncs = SqlTask.createFuncs(sqlTaskExecutionContext.getUdfFuncTenantCodeMap(), sqlTaskExecutionContext.getDefaultFS(), this.logger);
            this.executeFuncAndSql(mainSqlBinds, preStatementSqlBinds, postStatementSqlBinds, createFuncs);
            this.setExitStatusCode(0);
        }
        catch (Exception e) {
            this.setExitStatusCode(-1);
            this.logger.error("sql task error: {}", (Object)e.toString());
            throw e;
        }
    }

    public void executeFuncAndSql(SqlBinds mainSqlBinds, List<SqlBinds> preStatementsBinds, List<SqlBinds> postStatementsBinds, List<String> createFuncs) throws Exception {
        Connection connection = null;
        PreparedStatement stmt = null;
        ResultSet resultSet = null;
        try {
            connection = DataSourceClientProvider.getInstance().getConnection(DbType.valueOf((String)this.sqlParameters.getType()), (ConnectionParam)this.baseConnectionParam);
            if (CollectionUtils.isNotEmpty(createFuncs)) {
                this.createTempFunction(connection, createFuncs);
            }
            this.preSql(connection, preStatementsBinds);
            stmt = this.prepareStatementAndBind(connection, mainSqlBinds);
            String result = null;
            if (this.sqlParameters.getSqlType() == SqlType.QUERY.ordinal()) {
                resultSet = stmt.executeQuery();
                result = this.resultProcess(resultSet);
            } else if (this.sqlParameters.getSqlType() == SqlType.NON_QUERY.ordinal()) {
                String updateResult = String.valueOf(stmt.executeUpdate());
                result = this.setNonQuerySqlReturn(updateResult, this.sqlParameters.getLocalParams());
            }
            this.sqlParameters.dealOutParam(result);
            this.postSql(connection, postStatementsBinds);
            this.close(resultSet, stmt, connection);
        }
        catch (Exception e) {
            try {
                this.logger.error("execute sql error: {}", (Object)e.getMessage());
                throw e;
            }
            catch (Throwable throwable) {
                this.close(resultSet, stmt, connection);
                throw throwable;
            }
        }
    }

    private String setNonQuerySqlReturn(String updateResult, List<Property> properties) {
        String result = null;
        for (Property info : properties) {
            if (Direct.OUT != info.getDirect()) continue;
            ArrayList updateRL = new ArrayList();
            HashMap<String, String> updateRM = new HashMap<String, String>();
            updateRM.put(info.getProp(), updateResult);
            updateRL.add(updateRM);
            result = JSONUtils.toJsonString(updateRL);
            break;
        }
        return result;
    }

    private String resultProcess(ResultSet resultSet) throws Exception {
        ArrayNode resultJSONArray = JSONUtils.createArrayNode();
        if (resultSet != null) {
            int i;
            int limit;
            ResultSetMetaData md = resultSet.getMetaData();
            int num = md.getColumnCount();
            int n = limit = this.sqlParameters.getLimit() == 0 ? 10000 : this.sqlParameters.getLimit();
            for (int rowCount = 0; rowCount < limit && resultSet.next(); ++rowCount) {
                ObjectNode mapOfColValues = JSONUtils.createObjectNode();
                for (i = 1; i <= num; ++i) {
                    mapOfColValues.set(md.getColumnLabel(i), JSONUtils.toJsonNode((Object)resultSet.getObject(i)));
                }
                resultJSONArray.add((JsonNode)mapOfColValues);
            }
            int displayRows = this.sqlParameters.getDisplayRows() > 0 ? this.sqlParameters.getDisplayRows() : 10;
            displayRows = Math.min(displayRows, resultJSONArray.size());
            this.logger.info("display sql result {} rows as follows:", (Object)displayRows);
            for (i = 0; i < displayRows; ++i) {
                String row = JSONUtils.toJsonString((Object)resultJSONArray.get(i));
                this.logger.info("row {} : {}", (Object)(i + 1), (Object)row);
            }
            if (resultSet.next()) {
                this.logger.info("sql result limit : {} exceeding results are filtered", (Object)limit);
                String log = String.format("sql result limit : %d exceeding results are filtered", limit);
                resultJSONArray.add(JSONUtils.toJsonNode((Object)log));
            }
        }
        String result = JSONUtils.toJsonString((Object)resultJSONArray);
        if (this.sqlParameters.getSendEmail() == null || this.sqlParameters.getSendEmail().booleanValue()) {
            this.sendAttachment(this.sqlParameters.getGroupId(), StringUtils.isNotEmpty((CharSequence)this.sqlParameters.getTitle()) ? this.sqlParameters.getTitle() : this.taskExecutionContext.getTaskName() + " query result sets", result);
        }
        this.logger.debug("execute sql result : {}", (Object)result);
        return result;
    }

    private void sendAttachment(int groupId, String title, String content) {
        this.setNeedAlert(Boolean.TRUE);
        TaskAlertInfo taskAlertInfo = new TaskAlertInfo();
        taskAlertInfo.setAlertGroupId(Integer.valueOf(groupId));
        taskAlertInfo.setContent(content);
        taskAlertInfo.setTitle(title);
        this.setTaskAlertInfo(taskAlertInfo);
    }

    private void preSql(Connection connection, List<SqlBinds> preStatementsBinds) throws Exception {
        for (SqlBinds sqlBind : preStatementsBinds) {
            PreparedStatement pstmt = this.prepareStatementAndBind(connection, sqlBind);
            Throwable throwable = null;
            try {
                int result = pstmt.executeUpdate();
                this.logger.info("pre statement execute result: {}, for sql: {}", (Object)result, (Object)sqlBind.getSql());
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (pstmt == null) continue;
                if (throwable != null) {
                    try {
                        pstmt.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                pstmt.close();
            }
        }
    }

    private void postSql(Connection connection, List<SqlBinds> postStatementsBinds) throws Exception {
        for (SqlBinds sqlBind : postStatementsBinds) {
            PreparedStatement pstmt = this.prepareStatementAndBind(connection, sqlBind);
            Throwable throwable = null;
            try {
                int result = pstmt.executeUpdate();
                this.logger.info("post statement execute result: {},for sql: {}", (Object)result, (Object)sqlBind.getSql());
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (pstmt == null) continue;
                if (throwable != null) {
                    try {
                        pstmt.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                pstmt.close();
            }
        }
    }

    private void createTempFunction(Connection connection, List<String> createFuncs) throws Exception {
        try (Statement funcStmt = connection.createStatement();){
            for (String createFunc : createFuncs) {
                this.logger.info("hive create function sql: {}", (Object)createFunc);
                funcStmt.execute(createFunc);
            }
        }
    }

    private void close(ResultSet resultSet, PreparedStatement pstmt, Connection connection) {
        if (resultSet != null) {
            try {
                resultSet.close();
            }
            catch (SQLException e) {
                this.logger.error("close result set error : {}", (Object)e.getMessage(), (Object)e);
            }
        }
        if (pstmt != null) {
            try {
                pstmt.close();
            }
            catch (SQLException e) {
                this.logger.error("close prepared statement error : {}", (Object)e.getMessage(), (Object)e);
            }
        }
        if (connection != null) {
            try {
                connection.close();
            }
            catch (SQLException e) {
                this.logger.error("close connection error : {}", (Object)e.getMessage(), (Object)e);
            }
        }
    }

    private PreparedStatement prepareStatementAndBind(Connection connection, SqlBinds sqlBinds) {
        boolean timeoutFlag = this.taskExecutionContext.getTaskTimeoutStrategy() == TaskTimeoutStrategy.FAILED || this.taskExecutionContext.getTaskTimeoutStrategy() == TaskTimeoutStrategy.WARNFAILED;
        try {
            Map<Integer, Property> params;
            PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql());
            if (timeoutFlag) {
                stmt.setQueryTimeout(this.taskExecutionContext.getTaskTimeout());
            }
            if ((params = sqlBinds.getParamsMap()) != null) {
                for (Map.Entry<Integer, Property> entry : params.entrySet()) {
                    Property prop = entry.getValue();
                    ParameterUtils.setInParameter((int)entry.getKey(), (PreparedStatement)stmt, (DataType)prop.getType(), (String)prop.getValue());
                }
            }
            this.logger.info("prepare statement replace sql : {} ", (Object)stmt);
            return stmt;
        }
        catch (Exception exception) {
            throw new TaskException("SQL task prepareStatementAndBind error", (Throwable)exception);
        }
    }

    private void printReplacedSql(String content, String formatSql, String rgex, Map<Integer, Property> sqlParamsMap) {
        this.logger.info("after replace sql , preparing : {}", (Object)formatSql);
        StringBuilder logPrint = new StringBuilder("replaced sql , parameters:");
        if (sqlParamsMap == null) {
            this.logger.info("printReplacedSql: sqlParamsMap is null.");
        } else {
            for (int i = 1; i <= sqlParamsMap.size(); ++i) {
                logPrint.append(sqlParamsMap.get(i).getValue()).append("(").append(sqlParamsMap.get(i).getType()).append(")");
            }
        }
        this.logger.info("Sql Params are {}", (Object)logPrint);
    }

    private SqlBinds getSqlAndSqlParamsMap(String sql) {
        HashMap<Integer, Property> sqlParamsMap = new HashMap<Integer, Property>();
        StringBuilder sqlBuilder = new StringBuilder();
        sql = ParameterUtils.replaceScheduleTime((String)sql, (Date)this.taskExecutionContext.getScheduleTime());
        Map paramsMap = ParamUtils.convert((TaskRequest)this.taskExecutionContext, (AbstractParameters)this.getParameters());
        if (paramsMap == null) {
            sqlBuilder.append(sql);
            return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
        }
        if (StringUtils.isNotEmpty((CharSequence)this.sqlParameters.getTitle())) {
            String title = ParameterUtils.convertParameterPlaceholders((String)this.sqlParameters.getTitle(), (Map)ParamUtils.convert((Map)paramsMap));
            this.logger.info("SQL title : {}", (Object)title);
            this.sqlParameters.setTitle(title);
        }
        this.setSqlParamsMap(sql, this.rgex, sqlParamsMap, paramsMap, this.taskExecutionContext.getTaskInstanceId());
        String rgexo = "['\"]*\\!\\{(.*?)\\}['\"]*";
        sql = this.replaceOriginalValue(sql, rgexo, paramsMap);
        String formatSql = sql.replaceAll(this.rgex, "?");
        sqlBuilder.append(formatSql);
        this.printReplacedSql(sql, formatSql, this.rgex, sqlParamsMap);
        return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
    }

    private String replaceOriginalValue(String content, String rgex, Map<String, Property> sqlParamsMap) {
        Matcher m;
        Pattern pattern = Pattern.compile(rgex);
        while ((m = pattern.matcher(content)).find()) {
            String paramName = m.group(1);
            String paramValue = sqlParamsMap.get(paramName).getValue();
            content = m.replaceFirst(paramValue);
        }
        return content;
    }

    public static List<String> createFuncs(Map<UdfFuncRequest, String> udfFuncTenantCodeMap, String defaultFS, Logger logger) {
        if (MapUtils.isEmpty(udfFuncTenantCodeMap)) {
            logger.info("can't find udf function resource");
            return null;
        }
        ArrayList<String> funcList = new ArrayList<String>();
        SqlTask.buildJarSql(funcList, udfFuncTenantCodeMap, defaultFS);
        SqlTask.buildTempFuncSql(funcList, new ArrayList<UdfFuncRequest>(udfFuncTenantCodeMap.keySet()));
        return funcList;
    }

    private static void buildTempFuncSql(List<String> sqls, List<UdfFuncRequest> udfFuncRequests) {
        if (CollectionUtils.isNotEmpty(udfFuncRequests)) {
            for (UdfFuncRequest udfFuncRequest : udfFuncRequests) {
                sqls.add(MessageFormat.format(CREATE_FUNCTION_FORMAT, udfFuncRequest.getFuncName(), udfFuncRequest.getClassName()));
            }
        }
    }

    private static void buildJarSql(List<String> sqls, Map<UdfFuncRequest, String> udfFuncTenantCodeMap, String defaultFS) {
        Set<Map.Entry<UdfFuncRequest, String>> entries = udfFuncTenantCodeMap.entrySet();
        for (Map.Entry<UdfFuncRequest, String> entry : entries) {
            String prefixPath = defaultFS.startsWith("file://") ? "file://" : defaultFS;
            String uploadPath = CommonUtils.getHdfsUdfDir((String)entry.getValue());
            String resourceFullName = entry.getKey().getResourceName();
            resourceFullName = resourceFullName.startsWith("/") ? resourceFullName : String.format("/%s", resourceFullName);
            sqls.add(String.format("add jar %s%s%s", prefixPath, uploadPath, resourceFullName));
        }
    }
}

