/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.memorycontainer.memory;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.memorycontainer.MLMemoryContainer;
import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig;
import org.opensearch.ml.common.memorycontainer.MemoryType;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesResponse;
import org.opensearch.ml.common.transport.memorycontainer.memory.MemorySearchResult;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.MemoryContainerHelper;
import org.opensearch.ml.utils.MemorySearchQueryBuilder;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportSearchMemoriesAction
extends HandledTransportAction<MLSearchMemoriesRequest, MLSearchMemoriesResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportSearchMemoriesAction.class);
    private final Client client;
    private final ConnectorAccessControlHelper connectorAccessControlHelper;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final NamedXContentRegistry xContentRegistry;
    private final MemoryContainerHelper memoryContainerHelper;

    @Inject
    public TransportSearchMemoriesAction(TransportService transportService, ActionFilters actionFilters, Client client, ConnectorAccessControlHelper connectorAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, NamedXContentRegistry xContentRegistry, MemoryContainerHelper memoryContainerHelper) {
        super("cluster:admin/opensearch/ml/memory_containers/memories/search", transportService, actionFilters, MLSearchMemoriesRequest::new);
        this.client = client;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.xContentRegistry = xContentRegistry;
        this.memoryContainerHelper = memoryContainerHelper;
    }

    protected void doExecute(Task task, MLSearchMemoriesRequest request, ActionListener<MLSearchMemoriesResponse> actionListener) {
        if (!this.mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
            actionListener.onFailure((Exception)new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]));
            return;
        }
        MLSearchMemoriesInput input = request.getMlSearchMemoriesInput();
        String tenantId = request.getTenantId();
        if (input == null) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Search memories input is required"));
            return;
        }
        if (StringUtils.isBlank((CharSequence)input.getMemoryContainerId())) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Memory container ID is required"));
            return;
        }
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, actionListener)) {
            return;
        }
        this.memoryContainerHelper.getMemoryContainer(input.getMemoryContainerId(), tenantId, (ActionListener<MLMemoryContainer>)ActionListener.wrap(container -> {
            User user = RestActionUtils.getUserContext(this.client);
            if (!this.memoryContainerHelper.checkMemoryContainerAccess(user, (MLMemoryContainer)container)) {
                actionListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have permissions to search memories in this container", RestStatus.FORBIDDEN, new Object[0]));
                return;
            }
            this.searchMemories(input, (MLMemoryContainer)container, actionListener);
        }, arg_0 -> actionListener.onFailure(arg_0)));
    }

    private void searchMemories(MLSearchMemoriesInput input, MLMemoryContainer container, ActionListener<MLSearchMemoriesResponse> actionListener) {
        try {
            MemoryStorageConfig storageConfig = container.getMemoryStorageConfig();
            String indexName = storageConfig != null ? storageConfig.getMemoryIndexName() : "ml-static-memory-" + container.getName().toLowerCase() + "-" + RestActionUtils.getUserContext(this.client).getName();
            SearchRequest searchRequest = this.buildSearchRequest(input.getQuery(), storageConfig, indexName);
            this.client.search(searchRequest, ActionListener.wrap(response -> {
                try {
                    MLSearchMemoriesResponse searchResponse = this.parseSearchResponse((SearchResponse)response);
                    actionListener.onResponse((Object)searchResponse);
                }
                catch (Exception e) {
                    log.error("Failed to parse search response", (Throwable)e);
                    actionListener.onFailure((Exception)new OpenSearchException("Failed to parse search response", (Throwable)e, new Object[0]));
                }
            }, e -> {
                log.error("Search execution failed", (Throwable)e);
                actionListener.onFailure((Exception)new OpenSearchException("Search execution failed: " + e.getMessage(), (Throwable)e, new Object[0]));
            }));
        }
        catch (Exception e2) {
            log.error("Failed to build search request", (Throwable)e2);
            actionListener.onFailure((Exception)new OpenSearchException("Failed to build search request: " + e2.getMessage(), (Throwable)e2, new Object[0]));
        }
    }

    private SearchRequest buildSearchRequest(String query, MemoryStorageConfig storageConfig, String indexName) throws IOException {
        XContentBuilder queryBuilder = MemorySearchQueryBuilder.buildQueryByStorageType(query, storageConfig);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.query((QueryBuilder)QueryBuilders.wrapperQuery((String)queryBuilder.toString()));
        searchSourceBuilder.fetchSource(null, new String[]{"memory_embedding"});
        return new SearchRequest().indices(new String[]{indexName}).source(searchSourceBuilder);
    }

    private MLSearchMemoriesResponse parseSearchResponse(SearchResponse searchResponse) throws IOException {
        ArrayList<MemorySearchResult> results = new ArrayList<MemorySearchResult>();
        float maxScore = searchResponse.getHits().getMaxScore();
        for (SearchHit hit : searchResponse.getHits().getHits()) {
            Map sourceMap = hit.getSourceAsMap();
            String memoryId = hit.getId();
            String memory = (String)sourceMap.get("memory");
            float score = hit.getScore();
            String sessionId = (String)sourceMap.get("session_id");
            String agentId = (String)sourceMap.get("agent_id");
            String userId = (String)sourceMap.get("user_id");
            String role = (String)sourceMap.get("role");
            MemoryType memoryType = null;
            String memoryTypeStr = (String)sourceMap.get("memory_type");
            if (memoryTypeStr != null) {
                try {
                    memoryType = MemoryType.valueOf((String)memoryTypeStr);
                }
                catch (IllegalArgumentException e) {
                    log.warn("Invalid memory type: {}", (Object)memoryTypeStr);
                }
            }
            Map tags = (Map)sourceMap.get("tags");
            Instant createdTime = null;
            Instant lastUpdatedTime = null;
            try {
                Object lastUpdatedTimeObj;
                Object createdTimeObj = sourceMap.get("created_time");
                if (createdTimeObj instanceof Number) {
                    createdTime = Instant.ofEpochMilli(((Number)createdTimeObj).longValue());
                }
                if ((lastUpdatedTimeObj = sourceMap.get("last_updated_time")) instanceof Number) {
                    lastUpdatedTime = Instant.ofEpochMilli(((Number)lastUpdatedTimeObj).longValue());
                }
            }
            catch (Exception e) {
                log.warn("Failed to parse timestamps", (Throwable)e);
            }
            MemorySearchResult result = MemorySearchResult.builder().memoryId(memoryId).memory(memory).score(score).sessionId(sessionId).agentId(agentId).userId(userId).memoryType(memoryType).role(role).tags(tags).createdTime(createdTime).lastUpdatedTime(lastUpdatedTime).build();
            results.add(result);
        }
        return MLSearchMemoriesResponse.builder().hits(results).totalHits(searchResponse.getHits().getTotalHits().value()).maxScore(maxScore).timedOut(searchResponse.isTimedOut()).build();
    }
}

