module Whisper
  interface _Samples
    def length: () -> Integer
    def each: { (Float) -> void } -> void
  end

  type log_callback = ^(Integer level, String message, Object user_data) -> void
  type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
  type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
  type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
  type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish

  VERSION: String
  LOG_LEVEL_NONE: Integer
  LOG_LEVEL_INFO: Integer
  LOG_LEVEL_WARN: Integer
  LOG_LEVEL_ERROR: Integer
  LOG_LEVEL_DEBUG: Integer
  LOG_LEVEL_CONT: Integer
  AHEADS_NONE: Integer
  AHEADS_N_TOP_MOST: Integer
  AHEADS_CUSTOM: Integer
  AHEADS_TINY_EN: Integer
  AHEADS_TINY: Integer
  AHEADS_BASE_EN: Integer
  AHEADS_BASE: Integer
  AHEADS_SMALL_EN: Integer
  AHEADS_SMALL: Integer
  AHEADS_MEDIUM_EN: Integer
  AHEADS_MEDIUM: Integer
  AHEADS_LARGE_V1: Integer
  AHEADS_LARGE_V2: Integer
  AHEADS_LARGE_V3: Integer
  AHEADS_LARGE_V3_TURBO: Integer

  def self.lang_max_id: () -> Integer
  def self.lang_id: (string name) -> Integer
  def self.lang_str: (Integer id) -> String
  def self.lang_str_full: (Integer id) -> String
  def self.log_set: (log_callback?, Object? user_data) -> log_callback
  def self.system_info_str: () -> String

  class Context
    def self.new: (String | path | ::URI::HTTP) -> instance

    # transcribe a single file
    # can emit to a block results
    #
    #     params = Whisper::Params.new
    #     params.duration = 60_000
    #     whisper.transcribe "path/to/audio.wav", params do |text|
    #       puts text
    #     end
    #
    # If n_processors is greater than 1, you cannot set any callbacks including
    # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback,
    # and log_callback set by Whisper.log_set
    def transcribe: (path, Params, ?n_processors: Integer) -> self
                  | (path, Params, ?n_processors: Integer) { (String) -> void } -> self

    def model_n_vocab: () -> Integer
    def model_n_audio_ctx: () -> Integer
    def model_n_audio_state: () -> Integer
    def model_n_text_head: () -> Integer
    def model_n_text_layer: () -> Integer
    def model_n_mels: () -> Integer
    def model_ftype: () -> Integer
    def model_type: () -> String

    # Yields each Whisper::Segment:
    #
    #     whisper.transcribe("path/to/audio.wav", params)
    #     whisper.each_segment do |segment|
    #       puts segment.text
    #     end
    #
    # Returns an Enumerator if no block given:
    #
    #     whisper.transcribe("path/to/audio.wav", params)
    #     enum = whisper.each_segment
    #     enum.to_a # => [#<Whisper::Segment>, ...]
    #
    def each_segment: { (Segment) -> void } -> void
                    | () -> Enumerator[Segment]

    def model: () -> Model
    def full_get_segment: (Integer nth) -> Segment
    def full_n_segments: () -> Integer

    # Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
    #
    def full_lang_id: () -> Integer

    # Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
    #
    #     full_get_segment_t0(3) # => 1668 (16680 ms)
    #
    def full_get_segment_t0: (Integer) -> Integer

    # End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
    #
    #     full_get_segment_t1(3) # => 1668 (16680 ms)
    #
    def full_get_segment_t1: (Integer) -> Integer

    # Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
    #
    #     full_get_segment_speacker_turn_next(3) # => true
    #
    def full_get_segment_speaker_turn_next: (Integer) -> (true | false)

    # Text of a segment indexed by +segment_index+.
    #
    #     full_get_segment_text(3) # => "ask not what your country can do for you, ..."
    #
    def full_get_segment_text: (Integer) -> String

    def full_get_segment_no_speech_prob: (Integer) -> Float

    # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
    # Not thread safe for same context
    # Uses the specified decoding strategy to obtain the text.
    #
    # The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
    #
    def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
            | (Params, _Samples, ?Integer n_samples) -> self

    # Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
    # Result is stored in the default state of the context
    # Not thread safe if executed in parallel on the same context.
    # It seems this approach can offer some speedup in some cases.
    # However, the transcription accuracy can be worse at the beginning and end of each chunk.
    #
    # If n_processors is greater than 1, you cannot set any callbacks including
    # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback,
    # and log_callback set by Whisper.log_set
    def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
                     | (Params, _Samples, ?Integer n_samples) -> self
                     | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self

    def to_srt: () -> String
    def to_webvtt: () -> String

    class Params
      def self.new: (
        use_gpu: boolish,
        flash_attn: boolish,
        gpu_device: Integer,
        dtw_token_timestamps: boolish,
        dtw_aheads_preset: Integer,
        dtw_n_top: Integer | nil,
      ) -> instance

      def use_gpu=: (boolish) -> boolish
      def use_gpu: () -> (true | false)
      def flash_attn=: (boolish) -> boolish
      def flash_attn: () -> (true | false)
      def gpu_device=: (Integer) -> Integer
      def gpu_device: () -> Integer
      def dtw_token_timestamps=: (boolish) -> boolish
      def dtw_token_timestamps: () -> (true | false)
      def dtw_aheads_preset=: (Integer) -> Integer
      def dtw_aheads_preset: () -> Integer
      def dtw_n_top=: (Integer | nil) -> (Integer | nil)
      def dtw_n_top: () -> (Integer | nil)
    end
  end

  class Params
    def self.new: (
      ?language: string,
      ?translate: boolish,
      ?no_context: boolish,
      ?single_segment: boolish,
      ?print_special: boolish,
      ?print_progress: boolish,
      ?print_realtime: boolish,
      ?print_timestamps: boolish,
      ?suppress_blank: boolish,
      ?suppress_nst: boolish,
      ?token_timestamps: boolish,
      ?max_len: Integer,
      ?split_on_word: boolish,
      ?initial_prompt: string | nil,
      ?carry_initial_prompt: boolish,
      ?diarize: boolish,
      ?offset: Integer,
      ?duration: Integer,
      ?max_text_tokens: Integer,
      ?temperature: Float,
      ?max_initial_ts: Float,
      ?length_penalty: Float,
      ?temperature_inc: Float,
      ?entropy_thold: Float,
      ?logprob_thold: Float,
      ?no_speech_thold: Float,
      ?new_segment_callback: new_segment_callback,
      ?new_segment_callback_user_data: Object,
      ?progress_callback: progress_callback,
      ?progress_callback_user_data: Object,
      ?encoder_begin_callback: encoder_begin_callback,
      ?encoder_begin_callback_user_data: Object,
      ?abort_callback: abort_callback,
      ?abort_callback_user_data: Object,
      ?vad: boolish,
      ?vad_model_path: path | URI,
      ?vad_params: Whisper::VAD::Params
    ) -> instance

    # params.language = "auto" | "en", etc...
    #
    def language=: (String) -> String # TODO: Enumerate lang names

    def language: () -> String
    def translate=: (boolish) -> boolish
    def translate: () -> (true | false)
    def no_context=: (boolish) -> boolish

    # If true, does not use past transcription (if any) as initial prompt for the decoder.
    #
    def no_context: () -> (true | false)

    def single_segment=: (boolish) -> boolish

    # If true, forces single segment output (useful for streaming).
    #
    def single_segment: () -> (true | false)

    def print_special=: (boolish) -> boolish

    # If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
    #
    def print_special: () -> (true | false)

    def print_progress=: (boolish) -> boolish

    # If true, prints progress information.
    #
    def print_progress: () -> (true | false)

    def print_realtime=: (boolish) -> boolish

    # If true, prints results from within whisper.cpp. (avoid it, use callback instead)
    #
    def print_realtime: () -> (true | false)

    # If true, prints timestamps for each text segment when printing realtime.
    #
    def print_timestamps=: (boolish) -> boolish

    def print_timestamps: () -> (true | false)

    def suppress_blank=: (boolish) -> boolish

    # If true, suppresses blank outputs.
    #
    def suppress_blank: () -> (true | false)

    def suppress_nst=: (boolish) -> boolish

    # If true, suppresses non-speech-tokens.
    #
    def suppress_nst: () -> (true | false)

    def token_timestamps=: (boolish) -> boolish

    # If true, enables token-level timestamps.
    #
    def token_timestamps: () -> (true | false)

    def max_len=: (Integer) -> Integer

    # max segment length in characters.
    #
    def max_len: () -> Integer

    def split_on_word=: (boolish) -> boolish

    # If true, split on word rather than on token (when used with max_len).
    #
    def split_on_word: () -> (true | false)

    def initial_prompt=: (_ToS) -> _ToS
    def carry_initial_prompt=: (boolish) -> boolish

    # Tokens to provide to the whisper decoder as initial prompt
    # these are prepended to any existing text context from a previous call
    # use whisper_tokenize() to convert text to tokens.
    # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
    #
    def initial_prompt: () -> (String | nil)
    def carry_initial_prompt: () -> (true | false)

    def diarize=: (boolish) -> boolish

    # If true, enables diarization.
    #
    def diarize: () -> (true | false)

    def offset=: (Integer) -> Integer

    # Start offset in ms.
    #
    def offset: () -> Integer

    def duration=: (Integer) -> Integer

    # Audio duration to process in ms.
    #
    def duration: () -> Integer

    def max_text_tokens=: (Integer) -> Integer

    # Max tokens to use from past text as prompt for the decoder.
    #
    def max_text_tokens: () -> Integer

    def temperature=: (Float) -> Float
    def temperature: () -> Float
    def max_initial_ts=: (Float) -> Float

    # See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
    #
    def max_initial_ts: () -> Float

    def length_penalty=: (Float) -> Float
    def length_penalty: () -> Float
    def temperature_inc=: (Float) -> Float
    def temperature_inc: () -> Float
    def entropy_thold=: (Float) -> Float

    # Similar to OpenAI's "compression_ratio_threshold"
    #
    def entropy_thold: () -> Float

    def logprob_thold=: (Float) -> Float
    def logprob_thold: () -> Float
    def no_speech_thold=: (Float) -> Float
    def no_speech_thold: () -> Float

    # Sets new segment callback, called for every newly generated text segment.
    #
    #     params.new_segment_callback = ->(context, _, n_new, user_data) {
    #       # ...
    #     }
    #
    def new_segment_callback=: (new_segment_callback) -> new_segment_callback
    def new_segment_callback: () -> (new_segment_callback | nil)

    # Sets user data passed to the last argument of new segment callback.
    #
    def new_segment_callback_user_data=: (Object) -> Object

    def new_segment_callback_user_data: () -> Object

    # Sets progress callback, called on each progress update.
    #
    #     params.new_segment_callback = ->(context, _, progress, user_data) {
    #       # ...
    #     }
    #
    # +progress+ is an Integer between 0 and 100.
    #
    def progress_callback=: (progress_callback) -> progress_callback

    def progress_callback: () -> (progress_callback | nil)

    # Sets user data passed to the last argument of progress callback.
    #
    def progress_callback_user_data=: (Object) -> Object

    def progress_callback_user_data: () -> Object

    # Sets encoder begin callback, called when the encoder starts.
    #
    def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback

    def encoder_begin_callback: () -> (encoder_begin_callback | nil)

    # Sets user data passed to the last argument of encoder begin callback.
    #
    def encoder_begin_callback_user_data=: (Object) -> Object

    def encoder_begin_callback_user_data: () -> Object

    # Sets abort callback, called to check if the process should be aborted.
    #
    #     params.abort_callback = ->(user_data) {
    #       # ...
    #     }
    #
    #
    def abort_callback=: (abort_callback) -> abort_callback

    def abort_callback: () -> (abort_callback | nil)

    # Sets user data passed to the last argument of abort callback.
    #
    def abort_callback_user_data=: (Object) -> Object

    def abort_callback_user_data: () -> Object

    # Enable VAD
    #
    def vad=: (boolish) -> boolish

    def vad: () -> (true | false)

    # Path to the VAD model
    def vad_model_path=: (path | URI | nil) -> (path | URI | nil)

    def vad_model_path: () -> (String | nil)

    def vad_params=: (Whisper::VAD::Params) -> Whisper::VAD::Params
    def vad_params: () -> (Whisper::VAD::Params)

    # Hook called on new segment. Yields each Whisper::Segment.
    #
    #     whisper.on_new_segment do |segment|
    #       # ...
    #     end
    #
    def on_new_segment: { (Segment) -> void } -> void

    # Hook called on progress update. Yields each progress Integer between 0 and 100.
    #
    def on_progress: { (Integer progress) -> void } -> void

    # Hook called on encoder starts.
    #
    def on_encoder_begin: { () -> void } -> void

    # Call block to determine whether abort or not. Return +true+ when you want to abort.
    #
    #     params.abort_on do
    #       if some_condition
    #         true # abort
    #       else
    #         false # continue
    #       end
    #     end
    #
    def abort_on: { (Object user_data) -> boolish } -> void
  end

  class Model
    def self.pre_converted_models: () -> Hash[String, Model::URI]
    def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
    def self.new: () -> instance
    def n_vocab: () -> Integer
    def n_audio_ctx: () -> Integer
    def n_audio_state: () -> Integer
    def n_audio_head: () -> Integer
    def n_audio_layer: () -> Integer
    def n_text_ctx: () -> Integer
    def n_text_state: () -> Integer
    def n_text_head: () -> Integer
    def n_text_layer: () -> Integer
    def n_mels: () -> Integer
    def ftype: () -> Integer
    def type: () -> String

    class URI
      def self.new: (string | ::URI::HTTP) -> instance
      def to_path: -> String
      def clear_cache: -> void
    end

    class ZipURI < URI
      def cache: () -> Pathname
      def clear_cache: () -> void
    end
  end

  class Segment
    type deconstructed_keys = {
      start_time: (Integer | nil),
      end_time: (Integer | nil),
      text: (String | nil),
      no_speech_prob: (Float | nil),
      speaker_turn_next: (true | false | nil),
      n_tokens: (Integer | nil)
    }

    # Start time in milliseconds.
    #
    def start_time: () -> Integer

    # End time in milliseconds.
    #
    def end_time: () -> Integer

    # Whether the next segment is predicted as a speaker turn.
    #
    def speaker_turn_next?: () -> (true | false)

    def text: () -> String
    def no_speech_prob: () -> Float

    # Get number of tokens in the segment
    #
    def n_tokens: () -> Integer

    # Yields each Whisper::Token:
    #
    #   whisper.each_segment.first.each_token do |token|
    #     p token
    #   end
    #
    # Returns an Enumerator if no block is given:
    #
    #   whisper.each_segment.first.each_token.to_a # => [#<Whisper::Token>, ...]
    #
    def each_token: { (Token) -> void } -> void
                  | () -> Enumerator[Token]
    def to_srt_cue: () -> String
    def to_webvtt_cue: () -> String


    #  Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
    #
    #      whisper.each_segment do |segment|
    #        segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
    #
    #        puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
    #      end
    def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys
  end

  module Token
    type deconstructed_keys = {
      id: (Integer | nil),
      tid: (Integer | nil),
      probability: (Float | nil),
      log_probability: (Float | nil),
      pt: (Float | nil),
      ptsum: (Float | nil),
      t_dtw: (Integer | nil),
      voice_length: (Float | nil),
      text: (String | nil),
      start_time: (Integer | nil),
      end_time: (Integer | nil),
    }

    # Token ID.
    #
    def id: () -> Integer

    # Forced timestamp token ID.
    #
    def tid: () -> Integer

    # Probability of the token.
    #
    def probability: () -> Float

    # Log probability of the token.
    #
    def log_probability: () -> Float

    # Probability of the timestamp token.
    #
    def pt: () -> Float

    # Sum of probability of all timestamp tokens.
    #
    def ptsum: () -> Float

    # [EXPERIMENTAL] Token-level timestamps with DTW
    #
    # Do not use if you haven't computed token-level timestamps with dtw.
    # Roughly corresponds to the moment in audio in which the token was output.
    #
    def t_dtw: () -> Integer

    # Voice length of the token.
    #
    def voice_length: () -> Float

    # Start time of the token.
    #
    # Token-level timestamp data.
    # Do not use if you haven't computed token-level timestamps.
    #
    def start_time: () -> Integer

    # End time of the token.
    #
    # Token-level timestamp data.
    # Do not use if you haven't computed token-level timestamps.
    #
    def end_time: () -> Integer

    # Get the token text of the token.
    #
    def text: () -> String
    def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys
  end

  module VAD
    class Params
      def self.new: (
        ?threshold: Float,
        ?min_speech_duration_ms: Integer,
        ?min_silence_duration_ms: Integer,
        ?max_speech_duration_s: Float,
        ?speech_pad_ms: Integer,
        ?samples_overlap: Float
      ) -> instance

      # Probability threshold to consider as speech.
      #
      def threshold=: (Float) -> Float

      def threshold: () -> Float

      # Min duration for a valid speech segment.
      #
      def min_speech_duration_ms=: (Integer) -> Integer

      def min_speech_duration_ms: () -> Integer

      # Min silence duration to consider speech as ended.
      #
      def min_silence_duration_ms=: (Integer) -> Integer

      def min_silence_duration_ms: () -> Integer

      # Max duration of a speech segment before forcing a new segment.
      def max_speech_duration_s=: (Float) -> Float

      def max_speech_duration_s: () -> Float

      # Padding added before and after speech segments.
      #
      def speech_pad_ms=: (Integer) -> Integer

      def speech_pad_ms: () -> Integer

      # Overlap in seconds when copying audio samples from speech segment.
      #
      def samples_overlap=: (Float) -> Float

      def samples_overlap: () -> Float
      def ==: (Params) -> (true | false)
    end

    class Context
      def self.new: (String | path | ::URI::HTTP model_name_or_path) -> instance
      def segments_from_samples: (Params, Array[Float] samples, ?Integer n_samples) -> Segments
                               | (Params, _Samples, ?Integer n_samples) -> Segments
      def detect: (path wav_file_path, Params) -> Segments
    end

    class Segments
      include Enumerable[Segment]

      def each: { (Segment) -> void } -> void
              | () -> Enumerator[Segment]
      def length: -> Integer
    end

    class Segment
      type deconstructed_keys = {
        start_time: (Integer | nil),
        end_time: (Integer | nil),
      }

      def start_time: () -> Integer
      def end_time: () -> Integer
      def deconstruct_keys: (Array[:start_time | :end_time] | nil) -> deconstructed_keys
    end
  end

  class Error < StandardError
    attr_reader code: Integer

    def self.new: (Integer code) -> instance
  end
end
