import { useCallback, useState } from 'react';
import {
  Conversation,
  Message,
  OpenAiFormatMessage,
  ServerGptQuery,
  assertTruthy,
} from './types';
import { map } from 'lodash';

type ResponseGenerator = AsyncGenerator<
  string | { done: true },
  | {
      text: string;
      isError: boolean;
    }
  | { done: true },
  unknown
>;

export function useGetInterruptibleResponseStreamed() {
  const [responseGenerator, setResponseGenerator] =
    useState<ResponseGenerator | null>(null);

  async function getInterruptibleResponseStreamed(
    previousResponseGenerator: ResponseGenerator | null,
    systemMessageContent: string,
    messages: Conversation,
    input: string,
    outputCallback: (response: { text: string; isComplete: boolean }) => void,
    temperature: number = 0.7
  ) {
    // A previous message is still getting answered. Stop that generator
    if (previousResponseGenerator) {
      previousResponseGenerator.return({ done: true });
    }

    const _responseGenerator = getResponseStreamed(
      systemMessageContent,
      messages,
      input,
      temperature
    );

    setResponseGenerator(_responseGenerator);

    let _output = '';
    for await (const chunk of _responseGenerator) {
      _output += chunk;
      outputCallback({ text: _output, isComplete: false });
    }

    outputCallback({ text: _output, isComplete: true });
  }

  const getInterruptibleResponseStreamed_ = useCallback(
    (
      systemMessageContent: string,
      messages: Conversation,
      input: string,
      outputCallback: (response: { text: string; isComplete: boolean }) => void,
      temperature: number = 0.7
    ) => {
      getInterruptibleResponseStreamed(
        responseGenerator,
        systemMessageContent,
        messages,
        input,
        outputCallback,
        temperature
      );
    },
    [responseGenerator]
  );

  return getInterruptibleResponseStreamed_;
}

async function* getResponseStreamed(
  systemMessageContent: string,
  messages: Conversation,
  userInput: string,
  temperature: number
): ResponseGenerator {
  const response = await fetch(`/chatStream`, {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
    },
    body: JSON.stringify({
      query: getServerQuery(
        systemMessageContent,
        messages,
        userInput,
        temperature
      ),
    }),
  });

  const reader = response.body!.getReader();
  const decoder = new TextDecoder();
  while (true) {
    const { done, value } = await reader.read();
    if (done) {
      return { done: true };
    }

    yield decoder.decode(value) || '';
  }
}

function getServerQuery(
  systemMessageContent: string,
  messages: Conversation,
  userInput: string,
  temperature: number
): ServerGptQuery {
  const maxTokens = 2000;
  const model = 'gpt-4';

  const systemMessage: OpenAiFormatMessage = {
    role: 'system',
    content: systemMessageContent,
  };

  return {
    openAiApiKey: localStorage.accessKeyId,
    messages: [systemMessage].concat(meldMessagesToOpenAiFormat(messages), {
      role: 'user',
      content: userInput,
    }),
    temperature,
    ...(maxTokens !== undefined ? { maxTokens } : {}),
    model,
  };
}

function meldMessagesToOpenAiFormat(
  messages: Array<Message>
): Array<OpenAiFormatMessage> {
  return map(messages, (m) => {
    assertTruthy(m.sender === 'user' || m.sender === 'bot');

    return {
      role: m.sender === 'user' ? 'user' : 'assistant',
      content: m.text,
    };
  });
}
