/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.functions.inference;

import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.StructKind;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorNamespace;
import org.apache.flink.annotation.Internal;
import org.apache.flink.sql.parser.type.SqlRawTypeNameSpec;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.functions.inference.ArgumentCountRange;
import org.apache.flink.table.planner.functions.inference.CallBindingCallContext;
import org.apache.flink.table.planner.plan.schema.RawRelDataType;
import org.apache.flink.table.planner.typeutils.LogicalRelDataTypeConverter;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeInferenceUtil;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RawType;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;

@Internal
public final class TypeInferenceOperandChecker
implements SqlOperandTypeChecker,
SqlOperandMetadata {
    private final DataTypeFactory dataTypeFactory;
    private final FunctionDefinition definition;
    private final TypeInference typeInference;
    private final SqlOperandCountRange countRange;

    public TypeInferenceOperandChecker(DataTypeFactory dataTypeFactory, FunctionDefinition definition, TypeInference typeInference) {
        this.dataTypeFactory = dataTypeFactory;
        this.definition = definition;
        this.typeInference = typeInference;
        this.countRange = new ArgumentCountRange(TypeInferenceOperandChecker.deriveArgumentCount(typeInference));
    }

    @Override
    public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
        CallBindingCallContext callContext = new CallBindingCallContext(this.dataTypeFactory, this.definition, callBinding, null, this.typeInference.getStaticArguments().orElse(null));
        try {
            return this.checkOperandTypesOrError(callBinding, callContext);
        }
        catch (ValidationException e) {
            if (throwOnFailure) {
                throw TypeInferenceUtil.createInvalidCallException((CallContext)callContext, (ValidationException)e);
            }
            return false;
        }
        catch (Throwable t) {
            throw TypeInferenceUtil.createUnexpectedException((CallContext)callContext, (Throwable)t);
        }
    }

    @Override
    public SqlOperandCountRange getOperandCountRange() {
        return this.countRange;
    }

    @Override
    public String getAllowedSignatures(SqlOperator op, String opName) {
        return TypeInferenceUtil.generateSignature((TypeInference)this.typeInference, (String)opName, (FunctionDefinition)this.definition);
    }

    @Override
    public SqlOperandTypeChecker.Consistency getConsistency() {
        return SqlOperandTypeChecker.Consistency.NONE;
    }

    @Override
    public boolean isOptional(int i) {
        if (this.typeInference.getStaticArguments().isEmpty()) {
            return false;
        }
        List staticArgs = (List)this.typeInference.getStaticArguments().get();
        return ((StaticArgument)staticArgs.get(i)).isOptional();
    }

    @Override
    public boolean isFixedParameters() {
        return this.typeInference.getStaticArguments().map(args -> args.stream().anyMatch(StaticArgument::isOptional)).orElse(false);
    }

    @Override
    public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
        return this.typeInference.getStaticArguments().map(args -> args.stream().map(arg -> this.toParamType(typeFactory, (StaticArgument)arg)).collect(Collectors.toList())).orElseThrow(() -> new ValidationException("Unsupported function signature. Function must not be overloaded or use varargs."));
    }

    @Override
    public List<String> paramNames() {
        return this.typeInference.getStaticArguments().map(args -> args.stream().map(StaticArgument::getName).collect(Collectors.toList())).orElseThrow(() -> new ValidationException("Unsupported function signature. Function must not be overloaded or use varargs."));
    }

    private RelDataType toParamType(RelDataTypeFactory typeFactory, StaticArgument arg) {
        LogicalType type = arg.getDataType().map(DataType::getLogicalType).orElse(null);
        if (type == null) {
            return typeFactory.createSqlType(SqlTypeName.ANY);
        }
        if (arg.is(StaticArgumentTrait.TABLE) && LogicalTypeChecks.isCompositeType((LogicalType)type)) {
            return typeFactory.createStructType(StructKind.FULLY_QUALIFIED, LogicalTypeChecks.getFieldTypes((LogicalType)type).stream().map(t -> LogicalRelDataTypeConverter.toRelDataType(t, typeFactory)).collect(Collectors.toList()), LogicalTypeChecks.getFieldNames((LogicalType)type));
        }
        return LogicalRelDataTypeConverter.toRelDataType(type, typeFactory);
    }

    private boolean checkOperandTypesOrError(SqlCallBinding callBinding, CallContext callContext) {
        CallContext castCallContext;
        try {
            castCallContext = TypeInferenceUtil.castArguments((TypeInference)this.typeInference, (CallContext)callContext, null);
        }
        catch (ValidationException e) {
            throw TypeInferenceUtil.createInvalidInputException((TypeInference)this.typeInference, (CallContext)callContext, (ValidationException)e);
        }
        this.insertImplicitCasts(callBinding, castCallContext.getArgumentDataTypes());
        return true;
    }

    private void insertImplicitCasts(SqlCallBinding callBinding, List<DataType> expectedDataTypes) {
        FlinkTypeFactory flinkTypeFactory = ShortcutUtils.unwrapTypeFactory(callBinding);
        List<SqlNode> operands = callBinding.operands();
        for (int i = 0; i < operands.size(); ++i) {
            LogicalType argumentType;
            LogicalType expectedType = expectedDataTypes.get(i).getLogicalType();
            SqlNode sqlNode = operands.get(i);
            if (sqlNode.getKind() == SqlKind.DEFAULT || LogicalTypeCasts.supportsAvoidingCast((LogicalType)(argumentType = FlinkTypeFactory.toLogicalType(SqlTypeUtil.deriveType(callBinding, operands.get(i)))), (LogicalType)expectedType)) continue;
            RelDataType expectedRelDataType = flinkTypeFactory.createFieldTypeFromLogicalType(expectedType);
            SqlNode castedOperand = this.castTo(operands.get(i), expectedRelDataType);
            callBinding.getCall().setOperand(i, castedOperand);
            this.updateInferredType(callBinding.getValidator(), castedOperand, expectedRelDataType);
        }
    }

    private SqlNode castTo(SqlNode node, RelDataType type) {
        SqlDataTypeSpec dataType = type instanceof RawRelDataType ? this.createRawDataTypeSpec((RawRelDataType)type) : SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable());
        return SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, node, dataType);
    }

    private SqlDataTypeSpec createRawDataTypeSpec(RawRelDataType type) {
        RawType<?> rawType = type.getRawType();
        SqlCharStringLiteral className = SqlLiteral.createCharString(rawType.getOriginatingClass().getName(), SqlParserPos.ZERO);
        SqlCharStringLiteral serializer = SqlLiteral.createCharString(rawType.getSerializerString(), SqlParserPos.ZERO);
        SqlRawTypeNameSpec rawSpec = new SqlRawTypeNameSpec(className, serializer, SqlParserPos.ZERO);
        return new SqlDataTypeSpec(rawSpec, null, type.isNullable(), SqlParserPos.ZERO);
    }

    private void updateInferredType(SqlValidator validator, SqlNode node, RelDataType type) {
        validator.setValidatedNodeType(node, type);
        SqlValidatorNamespace namespace = validator.getNamespace(node);
        if (namespace != null) {
            namespace.setType(type);
        }
    }

    private static ArgumentCount deriveArgumentCount(TypeInference typeInference) {
        int staticArgs = typeInference.getStaticArguments().map(List::size).orElse(-1);
        if (staticArgs == -1) {
            return typeInference.getInputTypeStrategy().getArgumentCount();
        }
        int optionalArgs = typeInference.getStaticArguments().map(args -> (int)args.stream().filter(StaticArgument::isOptional).count()).orElse(0);
        return ConstantArgumentCount.between((int)(staticArgs - optionalArgs), (int)staticArgs);
    }
}

