/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.connector.functions.postprocess;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.opensearch.ml.common.connector.functions.postprocess.ConnectorPostProcessFunction;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

public class RemoteMlCommonsPassthroughPostProcessFunction
extends ConnectorPostProcessFunction<Map<String, Object>> {
    @Override
    public void validate(Object input) {
        if (!(input instanceof Map) && !(input instanceof List)) {
            throw new IllegalArgumentException("Post process function input must be a Map or List");
        }
    }

    @Override
    public List<ModelTensor> process(Map<String, Object> mlCommonsResponse, MLResultDataType dataType) {
        if (mlCommonsResponse.containsKey("inference_results") && mlCommonsResponse.get("inference_results") instanceof List) {
            List inferenceResults = (List)mlCommonsResponse.get("inference_results");
            ArrayList<ModelTensor> modelTensors = new ArrayList<ModelTensor>();
            for (Map result2 : inferenceResults) {
                if (!result2.containsKey("output") || !(result2.get("output") instanceof List)) continue;
                List outputs = (List)result2.get("output");
                for (Map output : outputs) {
                    ModelTensor modelTensor = this.createModelTensorFromMap(output);
                    if (modelTensor == null) continue;
                    modelTensors.add(modelTensor);
                }
            }
            return modelTensors;
        }
        ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(mlCommonsResponse).build();
        return List.of(tensor);
    }

    private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
        Number[] numbers;
        Object dataTypeObj;
        String castedName;
        if (map == null || map.isEmpty()) {
            return null;
        }
        Object uncastedName = map.get("name");
        String name = uncastedName instanceof String ? (castedName = (String)uncastedName) : "output";
        String result2 = (String)map.get("result");
        Map dataAsMap = (Map)map.get("dataAsMap");
        MLResultDataType dataType = null;
        if (map.containsKey("data_type") && (dataTypeObj = map.get("data_type")) instanceof String) {
            try {
                dataType = MLResultDataType.valueOf((String)dataTypeObj);
            }
            catch (IllegalArgumentException illegalArgumentException) {
                // empty catch block
            }
        }
        long[] shape = null;
        if (map.containsKey("shape") && (numbers = RemoteMlCommonsPassthroughPostProcessFunction.processNumericalArray(map, "shape", Number.class)) != null) {
            shape = Arrays.stream(numbers).mapToLong(Number::longValue).toArray();
        }
        Number[] data2 = null;
        if (map.containsKey("data")) {
            data2 = RemoteMlCommonsPassthroughPostProcessFunction.processNumericalArray(map, "data", Number.class);
        }
        return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data2).result(result2).dataAsMap(dataAsMap).build();
    }

    private static <T> T[] processNumericalArray(Map<String, Object> map, String key, Class<T> type2) {
        Object obj = map.get(key);
        if (obj instanceof List) {
            List list = (List)obj;
            Object[] array = (Object[])Array.newInstance(type2, list.size());
            for (int i = 0; i < list.size(); ++i) {
                Object item = list.get(i);
                if (!type2.isInstance(item)) continue;
                array[i] = type2.cast(item);
            }
            return array;
        }
        return null;
    }
}

