import { HttpClient, HttpHeaders, HttpParams } from '@angular/common/http';
import { Injectable } from '@angular/core';
import {
  combineLatest,
  Observable,
  of,
  throwError,
} from 'rxjs';
import {
  catchError,
  concatAll,
  delay,
  map,
  retryWhen,
  switchMap,
  take,
  tap,
  toArray,
} from 'rxjs/operators';
import { environment } from 'src/environments/environment';
import {
  TrainingSession,
  GenericParameter,
  MappingNode,
  TrainingModel,
  TrainingParadigm,
  StopCondition,
  TrainingSessionList,
  NodeList,
  Infrastructure,
  PrivateAttrs,
  TrainingSessionStatus,
  TrainingSessionRounds,
  UrlModelParameters,
  TrainingSessionState,
  EntityStatus,
  MetricsConfig,
} from '../models/';
import { Node, TrainingActions } from 'src/app/platform';
import { NodeService } from './node.service';
import { TrainingService } from 'src/app/platform/services/training.service';
import { getTrainedText, getUtcDateFormatted } from 'src/app/utils';

const headers = new HttpHeaders({'Content-Type':'application/json'});

const trainingSessionStateCodesPriority = [
  {
    priority: 1,
    ids: [3023, 3024, 3025, 3026, 3027, 3028, 3029],
    state: TrainingSessionState.aborted,
  },
  {
    priority: 1,
    ids: [3022],
    state: TrainingSessionState.stopped,
  },
  { priority: 2, ids: [3021], state: TrainingSessionState.finished },
  { priority: 3, ids: [3005], state: TrainingSessionState.paused },
  { priority: 4, ids: [3004, 3006], state: TrainingSessionState.running },
  { priority: 5, ids: [0, 3007], state: TrainingSessionState.pending },
];
@Injectable({
  providedIn: 'root',
})
export class TrainingSessionService {
  baseUrl: string = environment.baseUrl;

  constructor(
    private http: HttpClient,
    private _nodeService: NodeService,
    private _trainingService: TrainingService
    ) {}

  //TODO: Use this function to calculate session state and assign to Training Session model
  //after call /session/{timestamp}
  private calculateTrainingSessionState = (entityStatus?: EntityStatus[]): TrainingSessionState | undefined =>   {
    if (!entityStatus) {
      return undefined;
    }

    const stateCodesPriority = trainingSessionStateCodesPriority;
    const entityStatusCodeArray = entityStatus.map((entity) => entity.id);

    for (let i = 0; i<stateCodesPriority.length; i++) {
      const containCode = stateCodesPriority[i].ids.some((item) => entityStatusCodeArray.includes(item));
      if (containCode) {
        return stateCodesPriority[i].state;
      }
    }
    return undefined;
  };

  getTrainingSession(
    org: any,
    project: any,
    timestamp: any
    ): Observable<TrainingSession> {
      const sessionRequest = this.http.get<TrainingSession>(`${this.baseUrl}orgs/${org}/projects/${project}/sessions/${timestamp}`);
      return combineLatest([
        sessionRequest,
        this.allTrainingSessionNodeMappings(org, project, timestamp)
      ])
      .pipe(
        map(res => {
          const [session, mappings] = res
          const state = this.calculateTrainingSessionState(session.entityStatus)
          return {
            ...session,
            mappings,
            state
          }
        }),
        switchMap(
          (trainingSession: TrainingSession) =>
          this.getTrainingSessionRounds(org, project, trainingSession)
             .pipe(
                switchMap(
                  async (rounds) => ({
                    ...trainingSession,
                    rounds
                  })
                )
              )
        ),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  getTrainingSessionWithoutRounds(
    org: any,
    project: any,
    timestamp: any
    ): Observable<TrainingSession> {
      const sessionRequest = this.http.get<TrainingSession>(`${this.baseUrl}orgs/${org}/projects/${project}/sessions/${timestamp}`);
      return combineLatest([
        sessionRequest,
        this.allTrainingSessionNodeMappings(org, project, timestamp)
      ])
      .pipe(
        map(res => {
          const [session, mappings] = res
          const state = this.calculateTrainingSessionState(session.entityStatus)
          return {
            ...session,
            mappings,
            state
          }
        }),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  getTrainingSessionStatus(
    orgId: string | undefined,
    projectId: string | undefined,
    trainingId: string | undefined,
    trainingStatusTimestamp: string | undefined
  ): Observable<TrainingSessionStatus | never> {
    return this.http.get<TrainingSessionStatus>(
      `${this.baseUrl}orgs/${orgId}/projects/${projectId}/trainings/${trainingId}/sessions/${trainingStatusTimestamp}/status`,
      { headers }
    )
    .pipe(
      map((res) => {
        // res.animation = (TrainingAnimationStatus as any)[res.status]; TO DO
        res.summaryTrainingSessionExecution.lastRound++;
        return res as TrainingSessionStatus;
      }),
      catchError((res) => {
        return throwError(res?.error);
      })
    );
  }

  getTrainingSessionRounds(
    orgId: string | undefined,
    projectId: string | undefined,
    session: any,
    nextPageToken: string = '0',
    accumulatedItems: any[] = []
  ): Observable<TrainingSessionRounds | never> {
    let params = new HttpParams()
      .append('pageStartPoint', nextPageToken)
      .append('pageSize', '100')
      .append('page', '0');

    return this.http.get<TrainingSessionRounds>(
      `${this.baseUrl}orgs/${orgId}/projects/${projectId}/sessions/${session.trainingSessionTimestamp}/rounds`,
      { params: params }
    ).pipe(
      switchMap(response => {
        const newItems = accumulatedItems.concat(response.items);
        if (response.nextItem && response.nextItem !== 'null') {
          return this.getTrainingSessionRounds(orgId, projectId, session, response.nextItem, newItems);
        } else {
          response.items = newItems;
          return of(response).pipe(
            map(res => {
              if (res.items && res.items.length > 0) {
                res.items = res.items.sort((a, b) => a.round - b.round);

                res.startTime = res.items[0].stamp?.created;
                res.startTimestampFirstRound = res.items[0].startTimestamp;
                res.endTimestampLastRound = res.items[res.items.length - 1].endTimestamp;

                res.items.forEach((roundData) => {
                  roundData.round++;
                  roundData.epoch++;
                  res.currentRound = roundData;
                });

                if (session.state === 'RUNNING') {
                  if (session.summaryTrainingSessionExecution?.lastRound) {
                    res.items = res.items.slice(
                      1,
                      session.summaryTrainingSessionExecution.lastRound
                    );
                  }
                } else {
                  if (session.summaryTrainingSessionExecution?.lastRound) {
                    res.items = res.items.slice(
                      1,
                      session.summaryTrainingSessionExecution.lastRound + 1
                    );
                  }
                }
              }
              return res as TrainingSessionRounds;
            })
          );
        }
      }),
      catchError(error => {
        return throwError(error?.error);
      })
    );
  }

  getTrainingSessionId(
    org: any,
    project: any,
    training: any
  ): Observable<TrainingSession> {
    return this.http
      .get<TrainingSessionList>(
        `${this.baseUrl}orgs/${org}/projects/${project}/trainings/${training}/sessions`
      )
      .pipe(
        map((res: TrainingSessionList) => {

          if (res.items.length === 1) {
            return res.items[0];
          }
          else if (res.items.length > 1) {
            throw new Error('Error: More than one training session found');
          }
          else {
            throw new Error('Error: No session found');
          }
        }),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  getAllTrainingSession(
    org: any,
    project: any
  ): Observable<TrainingSession[]> {
    return this.http
      .get<TrainingSessionList>(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions`
      )
      .pipe(
        map((res: TrainingSessionList) => res.items.map((trainingSession) => {
          trainingSession.state = this.calculateTrainingSessionState(trainingSession.entityStatus);
          return trainingSession;
        })),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  allTrainingSessionNodeMappings(
    org: any,
    project: any,
    session: any
  ): Observable<MappingNode[]> {
    const params = new HttpParams()
      .set('page', '0')
      .set('pageSize', '100')
      .set('pagination', false);

    return this.http.get<{ pageSize: number; page: number; items: MappingNode[] }>(
      `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes/mappings`,
      { headers, params }
    )
    .pipe(
      map((res) => {
        return res.items;
      }),
      catchError((error) => {
        return throwError(error);
      })
    );
  }

  startTraining(
    orgId: any,
    projectId: any,
    timestamp: string | undefined
  ) {
    return this.http.post(
      `${this.baseUrl}orgs/${orgId}/projects/${projectId}/sessions/${timestamp}/start`,
      {},
      { headers }
    )
    .pipe(
      catchError((res) => {
        return throwError(res?.error);
      })
    );
  }

  stopTraining(
    orgId: any,
    projectId: any,
    timestamp: string | undefined
  ) {
    return this.http.post(
      `${this.baseUrl}orgs/${orgId}/projects/${projectId}/sessions/${timestamp}/stop`,
      {},
      { headers }
    )
    .pipe(
      catchError((res) => {
        return throwError(res?.error?.message);
      })
    );
  }

  getParadigm(
    org: any,
    project: any,
    training: any
  ): Observable<TrainingParadigm> {
    return this.http
      .get<TrainingParadigm>(
        `${this.baseUrl}orgs/${org}/projects/${project}/trainings/${training}/paradigm`
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  getModel(org: any, project: any, training: any): Observable<TrainingModel> {
    return this.http
      .get<TrainingModel>(
        `${this.baseUrl}orgs/${org}/projects/${project}/trainings/${training}/model`
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  updateParadigm(
    org: any,
    project: any,
    training: any,
    paradigmCode: string,
    parameters: GenericParameter[]
  ): Observable<TrainingParadigm> {
    const params = {
      id: paradigmCode,
      parameters,
    };
    return this.http
      .put<TrainingParadigm>(
        `${this.baseUrl}orgs/${org}/projects/${project}/trainings/${training}/paradigm`,
        params
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  updateModel(
    org: any,
    project: any,
    training: any,
    modelCode: string,
    hyperparameters: GenericParameter[]
  ): Observable<TrainingModel> {
    const params = {
      id: modelCode,
      hyperparameters,
    };
    return this.http
      .put<TrainingModel>(
        `${this.baseUrl}orgs/${org}/projects/${project}/trainings/${training}/model`,
        params
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  updateMechanics(org: any, project: any, training: any, data: any) {
    return combineLatest([
      this.updateModel(
        org,
        project,
        training,
        data.idModel,
        data.hyperparameters
      ),
      this.updateParadigm(
        org,
        project,
        training,
        data.idParadigm,
        data.parameters
      ),
    ]);
  }

  mappingNodes(
    org: any,
    project: any,
    session: any,
    nodes: MappingNode[]
  ) {
    const mappingNodeCalls = nodes.map((node) => {
      return this.mappingNode(org, project, session, node);
    });
    return combineLatest(mappingNodeCalls);
  }

  mappingNode(
    org: any,
    project: any,
    session: any,
    node: MappingNode
  ) {
    let data: any = {
      nodeSuid: node.nodeSuid,
      role: node.role,
    };

    if (node.differentialPrivacy) {
      data = {
        ...data,
        differentialPrivacy: node.differentialPrivacy,
      };
    }

    if (node.file) {
      data = { ...data,
        file: {
        uuid: node.file?.uuid,
        path: `orgs/${org}/projects/${project}/sessions/${session}/files/`,
          }
        };
    }

    if (node.softDeletePolicy) {
      data = { ...data, softDeletePolicy: node.softDeletePolicy };
    }

    if (node.erParams) {
      data = { ...data, erParams: node.erParams };
    }

    const httpCall = !node.isMapped
      ? this.http.post(
          `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes/${node.nodeSuid}/mappings`,
          data
        )
      : this.http.put(
          `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes/${node.nodeSuid}/mappings`,
          data
        );

    return httpCall.pipe(
      tap((res) => res),
      catchError((error) => {
        return throwError(error);
      })
    );
  }
  __setState(trSession: TrainingSession): TrainingSession {
    if (trSession) {
      if (trSession.entityStatus) {
        trSession.entityStatus.forEach((eStatus) => {
          if (eStatus.id === 3025) {
            trSession.state = TrainingSessionState.aborted;
          } else if (eStatus.id === 3021 || eStatus.id > 3022) {
            trSession.state = TrainingSessionState.finished;
            trSession.finished = eStatus.timestamp;
          } else if (eStatus.id === 3004) {
            trSession.state = TrainingSessionState.running;
            trSession.started = eStatus.timestamp;
          } else if (eStatus.id === 3006 || eStatus.id === 3022) {
            trSession.state = TrainingSessionState.stopped;
          } else if (eStatus.id === 3010) {
            trSession.state = TrainingSessionState.deployed;
            trSession.deployed = eStatus.timestamp;
          }
        });
      }
      if (!trSession.state) {
        if (
          trSession.mappings &&
          trSession.mappings.length > 0 &&
          trSession.stopConditions &&
          trSession.stopConditions.length > 0 &&
          trSession.stopConditions[0].value > 1
        ) {
          // PROVISIONAL
          trSession.state = TrainingSessionState.mapped;
        } else {
          trSession.state = TrainingSessionState.incomplete;
        }
      }
      if (trSession?.summaryTrainingSessionExecution?.lastRound) {
        trSession.summaryTrainingSessionExecution.lastRound++;
      }
    }
    return trSession;
  }
  __setMaxIterations(trSession: TrainingSession): TrainingSession {
    if (trSession) {
      if (trSession.stopConditions) {
        trSession.stopConditions.forEach((stopCondition) => {
          if (stopCondition.condition === 'maxRounds') {
            trSession.maxRounds = stopCondition.value;
          }
          if (stopCondition.condition === 'maxEpochs') {
            trSession.maxEpochs = stopCondition.value;
          }
        });
      }
    }
    return trSession;
  }

  setMetricsConfig(
    org: any,
    project: any,
    training: any,
    session: any,
    metricsConfig: MetricsConfig
  ) {
    const value = {
      evaluationInterval: metricsConfig.evaluationInterval.valueOf(),
    };
    const httpCall = !metricsConfig.isAdded
      ? this.http.post(
          `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/metricsConfig`,
          value
        )
      : this.http.put(
          `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/metricsConfig`,
          value
        );
      return httpCall;
  }

  setTrainingSessionStepFiveConfiguration(
    org: any,
    project: any,
    training: any,
    session: any,
    data: { stopCondition: StopCondition, metricsConfig?: MetricsConfig}) {
    let httpCalls = [];
    httpCalls.push(this.setStopCondition(org, project, session, data.stopCondition));
    if (data.metricsConfig) {
      httpCalls.push(this.setMetricsConfig(org, project, training, session, data.metricsConfig));
    }
    return combineLatest(httpCalls);
  }

  setStopConditions(
    org: any,
    project: any,
    session: any,
    data: StopCondition[]
  ) {

    const values = data.map(stopCondition => ({
      operator: stopCondition.operator,
      order: stopCondition.order,
      eval: stopCondition.eval,
      condition: stopCondition.condition,
      value: stopCondition.value,
    }));

    const httpCall = this.http.post(
      `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/stopConditions`,
      values
    );

    return httpCall;
  }

  setStopCondition(
    org: any,
    project: any,
    session: any,
    data: StopCondition
  ) {
    const value = {
      operator: data.operator,
      order: data.order,
      eval: data.eval,
      condition: data.condition,
      value: data.value,
    };
    const httpCall = !data.isAdded
      ? this.http.post(
          `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/stopConditions`,
          value
        )
      : this.http.put(
          `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/stopConditions`,
          value
        );
      return httpCall;
  }

  getStopConditions(org: any, project: any, training: any, session: any) {
    return this.http
      .get<StopCondition[]>(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/stopConditions`
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  getMetricsConfig(org: any, project: any, training: any, session: any) {
    return this.http
      .get<MetricsConfig>(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/metricsConfig`
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  deleteTrainingSessionProcess(
    org: any,
    project: any,
    training: any,
    session: any
  ) {
    return this.unlinkNodesFromTrainingSession(
      org,
      project,
      training,
      session
    )
  }

  unlinkNodesFromTrainingSession(
    org: any,
    project: any,
    training: any,
    session: any
  ) {
    return this.http
      .delete(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes`
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  deleteTrainingSession(org: any, project: any, training: any, session: any) {
    return this.http
      .delete(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}`
      )
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  getLinkedNodes(
    org: any,
    project: any,
    training: any,
    session: any
  ): Observable<Node[]> {
    return this.http
      .get<NodeList>(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes`
      )
      .pipe(
        map((res) => res.items),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  linkNode(
    org: any,
    project: any,
    session: any,
    nodeId: string
  ) {
    const params = {
      projectSuid: project,
      trainingSessionTimestamp: session,
    };

    return this.http
      .post(`${this.baseUrl}orgs/${org}/nodes/${nodeId}/link`, params)
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  unlinkNode(org: any, nodeId: string) {
    return this.http
      .delete(`${this.baseUrl}orgs/${org}/nodes/${nodeId}/link`)
      .pipe(
        tap((res) => res),
        catchError((error) => {
          return throwError(error);
        })
      );
  }

  linkNodeProcess(
    org: any,
    project: any,
    training: any,
    session: any,
    nodeId: string
  ) {
    return this.linkNode(org, project, session, nodeId).pipe(
      take(1),
      switchMap(() => {
        return this.getLinkedNodes(org, project, training, session);
      })
    );
  }

  unlinkNodeProcess(
    org: any,
    project: any,
    training: any,
    session: any,
    nodeId: string
  ) {
    return this.unlinkNode(org, nodeId).pipe(
      take(1),
      switchMap(() => {
        return this.getLinkedNodes(org, project, training, session);
      })
    );
  }

  deployInfra(
    org: string | undefined,
    projectId: string | undefined,
    trainingId: string | undefined,
    session: string | undefined
  ): Observable<any | never> {
    return this.http.post(
      `${this.baseUrl}orgs/${org}/projects/${projectId}/sessions/${session}/infrastructure`,
      {},
      { headers }
    )
    .pipe(
      catchError((res) => {
        return throwError(res?.error);
      })
    );
  }

  loadMapInfra(
    org: string | undefined,
    projectId: string | undefined,
    session: string | undefined,
    mappings: MappingNode[]
  ): Observable<string | never> {
    return this.http.get(
      `${this.baseUrl}orgs/${org}/projects/${projectId}/sessions/${session}/infrastructure`,
      { headers }
    )
    .pipe(
      map((res: any) => {
        let text;
        if (res) {
          if (res.status === 'PENDING' || res.status === 'CANCELED_AND_PENDING') {
            if (mappings && mappings.length > 0) {
              text = 'NOT_DEPLOYED';
            } else {
              text = 'NOT_CONFIGURED';
            }
          } else {
            text = res.status;
          }
        }
        return text;
      }),
      catchError((res) => {
        return throwError(res?.error);
      })
    );
  }

  loadInfra(
    org: string | undefined,
    projectId: string | undefined,
    trainingId: string | undefined,
    session: string | undefined
  ): Observable<Infrastructure | never> {
    return this.http.get(
      `${this.baseUrl}orgs/${org}/projects/${projectId}/sessions/${session}/infrastructure`,
      { headers }
    )
    .pipe(
      map((res) => res as Infrastructure),
      catchError((res) => {
        return throwError(res?.error);
      })
    );
  }

  cancelDeployInfra(
    org: string | undefined,
    projectId: string | undefined,
    trainingId: string | undefined,
    session: string | undefined
  ) {
    return this.http.post(
      `${this.baseUrl}orgs/${org}/projects/${projectId}/sessions/${session}/infrastructure/cancel`,
      {},
      { headers }
    )
    .pipe(
      map((res) => res),
      catchError((res) => {
        return throwError(res?.error);
      })
    );
  }

  getPrivateAttrs(
    org: any,
    project: any,
    training: any,
    session: any,
    round: number
  ) {
    return this.http.get<PrivateAttrs>(
        `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/rounds/${round}/privateAttrs`
    )
    .pipe(
      tap((res) => res),
      catchError((error) => {
        return throwError(error);
      })
    );
  }

  createPrivateAttrs(
    org: any,
    project: any,
    training: any,
    session: any,
    rounds: number,
    datasetLength: number
  ) {
    const params = {
      trainingDatasetLength: datasetLength,
    };

    return this.http.post(
      `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/rounds/${rounds}/privateAttrs`,
      params
    )
    .pipe(
      tap((res) => res),
      catchError((error) => {
        return throwError(error);
      })
    );
  }

  downloadModelParameters(
    org: any,
    project: any,
    session: any,
    node: string
  ) {
    return this.http.get<UrlModelParameters>(
      `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes/${node}/customFile?fileName=__SERIALIZED_INNER_MODEL__&round=0`
    )
    .pipe(
      tap((res) => res),
      catchError((error) => {
        return throwError(error);
      })
    );
  }

  getNodeHistoryList(
    org: any,
    project: any,
    session: any
  ): Observable<Node[]> {
    const params = new HttpParams()
      .set('page', '0')
      .set('pageSize', '100')
      .set('pagination', false);

    return this.http.get<{ pageSize: number; page: number; items: Node[] }>(
      `${this.baseUrl}orgs/${org}/projects/${project}/sessions/${session}/nodes/history`,
      { headers, params }
    )
    .pipe(
      map((res) => res.items),
      catchError((error) => {
        return throwError(error);
      })
    );
  }

  getTrainedInfo(org: any, project: any, trainingSession: any) {
    return this.getAllTrainingSession(org, project).pipe(
      map((response) => {
        let text = 'Not trained';
        if (response && response.length > 0) {
          const entityStatus = response[0].entityStatus;

          if (entityStatus && entityStatus.length > 0) {
            text = getTrainedText(entityStatus);
          }
        }
        return text;
      })
    );
  }

  getAllTrainedInfo(org: any, project: any, trainingIds: string[]) {
    const calls = trainingIds.map((id) => {
      return this.getTrainedInfo(org, project, id);
    });
    return combineLatest(calls);
  }
}
