Skip to content

API Reference

Text

TextAnalyzer

Used to get text from a csv and then run the TextDetector on it.

Source code in ammico/text.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
class TextAnalyzer:
    """Used to get text from a csv and then run the TextDetector on it."""

    def __init__(
        self, csv_path: str, column_key: str = None, csv_encoding: str = "utf-8"
    ) -> None:
        """Init the TextTranslator class.

        Args:
            csv_path (str): Path to the CSV file containing the text entries.
            column_key (str): Key for the column containing the text entries.
                Defaults to None.
            csv_encoding (str): Encoding of the CSV file. Defaults to "utf-8".
        """
        self.csv_path = csv_path
        self.column_key = column_key
        self.csv_encoding = csv_encoding
        self._check_valid_csv_path()
        self._check_file_exists()
        if not self.column_key:
            print("No column key provided - using 'text' as default.")
            self.column_key = "text"
        if not self.csv_encoding:
            print("No encoding provided - using 'utf-8' as default.")
            self.csv_encoding = "utf-8"
        if not isinstance(self.column_key, str):
            raise ValueError("The provided column key is not a string.")
        if not isinstance(self.csv_encoding, str):
            raise ValueError("The provided encoding is not a string.")

    def _check_valid_csv_path(self):
        if not isinstance(self.csv_path, str):
            raise ValueError("The provided path to the CSV file is not a string.")
        if not self.csv_path.endswith(".csv"):
            raise ValueError("The provided file is not a CSV file.")

    def _check_file_exists(self):
        try:
            with open(self.csv_path, "r") as file:  # noqa
                pass
        except FileNotFoundError:
            raise FileNotFoundError("The provided CSV file does not exist.")

    def read_csv(self) -> dict:
        """Read the CSV file and return the dictionary with the text entries.

        Returns:
            dict: The dictionary with the text entries.
        """
        df = pd.read_csv(self.csv_path, encoding=self.csv_encoding)

        if self.column_key not in df:
            raise ValueError(
                "The provided column key is not in the CSV file. Please check."
            )
        self.mylist = df[self.column_key].to_list()
        self.mydict = {}
        for i, text in enumerate(self.mylist):
            self.mydict[self.csv_path + "row-" + str(i)] = {
                "filename": self.csv_path,
                "text": text,
            }

__init__(csv_path, column_key=None, csv_encoding='utf-8')

Init the TextTranslator class.

Parameters:

Name Type Description Default
csv_path str

Path to the CSV file containing the text entries.

required
column_key str

Key for the column containing the text entries. Defaults to None.

None
csv_encoding str

Encoding of the CSV file. Defaults to "utf-8".

'utf-8'
Source code in ammico/text.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def __init__(
    self, csv_path: str, column_key: str = None, csv_encoding: str = "utf-8"
) -> None:
    """Init the TextTranslator class.

    Args:
        csv_path (str): Path to the CSV file containing the text entries.
        column_key (str): Key for the column containing the text entries.
            Defaults to None.
        csv_encoding (str): Encoding of the CSV file. Defaults to "utf-8".
    """
    self.csv_path = csv_path
    self.column_key = column_key
    self.csv_encoding = csv_encoding
    self._check_valid_csv_path()
    self._check_file_exists()
    if not self.column_key:
        print("No column key provided - using 'text' as default.")
        self.column_key = "text"
    if not self.csv_encoding:
        print("No encoding provided - using 'utf-8' as default.")
        self.csv_encoding = "utf-8"
    if not isinstance(self.column_key, str):
        raise ValueError("The provided column key is not a string.")
    if not isinstance(self.csv_encoding, str):
        raise ValueError("The provided encoding is not a string.")

read_csv()

Read the CSV file and return the dictionary with the text entries.

Returns:

Name Type Description
dict dict

The dictionary with the text entries.

Source code in ammico/text.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def read_csv(self) -> dict:
    """Read the CSV file and return the dictionary with the text entries.

    Returns:
        dict: The dictionary with the text entries.
    """
    df = pd.read_csv(self.csv_path, encoding=self.csv_encoding)

    if self.column_key not in df:
        raise ValueError(
            "The provided column key is not in the CSV file. Please check."
        )
    self.mylist = df[self.column_key].to_list()
    self.mydict = {}
    for i, text in enumerate(self.mylist):
        self.mydict[self.csv_path + "row-" + str(i)] = {
            "filename": self.csv_path,
            "text": text,
        }

TextDetector

Bases: AnalysisMethod

Source code in ammico/text.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
class TextDetector(AnalysisMethod):
    def __init__(
        self,
        subdict: dict,
        skip_extraction: bool = False,
        accept_privacy: str = "PRIVACY_AMMICO",
    ) -> None:
        """Init text detection class.

        Args:
            subdict (dict): Dictionary containing file name/path, and possibly previous
                analysis results from other modules.
            skip_extraction (bool, optional): Decide if text will be extracted from images or
                is already provided via a csv. Defaults to False.
            accept_privacy (str, optional): Environment variable to accept the privacy
                statement for the Google Cloud processing of the data. Defaults to
                "PRIVACY_AMMICO".
        """
        super().__init__(subdict)
        # disable this for now
        # maybe it would be better to initialize the keys differently
        # the reason is that they are inconsistent depending on the selected
        # options, and also this may not be really necessary and rather restrictive
        # self.subdict.update(self.set_keys())
        self.accepted = privacy_disclosure(accept_privacy)
        if not self.accepted:
            raise ValueError(
                "Privacy disclosure not accepted - skipping text detection."
            )
        self.translator = Translator(raise_exception=True)
        self.skip_extraction = skip_extraction
        if not isinstance(skip_extraction, bool):
            raise ValueError("skip_extraction needs to be set to true or false")
        if self.skip_extraction:
            print("Skipping text extraction from image.")
            print("Reading text directly from provided dictionary.")
        self._initialize_spacy()

    def set_keys(self) -> dict:
        """Set the default keys for text analysis.

        Returns:
            dict: The dictionary with default text keys.
        """
        params = {"text": None, "text_language": None, "text_english": None}
        return params

    def _initialize_spacy(self):
        """Initialize the Spacy library for text analysis."""
        try:
            self.nlp = spacy.load("en_core_web_md")
        except Exception:
            spacy.cli.download("en_core_web_md")
            self.nlp = spacy.load("en_core_web_md")

    def _check_add_space_after_full_stop(self):
        """Add a space after a full stop. Required by googletrans."""
        # we have found text, now we check for full stops
        index_stop = [i.start() for i in re.finditer(r"\.", self.subdict["text"])]
        if not index_stop:  # no full stops found
            return
        # check if this includes the last string item
        end_of_list = False
        if len(self.subdict["text"]) <= (index_stop[-1] + 1):
            # the last found full stop is at the end of the string
            # but we can include all others
            if len(index_stop) == 1:
                end_of_list = True
            else:
                index_stop.pop()
        if end_of_list:  # only one full stop at end of string
            return
        # if this is not the end of the list, check if there is a space after the full stop
        no_space = [i for i in index_stop if self.subdict["text"][i + 1] != " "]
        if not no_space:  # all full stops have a space after them
            return
        # else, amend the text
        add_one = 1
        for i in no_space:
            self.subdict["text"] = (
                self.subdict["text"][: i + add_one]
                + " "
                + self.subdict["text"][i + add_one :]
            )
            add_one += 1

    def _truncate_text(self, max_length: int = 5000) -> str:
        """Truncate the text if it is too long for googletrans."""
        if self.subdict["text"] and len(self.subdict["text"]) > max_length:
            print("Text is too long - truncating to {} characters.".format(max_length))
            self.subdict["text_truncated"] = self.subdict["text"][:max_length]

    def analyse_image(self) -> dict:
        """Perform text extraction and analysis of the text.

        Returns:
            dict: The updated dictionary with text analysis results.
        """
        if not self.skip_extraction:
            self.get_text_from_image()
        # check that text was found
        if not self.subdict["text"]:
            print("No text found - skipping analysis.")
        else:
            # make sure all full stops are followed by whitespace
            # otherwise googletrans breaks
            self._check_add_space_after_full_stop()
            self._truncate_text()
            self.translate_text()
            self.remove_linebreaks()
            if self.subdict["text_english"]:
                self._run_spacy()
        return self.subdict

    def get_text_from_image(self):
        """Detect text on the image using Google Cloud Vision API."""
        if not self.accepted:
            raise ValueError(
                "Privacy disclosure not accepted - skipping text detection."
            )
        path = self.subdict["filename"]
        try:
            client = vision.ImageAnnotatorClient()
        except DefaultCredentialsError:
            raise DefaultCredentialsError(
                "Please provide credentials for google cloud vision API, see https://cloud.google.com/docs/authentication/application-default-credentials."
            )
        with io.open(path, "rb") as image_file:
            content = image_file.read()
        image = vision.Image(content=content)
        # check for usual connection errors and retry if necessary
        try:
            response = client.text_detection(image=image)
        except grpc.RpcError as exc:
            print("Cloud vision API connection failed")
            print("Skipping this image ..{}".format(path))
            print("Connection failed with code {}: {}".format(exc.code(), exc))
        # here check if text was found on image
        if response:
            texts = response.text_annotations[0].description
            self.subdict["text"] = texts
        else:
            print("No text found on image.")
            self.subdict["text"] = None
        if response.error.message:
            print("Google Cloud Vision Error")
            raise ValueError(
                "{}\nFor more info on error messages, check: "
                "https://cloud.google.com/apis/design/errors".format(
                    response.error.message
                )
            )

    def translate_text(self):
        """Translate the detected text to English using the Translator object."""
        if not self.accepted:
            raise ValueError(
                "Privacy disclosure not accepted - skipping text translation."
            )
        text_to_translate = (
            self.subdict["text_truncated"]
            if "text_truncated" in self.subdict
            else self.subdict["text"]
        )
        try:
            translated = self.translator.translate(text_to_translate)
        except Exception:
            print("Could not translate the text with error {}.".format(Exception))
            translated = None
            print("Skipping translation for this text.")
        self.subdict["text_language"] = translated.src if translated else None
        self.subdict["text_english"] = translated.text if translated else None

    def remove_linebreaks(self):
        """Remove linebreaks from original and translated text."""
        if self.subdict["text"] and self.subdict["text_english"]:
            self.subdict["text"] = self.subdict["text"].replace("\n", " ")
            self.subdict["text_english"] = self.subdict["text_english"].replace(
                "\n", " "
            )

    def _run_spacy(self):
        """Generate Spacy doc object for further text analysis."""
        self.doc = self.nlp(self.subdict["text_english"])

__init__(subdict, skip_extraction=False, accept_privacy='PRIVACY_AMMICO')

Init text detection class.

Parameters:

Name Type Description Default
subdict dict

Dictionary containing file name/path, and possibly previous analysis results from other modules.

required
skip_extraction bool

Decide if text will be extracted from images or is already provided via a csv. Defaults to False.

False
accept_privacy str

Environment variable to accept the privacy statement for the Google Cloud processing of the data. Defaults to "PRIVACY_AMMICO".

'PRIVACY_AMMICO'
Source code in ammico/text.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    subdict: dict,
    skip_extraction: bool = False,
    accept_privacy: str = "PRIVACY_AMMICO",
) -> None:
    """Init text detection class.

    Args:
        subdict (dict): Dictionary containing file name/path, and possibly previous
            analysis results from other modules.
        skip_extraction (bool, optional): Decide if text will be extracted from images or
            is already provided via a csv. Defaults to False.
        accept_privacy (str, optional): Environment variable to accept the privacy
            statement for the Google Cloud processing of the data. Defaults to
            "PRIVACY_AMMICO".
    """
    super().__init__(subdict)
    # disable this for now
    # maybe it would be better to initialize the keys differently
    # the reason is that they are inconsistent depending on the selected
    # options, and also this may not be really necessary and rather restrictive
    # self.subdict.update(self.set_keys())
    self.accepted = privacy_disclosure(accept_privacy)
    if not self.accepted:
        raise ValueError(
            "Privacy disclosure not accepted - skipping text detection."
        )
    self.translator = Translator(raise_exception=True)
    self.skip_extraction = skip_extraction
    if not isinstance(skip_extraction, bool):
        raise ValueError("skip_extraction needs to be set to true or false")
    if self.skip_extraction:
        print("Skipping text extraction from image.")
        print("Reading text directly from provided dictionary.")
    self._initialize_spacy()

analyse_image()

Perform text extraction and analysis of the text.

Returns:

Name Type Description
dict dict

The updated dictionary with text analysis results.

Source code in ammico/text.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def analyse_image(self) -> dict:
    """Perform text extraction and analysis of the text.

    Returns:
        dict: The updated dictionary with text analysis results.
    """
    if not self.skip_extraction:
        self.get_text_from_image()
    # check that text was found
    if not self.subdict["text"]:
        print("No text found - skipping analysis.")
    else:
        # make sure all full stops are followed by whitespace
        # otherwise googletrans breaks
        self._check_add_space_after_full_stop()
        self._truncate_text()
        self.translate_text()
        self.remove_linebreaks()
        if self.subdict["text_english"]:
            self._run_spacy()
    return self.subdict

get_text_from_image()

Detect text on the image using Google Cloud Vision API.

Source code in ammico/text.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def get_text_from_image(self):
    """Detect text on the image using Google Cloud Vision API."""
    if not self.accepted:
        raise ValueError(
            "Privacy disclosure not accepted - skipping text detection."
        )
    path = self.subdict["filename"]
    try:
        client = vision.ImageAnnotatorClient()
    except DefaultCredentialsError:
        raise DefaultCredentialsError(
            "Please provide credentials for google cloud vision API, see https://cloud.google.com/docs/authentication/application-default-credentials."
        )
    with io.open(path, "rb") as image_file:
        content = image_file.read()
    image = vision.Image(content=content)
    # check for usual connection errors and retry if necessary
    try:
        response = client.text_detection(image=image)
    except grpc.RpcError as exc:
        print("Cloud vision API connection failed")
        print("Skipping this image ..{}".format(path))
        print("Connection failed with code {}: {}".format(exc.code(), exc))
    # here check if text was found on image
    if response:
        texts = response.text_annotations[0].description
        self.subdict["text"] = texts
    else:
        print("No text found on image.")
        self.subdict["text"] = None
    if response.error.message:
        print("Google Cloud Vision Error")
        raise ValueError(
            "{}\nFor more info on error messages, check: "
            "https://cloud.google.com/apis/design/errors".format(
                response.error.message
            )
        )

remove_linebreaks()

Remove linebreaks from original and translated text.

Source code in ammico/text.py
239
240
241
242
243
244
245
def remove_linebreaks(self):
    """Remove linebreaks from original and translated text."""
    if self.subdict["text"] and self.subdict["text_english"]:
        self.subdict["text"] = self.subdict["text"].replace("\n", " ")
        self.subdict["text_english"] = self.subdict["text_english"].replace(
            "\n", " "
        )

set_keys()

Set the default keys for text analysis.

Returns:

Name Type Description
dict dict

The dictionary with default text keys.

Source code in ammico/text.py
104
105
106
107
108
109
110
111
def set_keys(self) -> dict:
    """Set the default keys for text analysis.

    Returns:
        dict: The dictionary with default text keys.
    """
    params = {"text": None, "text_language": None, "text_english": None}
    return params

translate_text()

Translate the detected text to English using the Translator object.

Source code in ammico/text.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def translate_text(self):
    """Translate the detected text to English using the Translator object."""
    if not self.accepted:
        raise ValueError(
            "Privacy disclosure not accepted - skipping text translation."
        )
    text_to_translate = (
        self.subdict["text_truncated"]
        if "text_truncated" in self.subdict
        else self.subdict["text"]
    )
    try:
        translated = self.translator.translate(text_to_translate)
    except Exception:
        print("Could not translate the text with error {}.".format(Exception))
        translated = None
        print("Skipping translation for this text.")
    self.subdict["text_language"] = translated.src if translated else None
    self.subdict["text_english"] = translated.text if translated else None

privacy_disclosure(accept_privacy='PRIVACY_AMMICO')

Asks the user to accept the privacy statement.

Parameters:

Name Type Description Default
accept_privacy str

The name of the disclosure variable (default: "PRIVACY_AMMICO").

'PRIVACY_AMMICO'
Source code in ammico/text.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def privacy_disclosure(accept_privacy: str = "PRIVACY_AMMICO"):
    """
    Asks the user to accept the privacy statement.

    Args:
        accept_privacy (str): The name of the disclosure variable (default: "PRIVACY_AMMICO").
    """
    if not os.environ.get(accept_privacy):
        accepted = _ask_for_privacy_acceptance(accept_privacy)
    elif os.environ.get(accept_privacy) == "False":
        accepted = False
    elif os.environ.get(accept_privacy) == "True":
        accepted = True
    else:
        print(
            "Could not determine privacy disclosure - skipping \
              text detection and translation."
        )
        accepted = False
    return accepted

Image Summary

ImageSummaryDetector

Bases: AnalysisMethod

Source code in ammico/image_summary.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class ImageSummaryDetector(AnalysisMethod):
    token_prompt_config = {
        "default": {
            "summary": {"prompt": "Describe this image.", "max_new_tokens": 256},
            "questions": {"prompt": "", "max_new_tokens": 128},
        },
        "concise": {
            "summary": {
                "prompt": "Describe this image in one concise caption.",
                "max_new_tokens": 64,
            },
            "questions": {"prompt": "Answer concisely: ", "max_new_tokens": 128},
        },
    }
    MAX_QUESTIONS_PER_IMAGE = 32
    KEYS_BATCH_SIZE = 16

    def __init__(
        self,
        summary_model: MultimodalSummaryModel,
        subdict: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        Class for analysing images using QWEN-2.5-VL model.
        It provides methods for generating captions and answering questions about images.

        Args:
            summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
            subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.

        Returns:
            None.
        """
        if subdict is None:
            subdict = {}

        super().__init__(subdict)
        self.summary_model = summary_model

    def _load_pil_if_needed(
        self, filename: Union[str, os.PathLike, Image.Image]
    ) -> Image.Image:
        if isinstance(filename, (str, os.PathLike)):
            return Image.open(filename).convert("RGB")
        elif isinstance(filename, Image.Image):
            return filename.convert("RGB")
        else:
            raise ValueError("filename must be a path or PIL.Image")

    @staticmethod
    def _is_sequence_but_not_str(obj: Any) -> bool:
        """True for sequence-like but not a string/bytes/PIL.Image."""
        return isinstance(obj, _Sequence) and not isinstance(
            obj, (str, bytes, Image.Image)
        )

    def _prepare_inputs(
        self, list_of_questions: list[str], entry: Optional[Dict[str, Any]] = None
    ) -> Dict[str, torch.Tensor]:
        filename = entry.get("filename")
        if filename is None:
            raise ValueError("entry must contain key 'filename'")

        if isinstance(filename, (str, os.PathLike, Image.Image)):
            images_context = self._load_pil_if_needed(filename)
        elif self._is_sequence_but_not_str(filename):
            images_context = [self._load_pil_if_needed(i) for i in filename]
        else:
            raise ValueError(
                "Unsupported 'filename' entry: expected path, PIL.Image, or sequence."
            )

        images_only_messages = [
            {
                "role": "user",
                "content": [
                    *(
                        [{"type": "image", "image": img} for img in images_context]
                        if isinstance(images_context, list)
                        else [{"type": "image", "image": images_context}]
                    )
                ],
            }
        ]

        try:
            image_inputs, _ = process_vision_info(images_only_messages)
        except Exception as e:
            raise RuntimeError(f"Image processing failed: {e}")

        texts: List[str] = []
        for q in list_of_questions:
            messages = [
                {
                    "role": "user",
                    "content": [
                        *(
                            [
                                {"type": "image", "image": image}
                                for image in images_context
                            ]
                            if isinstance(images_context, list)
                            else [{"type": "image", "image": images_context}]
                        ),
                        {"type": "text", "text": q},
                    ],
                }
            ]
            text = self.summary_model.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            texts.append(text)

        images_batch = [image_inputs] * len(texts)
        inputs = self.summary_model.processor(
            text=texts,
            images=images_batch,
            padding=True,
            return_tensors="pt",
        )
        inputs = {k: v.to(self.summary_model.device) for k, v in inputs.items()}

        return inputs

    def _validate_analysis_type(
        self,
        analysis_type: Union["AnalysisType", str],
        list_of_questions: Optional[List[str]],
        max_questions_per_image: int,
    ) -> Tuple[str, List[str], bool, bool]:
        if isinstance(analysis_type, AnalysisType):
            analysis_type = analysis_type.value

        allowed = {"summary", "questions", "summary_and_questions"}
        if analysis_type not in allowed:
            raise ValueError(f"analysis_type must be one of {allowed}")

        if list_of_questions is None:
            list_of_questions = [
                "Are there people in the image?",
                "What is this picture about?",
            ]

        if analysis_type in ("questions", "summary_and_questions"):
            if len(list_of_questions) > max_questions_per_image:
                raise ValueError(
                    f"Number of questions per image ({len(list_of_questions)}) exceeds safety cap ({max_questions_per_image}). Reduce questions or increase max_questions_per_image."
                )

        is_summary = analysis_type in ("summary", "summary_and_questions")
        is_questions = analysis_type in ("questions", "summary_and_questions")

        return analysis_type, list_of_questions, is_summary, is_questions

    def analyse_image(
        self,
        entry: dict,
        analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
        list_of_questions: Optional[List[str]] = None,
        max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
        is_concise_summary: bool = True,
        is_concise_answer: bool = True,
    ) -> Dict[str, Any]:
        """
        Analyse a single image entry. Returns dict with keys depending on analysis_type:
            - 'caption' (str) if summary requested
            - 'vqa' (dict) if questions requested
        """
        self.subdict = entry
        analysis_type, list_of_questions, is_summary, is_questions = (
            self._validate_analysis_type(
                analysis_type, list_of_questions, max_questions_per_image
            )
        )

        if is_summary:
            try:
                caps = self.generate_caption(
                    entry,
                    num_return_sequences=1,
                    is_concise_summary=is_concise_summary,
                )
                self.subdict["caption"] = caps[0] if caps else ""
            except Exception as e:
                warnings.warn(f"Caption generation failed: {e}")

        if is_questions:
            try:
                vqa_map = self.answer_questions(
                    list_of_questions, entry, is_concise_answer
                )
                self.subdict["vqa"] = vqa_map
            except Exception as e:
                warnings.warn(f"VQA failed: {e}")

        return self.subdict

    def analyse_images_from_dict(
        self,
        analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
        list_of_questions: Optional[List[str]] = None,
        max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
        keys_batch_size: int = KEYS_BATCH_SIZE,
        is_concise_summary: bool = True,
        is_concise_answer: bool = True,
    ) -> Dict[str, dict]:
        """
        Analyse image with  model.

        Args:
            analysis_type (str): type of the analysis.
            list_of_questions (list[str]): list of questions.
            max_questions_per_image (int): maximum number of questions per image.
                We recommend to keep it low to avoid long processing times and high memory usage.
            keys_batch_size (int): number of images to process in a batch.
            is_concise_summary (bool): whether to generate concise summary.
            is_concise_answer (bool): whether to generate concise answers.
        Returns:
            self.subdict (dict): dictionary with analysis results.
        """
        # TODO: add option to ask multiple questions per image as one batch.
        analysis_type, list_of_questions, is_summary, is_questions = (
            self._validate_analysis_type(
                analysis_type, list_of_questions, max_questions_per_image
            )
        )

        keys = list(self.subdict.keys())
        for batch_start in range(0, len(keys), keys_batch_size):
            batch_keys = keys[batch_start : batch_start + keys_batch_size]
            for key in batch_keys:
                entry = self.subdict[key]
                if is_summary:
                    try:
                        caps = self.generate_caption(
                            entry,
                            num_return_sequences=1,
                            is_concise_summary=is_concise_summary,
                        )
                        entry["caption"] = caps[0] if caps else ""
                    except Exception as e:
                        warnings.warn(f"Caption generation failed: {e}")

                if is_questions:
                    try:
                        vqa_map = self.answer_questions(
                            list_of_questions, entry, is_concise_answer
                        )
                        entry["vqa"] = vqa_map
                    except Exception as e:
                        warnings.warn(f"VQA failed: {e}")

                self.subdict[key] = entry
        return self.subdict

    def generate_caption(
        self,
        entry: Optional[Dict[str, Any]] = None,
        num_return_sequences: int = 1,
        is_concise_summary: bool = True,
    ) -> List[str]:
        """
        Create caption for image. Depending on is_concise_summary it will be either concise or detailed.

        Args:
            entry (dict): dictionary containing the image to be captioned.
            num_return_sequences (int): number of captions to generate.
            is_concise_summary (bool): whether to generate concise summary.

        Returns:
            results (list[str]): list of generated captions.
        """
        prompt = self.token_prompt_config[
            "concise" if is_concise_summary else "default"
        ]["summary"]["prompt"]
        max_new_tokens = self.token_prompt_config[
            "concise" if is_concise_summary else "default"
        ]["summary"]["max_new_tokens"]
        inputs = self._prepare_inputs([prompt], entry)

        gen_conf = GenerationConfig(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            num_return_sequences=num_return_sequences,
        )

        with torch.inference_mode():
            try:
                if self.summary_model.device == "cuda":
                    with torch.amp.autocast("cuda", enabled=True):
                        generated_ids = self.summary_model.model.generate(
                            **inputs, generation_config=gen_conf
                        )
                else:
                    generated_ids = self.summary_model.model.generate(
                        **inputs, generation_config=gen_conf
                    )
            except RuntimeError as e:
                warnings.warn(
                    f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry."
                )
                cudnn_was_enabled = (
                    torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
                )
                if cudnn_was_enabled:
                    torch.backends.cudnn.enabled = False
                try:
                    generated_ids = self.summary_model.model.generate(
                        **inputs, generation_config=gen_conf
                    )
                except Exception as retry_error:
                    raise RuntimeError(
                        f"Failed to generate ids after retry: {retry_error}"
                    ) from retry_error
                finally:
                    if cudnn_was_enabled:
                        torch.backends.cudnn.enabled = True

        decoded = None
        if "input_ids" in inputs:
            in_ids = inputs["input_ids"]
            trimmed = [
                out_ids[len(inp_ids) :]
                for inp_ids, out_ids in zip(in_ids, generated_ids)
            ]
            decoded = self.summary_model.tokenizer.batch_decode(
                trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
        else:
            decoded = self.summary_model.tokenizer.batch_decode(
                generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )

        results = [d.strip() for d in decoded]
        return results

    def _clean_list_of_questions(
        self, list_of_questions: list[str], prompt: str
    ) -> list[str]:
        """Clean the list of questions to contain correctly formatted strings."""
        # remove all None or empty questions
        list_of_questions = [i for i in list_of_questions if i and i.strip()]
        # ensure each question ends with a question mark
        list_of_questions = [
            i.strip() + "?" if not i.strip().endswith("?") else i.strip()
            for i in list_of_questions
        ]
        # ensure each question starts with the prompt
        list_of_questions = [
            i if i.lower().startswith(prompt.lower()) else prompt + i
            for i in list_of_questions
        ]
        return list_of_questions

    def answer_questions(
        self,
        list_of_questions: list[str],
        entry: Optional[Dict[str, Any]] = None,
        is_concise_answer: bool = True,
    ) -> List[str]:
        """
        Create answers for list of questions about image.
        Args:
            list_of_questions (list[str]): list of questions.
            entry (dict): dictionary containing the image to be captioned.
            is_concise_answer (bool): whether to generate concise answers.
        Returns:
            answers (list[str]): list of answers.
        """
        prompt = self.token_prompt_config[
            "concise" if is_concise_answer else "default"
        ]["questions"]["prompt"]
        max_new_tokens = self.token_prompt_config[
            "concise" if is_concise_answer else "default"
        ]["questions"]["max_new_tokens"]

        list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
        gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)

        question_chunk_size = 8
        answers: List[str] = []
        n = len(list_of_questions)
        for i in range(0, n, question_chunk_size):
            chunk = list_of_questions[i : i + question_chunk_size]
            inputs = self._prepare_inputs(chunk, entry)
            with torch.inference_mode():
                if self.summary_model.device == "cuda":
                    with torch.amp.autocast("cuda", enabled=True):
                        out_ids = self.summary_model.model.generate(
                            **inputs, generation_config=gen_conf
                        )
                else:
                    out_ids = self.summary_model.model.generate(
                        **inputs, generation_config=gen_conf
                    )

            if "input_ids" in inputs:
                in_ids = inputs["input_ids"]
                trimmed_batch = [
                    out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
                ]
                decoded = self.summary_model.tokenizer.batch_decode(
                    trimmed_batch,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                )
            else:
                decoded = self.summary_model.tokenizer.batch_decode(
                    out_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                )

            answers.extend([d.strip() for d in decoded])

        if len(answers) != len(list_of_questions):
            raise ValueError(
                f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions"
            )

        return answers

__init__(summary_model, subdict=None)

Class for analysing images using QWEN-2.5-VL model. It provides methods for generating captions and answering questions about images.

Parameters:

Name Type Description Default
summary_model [type]

An instance of MultimodalSummaryModel to be used for analysis.

required
subdict dict

Dictionary containing the image to be analysed. Defaults to {}.

None

Returns:

Type Description
None

None.

Source code in ammico/image_summary.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    summary_model: MultimodalSummaryModel,
    subdict: Optional[Dict[str, Any]] = None,
) -> None:
    """
    Class for analysing images using QWEN-2.5-VL model.
    It provides methods for generating captions and answering questions about images.

    Args:
        summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
        subdict (dict, optional): Dictionary containing the image to be analysed. Defaults to {}.

    Returns:
        None.
    """
    if subdict is None:
        subdict = {}

    super().__init__(subdict)
    self.summary_model = summary_model

analyse_image(entry, analysis_type=AnalysisType.SUMMARY_AND_QUESTIONS, list_of_questions=None, max_questions_per_image=MAX_QUESTIONS_PER_IMAGE, is_concise_summary=True, is_concise_answer=True)

Analyse a single image entry. Returns dict with keys depending on analysis_type: - 'caption' (str) if summary requested - 'vqa' (dict) if questions requested

Source code in ammico/image_summary.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def analyse_image(
    self,
    entry: dict,
    analysis_type: Union[str, AnalysisType] = AnalysisType.SUMMARY_AND_QUESTIONS,
    list_of_questions: Optional[List[str]] = None,
    max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
    is_concise_summary: bool = True,
    is_concise_answer: bool = True,
) -> Dict[str, Any]:
    """
    Analyse a single image entry. Returns dict with keys depending on analysis_type:
        - 'caption' (str) if summary requested
        - 'vqa' (dict) if questions requested
    """
    self.subdict = entry
    analysis_type, list_of_questions, is_summary, is_questions = (
        self._validate_analysis_type(
            analysis_type, list_of_questions, max_questions_per_image
        )
    )

    if is_summary:
        try:
            caps = self.generate_caption(
                entry,
                num_return_sequences=1,
                is_concise_summary=is_concise_summary,
            )
            self.subdict["caption"] = caps[0] if caps else ""
        except Exception as e:
            warnings.warn(f"Caption generation failed: {e}")

    if is_questions:
        try:
            vqa_map = self.answer_questions(
                list_of_questions, entry, is_concise_answer
            )
            self.subdict["vqa"] = vqa_map
        except Exception as e:
            warnings.warn(f"VQA failed: {e}")

    return self.subdict

analyse_images_from_dict(analysis_type=AnalysisType.SUMMARY_AND_QUESTIONS, list_of_questions=None, max_questions_per_image=MAX_QUESTIONS_PER_IMAGE, keys_batch_size=KEYS_BATCH_SIZE, is_concise_summary=True, is_concise_answer=True)

Analyse image with model.

Parameters:

Name Type Description Default
analysis_type str

type of the analysis.

SUMMARY_AND_QUESTIONS
list_of_questions list[str]

list of questions.

None
max_questions_per_image int

maximum number of questions per image. We recommend to keep it low to avoid long processing times and high memory usage.

MAX_QUESTIONS_PER_IMAGE
keys_batch_size int

number of images to process in a batch.

KEYS_BATCH_SIZE
is_concise_summary bool

whether to generate concise summary.

True
is_concise_answer bool

whether to generate concise answers.

True

Returns: self.subdict (dict): dictionary with analysis results.

Source code in ammico/image_summary.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def analyse_images_from_dict(
    self,
    analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY_AND_QUESTIONS,
    list_of_questions: Optional[List[str]] = None,
    max_questions_per_image: int = MAX_QUESTIONS_PER_IMAGE,
    keys_batch_size: int = KEYS_BATCH_SIZE,
    is_concise_summary: bool = True,
    is_concise_answer: bool = True,
) -> Dict[str, dict]:
    """
    Analyse image with  model.

    Args:
        analysis_type (str): type of the analysis.
        list_of_questions (list[str]): list of questions.
        max_questions_per_image (int): maximum number of questions per image.
            We recommend to keep it low to avoid long processing times and high memory usage.
        keys_batch_size (int): number of images to process in a batch.
        is_concise_summary (bool): whether to generate concise summary.
        is_concise_answer (bool): whether to generate concise answers.
    Returns:
        self.subdict (dict): dictionary with analysis results.
    """
    # TODO: add option to ask multiple questions per image as one batch.
    analysis_type, list_of_questions, is_summary, is_questions = (
        self._validate_analysis_type(
            analysis_type, list_of_questions, max_questions_per_image
        )
    )

    keys = list(self.subdict.keys())
    for batch_start in range(0, len(keys), keys_batch_size):
        batch_keys = keys[batch_start : batch_start + keys_batch_size]
        for key in batch_keys:
            entry = self.subdict[key]
            if is_summary:
                try:
                    caps = self.generate_caption(
                        entry,
                        num_return_sequences=1,
                        is_concise_summary=is_concise_summary,
                    )
                    entry["caption"] = caps[0] if caps else ""
                except Exception as e:
                    warnings.warn(f"Caption generation failed: {e}")

            if is_questions:
                try:
                    vqa_map = self.answer_questions(
                        list_of_questions, entry, is_concise_answer
                    )
                    entry["vqa"] = vqa_map
                except Exception as e:
                    warnings.warn(f"VQA failed: {e}")

            self.subdict[key] = entry
    return self.subdict

answer_questions(list_of_questions, entry=None, is_concise_answer=True)

Create answers for list of questions about image. Args: list_of_questions (list[str]): list of questions. entry (dict): dictionary containing the image to be captioned. is_concise_answer (bool): whether to generate concise answers. Returns: answers (list[str]): list of answers.

Source code in ammico/image_summary.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
def answer_questions(
    self,
    list_of_questions: list[str],
    entry: Optional[Dict[str, Any]] = None,
    is_concise_answer: bool = True,
) -> List[str]:
    """
    Create answers for list of questions about image.
    Args:
        list_of_questions (list[str]): list of questions.
        entry (dict): dictionary containing the image to be captioned.
        is_concise_answer (bool): whether to generate concise answers.
    Returns:
        answers (list[str]): list of answers.
    """
    prompt = self.token_prompt_config[
        "concise" if is_concise_answer else "default"
    ]["questions"]["prompt"]
    max_new_tokens = self.token_prompt_config[
        "concise" if is_concise_answer else "default"
    ]["questions"]["max_new_tokens"]

    list_of_questions = self._clean_list_of_questions(list_of_questions, prompt)
    gen_conf = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)

    question_chunk_size = 8
    answers: List[str] = []
    n = len(list_of_questions)
    for i in range(0, n, question_chunk_size):
        chunk = list_of_questions[i : i + question_chunk_size]
        inputs = self._prepare_inputs(chunk, entry)
        with torch.inference_mode():
            if self.summary_model.device == "cuda":
                with torch.amp.autocast("cuda", enabled=True):
                    out_ids = self.summary_model.model.generate(
                        **inputs, generation_config=gen_conf
                    )
            else:
                out_ids = self.summary_model.model.generate(
                    **inputs, generation_config=gen_conf
                )

        if "input_ids" in inputs:
            in_ids = inputs["input_ids"]
            trimmed_batch = [
                out_row[len(inp_row) :] for inp_row, out_row in zip(in_ids, out_ids)
            ]
            decoded = self.summary_model.tokenizer.batch_decode(
                trimmed_batch,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        else:
            decoded = self.summary_model.tokenizer.batch_decode(
                out_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )

        answers.extend([d.strip() for d in decoded])

    if len(answers) != len(list_of_questions):
        raise ValueError(
            f"Expected {len(list_of_questions)} answers, but got {len(answers)}, try varying amount of questions"
        )

    return answers

generate_caption(entry=None, num_return_sequences=1, is_concise_summary=True)

Create caption for image. Depending on is_concise_summary it will be either concise or detailed.

Parameters:

Name Type Description Default
entry dict

dictionary containing the image to be captioned.

None
num_return_sequences int

number of captions to generate.

1
is_concise_summary bool

whether to generate concise summary.

True

Returns:

Name Type Description
results list[str]

list of generated captions.

Source code in ammico/image_summary.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def generate_caption(
    self,
    entry: Optional[Dict[str, Any]] = None,
    num_return_sequences: int = 1,
    is_concise_summary: bool = True,
) -> List[str]:
    """
    Create caption for image. Depending on is_concise_summary it will be either concise or detailed.

    Args:
        entry (dict): dictionary containing the image to be captioned.
        num_return_sequences (int): number of captions to generate.
        is_concise_summary (bool): whether to generate concise summary.

    Returns:
        results (list[str]): list of generated captions.
    """
    prompt = self.token_prompt_config[
        "concise" if is_concise_summary else "default"
    ]["summary"]["prompt"]
    max_new_tokens = self.token_prompt_config[
        "concise" if is_concise_summary else "default"
    ]["summary"]["max_new_tokens"]
    inputs = self._prepare_inputs([prompt], entry)

    gen_conf = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_return_sequences=num_return_sequences,
    )

    with torch.inference_mode():
        try:
            if self.summary_model.device == "cuda":
                with torch.amp.autocast("cuda", enabled=True):
                    generated_ids = self.summary_model.model.generate(
                        **inputs, generation_config=gen_conf
                    )
            else:
                generated_ids = self.summary_model.model.generate(
                    **inputs, generation_config=gen_conf
                )
        except RuntimeError as e:
            warnings.warn(
                f"Retry without autocast failed: {e}. Attempting cudnn-disabled retry."
            )
            cudnn_was_enabled = (
                torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
            )
            if cudnn_was_enabled:
                torch.backends.cudnn.enabled = False
            try:
                generated_ids = self.summary_model.model.generate(
                    **inputs, generation_config=gen_conf
                )
            except Exception as retry_error:
                raise RuntimeError(
                    f"Failed to generate ids after retry: {retry_error}"
                ) from retry_error
            finally:
                if cudnn_was_enabled:
                    torch.backends.cudnn.enabled = True

    decoded = None
    if "input_ids" in inputs:
        in_ids = inputs["input_ids"]
        trimmed = [
            out_ids[len(inp_ids) :]
            for inp_ids, out_ids in zip(in_ids, generated_ids)
        ]
        decoded = self.summary_model.tokenizer.batch_decode(
            trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
    else:
        decoded = self.summary_model.tokenizer.batch_decode(
            generated_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )

    results = [d.strip() for d in decoded]
    return results

Video Summary

VideoSummaryDetector

Bases: AnalysisMethod

Source code in ammico/video_summary.py
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
class VideoSummaryDetector(AnalysisMethod):
    def __init__(
        self,
        summary_model: MultimodalSummaryModel = None,
        audio_model: Optional[AudioToTextModel] = None,
        subdict: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        Class for analysing videos using QWEN-2.5-VL model.
        It provides methods for generating captions and answering questions about videos.

        Args:
            summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
            subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}.

        Returns:
            None.
        """
        if subdict is None:
            subdict = {}

        super().__init__(subdict)
        _validate_subdict(subdict)
        self.summary_model = summary_model or None
        self.audio_model = audio_model
        self.prompt_builder = PromptBuilder()

    def _decode_trimmed_outputs(
        self,
        generated_ids: torch.Tensor,
        inputs: Dict[str, torch.Tensor],
        tokenizer,
        prompt_texts: List[str],
    ) -> List[str]:
        """
        Trim prompt tokens using attention_mask/input_ids when available and decode to strings.
        Then remove any literal prompt prefix using prompt_texts (one per batch element).
        Args:
            generated_ids (torch.Tensor): Generated token IDs from the model.
            inputs (Dict[str, torch.Tensor]): Original input tensors used for generation.
            tokenizer: Tokenizer used for decoding the generated outputs.
            prompt_texts (List[str]): List of prompt texts corresponding to each input in the batch.
        Returns:
            List[str]: Decoded generated texts after trimming and cleaning.
        """

        batch_size = generated_ids.shape[0]

        if "input_ids" in inputs:
            token_for_padding = (
                tokenizer.pad_token_id
                if getattr(tokenizer, "pad_token_id", None) is not None
                else getattr(tokenizer, "eos_token_id", None)
            )
            if token_for_padding is None:
                lengths = [int(inputs["input_ids"].shape[1])] * batch_size
            else:
                lengths = inputs["input_ids"].ne(token_for_padding).sum(dim=1).tolist()
        else:
            lengths = [0] * batch_size

        trimmed_ids = []
        for i in range(batch_size):
            out_ids = generated_ids[i]
            in_len = int(lengths[i]) if i < len(lengths) else 0
            if out_ids.size(0) > in_len:
                t = out_ids[in_len:]
            else:
                t = out_ids.new_empty((0,), dtype=out_ids.dtype)
            t_cpu = t.to("cpu")
            trimmed_ids.append(t_cpu.tolist())

        decoded = tokenizer.batch_decode(
            trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        decoded_results = []
        for ptext, raw in zip(prompt_texts, decoded):
            cleaned = _strip_prompt_prefix_literal(raw, ptext)
            decoded_results.append(cleaned)
        return decoded_results

    def _generate_from_processor_inputs(
        self,
        processor_inputs: Dict[str, torch.Tensor],
        prompt_texts: List[str],
        tokenizer,
        len_objects: Optional[int] = None,
    ) -> List[str]:
        """
        Run model.generate on already-processed processor_inputs (tensors moved to device),
        then decode and trim prompt tokens & remove literal prompt prefixes using prompt_texts.
        Args:
            processor_inputs (Dict[str, torch.Tensor]): Inputs prepared by the processor.
            prompt_texts (List[str]): List of prompt texts corresponding to each input in the batch.
            tokenizer: Tokenizer used for decoding the generated outputs.
            len_objects (Optional[int], optional): Number of objects/frames to adjust max_new_tokens. Defaults to None.
        Returns:
            List[str]: Decoded generated texts after trimming and cleaning.
        """
        # In case of many frames, allow more max_new_tokens # TODO recheck the logic
        if len_objects is not None:
            max_new_tokens = len_objects * 128
        else:
            max_new_tokens = 128
        gen_conf = GenerationConfig(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            num_return_sequences=1,
        )

        for k, v in processor_inputs.items():
            if isinstance(v, torch.Tensor):
                processor_inputs[k] = v.to(self.summary_model.device)

        with torch.inference_mode():
            try:
                if self.summary_model.device == "cuda":
                    with torch.amp.autocast("cuda", enabled=True):
                        generated_ids = self.summary_model.model.generate(
                            **processor_inputs, generation_config=gen_conf
                        )
                else:
                    generated_ids = self.summary_model.model.generate(
                        **processor_inputs, generation_config=gen_conf
                    )
            except RuntimeError as e:
                warnings.warn(
                    f"Generation failed with error: {e}. Retrying with cuDNN disabled.",
                    RuntimeWarning,
                )
                cudnn_was_enabled = (
                    torch.backends.cudnn.is_available() and torch.backends.cudnn.enabled
                )
                if cudnn_was_enabled:
                    torch.backends.cudnn.enabled = False
                try:
                    generated_ids = self.summary_model.model.generate(
                        **processor_inputs, generation_config=gen_conf
                    )
                except Exception as retry_error:
                    raise RuntimeError(
                        f"Failed to generate ids after retry: {retry_error}"
                    ) from retry_error
                finally:
                    if cudnn_was_enabled:
                        torch.backends.cudnn.enabled = True

        decoded = self._decode_trimmed_outputs(
            generated_ids, processor_inputs, tokenizer, prompt_texts
        )
        return decoded

    def _audio_to_text(self, audio_path: str) -> List[Dict[str, Any]]:
        """
        Convert audio file to text using an whisper model.
        Args:
            audio_path (str): Path to the audio file.
        Returns:
            List[Dict[str, Any]]: List of transcribed audio segments with start_time, end_time, text, and duration.
        """

        if not os.path.exists(audio_path):
            raise ValueError(f"Audio file {audio_path} does not exist.")

        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
            audio = whisperx.load_audio(audio_path)
            transcribe_result = self.audio_model.model.transcribe(audio)
            model_a, metadata = whisperx.load_align_model(
                language_code=transcribe_result["language"],
                device=self.audio_model.device,
            )
            aligned_result = whisperx.align(
                transcribe_result["segments"],
                model_a,
                metadata,
                audio,
                self.audio_model.device,
            )
            audio_descriptions = []
            for segment in aligned_result["segments"]:
                audio_descriptions.append(
                    {
                        "start_time": segment["start"],
                        "end_time": segment["end"],
                        "text": segment["text"].strip(),
                        "duration": segment["end"] - segment["start"],
                    }
                )
            return audio_descriptions
        except Exception as e:
            raise RuntimeError(f"Failed to transcribe audio: {e}")

    def _check_audio_stream(self, filename: str) -> bool:
        """
        Check if the video file has an audio stream.
        Args:
            filename (str): Path to the video file.
        Returns:
            bool: True if audio stream exists, False otherwise.
        """
        try:
            cmd = [
                "ffprobe",
                "-v",
                "error",
                "-select_streams",
                "a",
                "-show_entries",
                "stream=codec_type",
                "-of",
                "default=noprint_wrappers=1:nokey=1",
                filename,
            ]
            result = subprocess.run(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
            )
            output = result.stdout.strip()
            return bool(output)
        except Exception as e:
            warnings.warn(
                f"Failed to check audio stream in video {filename}: {e}",
                RuntimeWarning,
            )
            return False

    def _extract_transcribe_audio_part(
        self,
        filename: str,
    ) -> List[Dict[str, Any]]:
        """
        Extract audio part from the video file and generate captions using an audio whisperx model.
        Args:
            filename (str): Path to the video file.
        Returns:
            List[Dict[str, Any]]: List of transcribed audio segments with start_time, end_time, text, and duration.
        """

        if not self._check_audio_stream(filename):
            self.audio_model.close()
            self.audio_model = None
            return []

        with tempfile.TemporaryDirectory() as tmpdir:
            audio_output_path = os.path.join(tmpdir, "audio_extracted.wav")
            try:
                subprocess.run(
                    [
                        "ffmpeg",
                        "-i",
                        filename,
                        "-vn",
                        "-acodec",
                        "pcm_s16le",
                        "-ar",
                        "16000",
                        "-ac",
                        "1",
                        "-y",
                        audio_output_path,
                    ],
                    check=True,
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.DEVNULL,
                )
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to extract audio from video: {e}")

            audio_descriptions = self._audio_to_text(audio_output_path)

        # and close the audio model to free up resources
        self.audio_model.close()
        self.audio_model = None

        return audio_descriptions

    def _detect_scene_cuts(
        self,
        filename: str,
    ) -> Dict[str, Any]:
        """
        Detect scene cuts in the video using frame differencing method.
        Args:
            filename: Path to the video file
        Returns:
            List of segments with 'start_time' and 'end_time'
        """

        cap = cv2.VideoCapture(filename)
        fps = cap.get(cv2.CAP_PROP_FPS)

        frames = []
        img_height, img_width = None, None
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            img_height, img_width = frame.shape[:2]

            try:
                if img_width / img_height > 1.2:
                    frame_small = cv2.resize(frame, (320, 240))
                elif img_width / img_height < 0.8:
                    frame_small = cv2.resize(frame, (240, 320))
                else:
                    frame_small = cv2.resize(frame, (320, 320))
            except Exception as e:
                raise RuntimeError(
                    f"Failed to resize frame for scene cut detection: {e}"
                )

            gray = cv2.cvtColor(
                frame_small, cv2.COLOR_BGR2GRAY
            )  # TODO check if it is ok, maybe we can use color info as well
            frames.append(gray)

        cap.release()
        if img_height is None or img_width is None:
            raise ValueError(
                "Failed to read frames from video for scene cut detection."
            )

        # Compute frame differences to keep memory usage low
        frame_diffs = []
        for i in range(1, len(frames)):
            diff = cv2.absdiff(frames[i], frames[i - 1])
            mean_diff = np.mean(diff)
            frame_diffs.append(mean_diff)

        # Find peaks in differences (scene cuts) via adaptive threshold based on median
        threshold = 25.0
        median_diff = np.median(frame_diffs)
        cut_threshold = median_diff + threshold

        cut_frames = signal.find_peaks(
            frame_diffs,
            height=cut_threshold,
            distance=int(fps * 0.5),  # At least 0.5s between cuts
        )[0]

        video_segments = []
        cut_frames_with_starts = [0] + list(cut_frames) + [len(frames)]

        for i in range(len(cut_frames_with_starts) - 1):
            start_frame = cut_frames_with_starts[i]
            end_frame = cut_frames_with_starts[i + 1]

            video_segments.append(
                {
                    "type": "video_scene",
                    "start_time": start_frame / fps,
                    "end_time": end_frame / fps,
                    "duration": (end_frame - start_frame) / fps,
                }
            )

        # Since there may be issues with last frame detection, slightly adjust the end_time of the last segment
        last_segment = video_segments[-1]
        last_segment["end_time"] -= 0.5
        # Ensure the end_time does not go below the start_time in case of very short last segment/video
        if last_segment["end_time"] < last_segment["start_time"]:
            last_segment["end_time"] = last_segment["start_time"]

        return {
            "segments": video_segments,
            "video_meta": {
                "width": img_width,
                "height": img_height,
            },
        }

    def _extract_frame_timestamps_from_clip(
        self,
        filename: str,
    ) -> Dict[str, Any]:
        """
        Extract frame timestamps for each detected video segment.
        Args:
            filename: Path to the video file
        Returns:
            List of segments with 'start_time', 'end_time', and 'frame_timestamps'
        """
        base_frames_per_clip = 4.0
        result = self._detect_scene_cuts(filename)
        segments = result["segments"]
        video_meta = result["video_meta"]
        for seg in segments:
            if seg["duration"] < 2.0:
                frame_rate_per_clip = 2.0
            elif seg["duration"] > 20.0:
                frame_rate_per_clip = 6.0
            else:
                frame_rate_per_clip = base_frames_per_clip

            start_time = seg["start_time"]
            end_time = seg["end_time"]
            n_samples = max(1, int(frame_rate_per_clip))
            sample_times = torch.linspace(
                start_time, end_time, steps=n_samples, dtype=torch.float32
            )
            seg["frame_timestamps"] = sample_times.tolist()

        return {
            "segments": segments,
            "video_meta": video_meta,
        }

    def _reassign_video_timestamps_to_segments(
        self,
        segments: List[Dict[str, Any]],
        video_segs: List[Dict[str, Any]],
    ) -> None:
        """
        Reassign video frame timestamps to each new segment based on overlapping video scenes.
        Args:
            segments: List of segments to assign timestamps to.
            video_segs: List of video scenes with original frame timestamps.
        Returns:
            None
        """

        boundary_margin = 0.5
        eps = 1e-6

        video_list = list(video_segs)
        for seg in segments:
            seg_start = seg["start_time"]
            seg_end = seg["end_time"]

            merged_timestamps: List[float] = []
            for vscene in video_list:
                if "frame_timestamps" not in vscene:
                    raise ValueError("Video scene missing 'frame_timestamps' key.")

                contrib = [
                    float(t)
                    for t in vscene["frame_timestamps"]
                    if (t + eps) >= (seg_start - boundary_margin)
                    and (t - eps) <= (seg_end + boundary_margin)
                    and (t + eps) >= seg_start
                    and (t - eps) <= seg_end
                ]
                if contrib:
                    merged_timestamps.extend(contrib)

            # dedupe & sort
            seg["video_frame_timestamps"] = sorted(set(merged_timestamps))

    def _combine_visual_frames_by_time(
        self,
        video_segs: List[Dict[str, Any]],
    ) -> List[Dict[str, Any]]:
        """
        Split too-long video segments (>25s).
        Args:
            video_segs: List of video segments with 'start_time' and 'end_time'
        Returns:
            List of combined segments
        """

        if not video_segs:
            raise ValueError("No video segments to combine.")
        out = []
        for vs in video_segs:
            st, ed, dur = (
                float(vs["start_time"]),
                float(vs["end_time"]),
                float(vs["duration"]),
            )
            if dur > 25.0:
                parts = int(math.ceil(dur / 25.0))
                part_dur = dur / parts
                for p in range(parts):
                    ps = st + p * part_dur
                    pe = st + (p + 1) * part_dur if p < parts - 1 else ed
                    out.append(
                        {
                            "start_time": ps,
                            "end_time": pe,
                            "duration": pe - ps,
                            "audio_phrases": [],
                            "video_scenes": [vs],
                        }
                    )
            else:
                out.append(
                    {
                        "start_time": st,
                        "end_time": ed,
                        "duration": dur,
                        "audio_phrases": [],
                        "video_scenes": [vs],
                    }
                )

        self._reassign_video_timestamps_to_segments(out, video_segs)
        return out

    def merge_audio_visual_boundaries(
        self,
        audio_segs: List[Dict[str, Any]],
        video_segs: List[Dict[str, Any]],
        segment_threshold_duration: int = 8,
    ) -> List[Dict[str, Any]]:
        """
        Merge audio phrase boundaries and video scene cuts into coherent temporal segments for the model
        Args:
            audio_segs: List of audio segments with 'start_time' and 'end_time'
            video_segs: List of video segments with 'start_time' and 'end_time'
            segment_threshold_duration: Duration to create a new segment boundary
        Returns:
            List of merged segments
        """
        if not audio_segs:
            new_vid = self._combine_visual_frames_by_time(video_segs)
            return new_vid

        events = [
            ("audio", seg["start_time"], seg["end_time"], seg) for seg in audio_segs
        ] + [("video", seg["start_time"], seg["end_time"], seg) for seg in video_segs]

        if not events:
            raise ValueError("No audio and video segments to merge.")

        events.sort(key=lambda x: x[1])
        global_last_end = max(e[2] for e in events)
        # Create merged segments respecting both boundaries
        merged = []
        current_segment_start = 0
        current_audio_phrases = []
        current_video_scenes = []

        for event_type, start, _, data in events:
            current_duration = start - current_segment_start
            if current_duration > segment_threshold_duration:
                segment_end = start

                if segment_end < current_segment_start:
                    segment_end = current_segment_start

                merged.append(
                    {
                        "start_time": current_segment_start,
                        "end_time": segment_end,
                        "audio_phrases": current_audio_phrases,
                        "video_scenes": current_video_scenes,
                        "duration": segment_end - current_segment_start,
                    }
                )
                # start a new segment at the current event's start
                current_segment_start = segment_end
                current_audio_phrases = []
                current_video_scenes = []

            if event_type == "audio":
                current_audio_phrases.append(data)
            else:
                current_video_scenes.append(data)

        if current_audio_phrases or current_video_scenes:
            final_end = max(global_last_end, events[-1][2], current_segment_start)
            if final_end < current_segment_start:
                final_end = current_segment_start

            merged.append(
                {
                    "start_time": current_segment_start,
                    "end_time": final_end,
                    "audio_phrases": current_audio_phrases,
                    "video_scenes": current_video_scenes,
                    "duration": final_end - current_segment_start,
                }
            )

        self._reassign_video_timestamps_to_segments(merged, video_segs)
        return merged

    def _run_ffmpeg(
        self, cmd_args: List[str], timeout: Optional[float]
    ) -> subprocess.CompletedProcess:
        """
        Execute ffmpeg command and return the completed process.
        Args:
            cmd_args: List of ffmpeg command arguments.
            timeout: Timeout for the subprocess.
        Returns:
            CompletedProcess: Result of the subprocess execution.
        """

        cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error"] + cmd_args
        return subprocess.run(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout
        )

    def _build_extract_command(
        self, filename: str, timestamp: float, accurate: bool, codec: str = "png"
    ) -> List[str]:
        """
        Build ffmpeg command for frame extraction.

        Args:
            filename: Path to video file
            timestamp: Time in seconds
            accurate: If True, seek after input (slow but accurate)
            codec: Output codec ('png' or 'mjpeg')
        Returns:
            List of ffmpeg command arguments
        """
        ss_arg = f"{timestamp:.6f}"
        cmd = []
        # Position -ss based on accuracy requirement
        if accurate:
            cmd = ["-i", filename, "-ss", ss_arg]  # accurate mode
        else:
            cmd = ["-ss", ss_arg, "-i", filename]  # fast mode

        # Common extraction parameters
        cmd += ["-frames:v", "1", "-f", "image2pipe"]

        # Codec-specific settings
        if codec == "png":
            cmd += ["-vcodec", "png", "-pix_fmt", "rgb24"]
        elif codec == "mjpeg":
            cmd += ["-vcodec", "mjpeg", "-pix_fmt", "yuvj420p"]

        cmd.append("pipe:1")
        return cmd

    def _run_ffmpeg_extraction(
        self,
        filename: str,
        timestamp: float,
        out_w: int,
        out_h: int,
        timeout: Optional[float] = 30.0,
    ) -> Image.Image:
        """
        Extract a single frame at the specified timestamp.

        Args:
            filename: Path to video file
            timestamp: Time in seconds
            out_w: Optional output width
            out_h: Optional output height
            timeout: Subprocess timeout in seconds

        Returns:
            PIL Image in RGB format

        Raises:
            RuntimeError: If frame extraction fails
        """

        strategies = [
            ("mjpeg", False),
            ("png", True),
        ]

        last_error = None

        for codec, use_accurate in strategies:
            try:
                cmd = self._build_extract_command(
                    filename, timestamp, use_accurate, codec
                )
                proc = self._run_ffmpeg(cmd, timeout)

                if proc.returncode == 0 and proc.stdout:
                    img = Image.open(BytesIO(proc.stdout)).convert("RGB")
                    img = img.resize((out_w, out_h), resample=Image.BILINEAR)
                    return img
                else:
                    last_error = proc.stderr.decode("utf-8", errors="replace")

            except Exception as e:
                last_error = str(e)
                warnings.warn(
                    f"Frame extraction failed at {timestamp:.3f}s with codec {codec} "
                    f"({'accurate' if use_accurate else 'fast'}): {last_error}",
                    RuntimeWarning,
                )

        raise RuntimeError(
            f"Failed to extract frame at {timestamp:.3f}s from {filename}. "
            f"Last error: {last_error[:500]}"
        )

    def _calculate_output_dimensions(
        self, original_w: int, original_h: int
    ) -> Tuple[int, int]:
        """
        Calculate output dimensions in a fully adaptive way, preserving aspect ratio, but decreasing size.
        It works both for landscape and portrait videos.
        Args:
            original_w: Original width
            original_h: Original height
        Returns:
            Tuple of (out_w, out_h)
        """
        aspect_ratio = original_w / original_h
        max_dimension = 720

        if aspect_ratio > 1.2:
            out_w = max_dimension
            out_h = int(max_dimension / aspect_ratio)
        elif aspect_ratio < 0.8:
            out_h = max_dimension
            out_w = int(max_dimension * aspect_ratio)
        else:
            out_w = max_dimension
            out_h = max_dimension
        return out_w, out_h

    def _extract_frames_ffmpeg(
        self,
        filename: str,
        timestamps: List[float],
        original_w: int,
        original_h: int,
        workers: int = 4,
    ) -> List[Tuple[float, Image.Image]]:
        """
        Extract multiple frames using a thread pool (parallel ffmpeg processes).
        Args:
            filename: Path to video file
            timestamps: List of times in seconds
            original_w: Original width of the video
            original_h: Original height of the video
            workers: Number of parallel threads
        Returns:
          List of (timestamp, PIL.Image) preserving order of timestamps.
        """
        results = {}
        out_w, out_h = self._calculate_output_dimensions(original_w, original_h)
        with ThreadPoolExecutor(max_workers=workers) as ex:
            futures = {
                ex.submit(self._run_ffmpeg_extraction, filename, t, out_w, out_h): i
                for i, t in enumerate(timestamps)
            }
            for fut in as_completed(futures):
                idx = futures[fut]
                try:
                    img = fut.result()
                    results[idx] = img
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to extract frame for {timestamps[idx]}s: {e}"
                    ) from e

        return [(timestamps[i], results[i]) for i in range(len(timestamps))]

    def _make_captions_from_extracted_frames(
        self,
        filename: str,
        merged_segments: List[Dict[str, Any]],
        video_meta: Dict[str, Any],
        list_of_questions: Optional[List[str]] = None,
    ) -> None:
        """
        Generate captions for all extracted frames and then produce a concise summary of the video.
        Args:
            filename (str): Path to the video file.
            merged_segments (List[Dict[str, Any]]): List of merged segments with frame timestamps.
            list_of_questions (Optional[List[str]]): List of questions for VQA.
        Returns:
            None. Modifies merged_segments in place to add 'summary_bullets' and 'vqa_bullets'.
        """
        proc = self.summary_model.processor

        img_width = video_meta.get("width")
        img_height = video_meta.get("height")
        if img_width is None or img_height is None:
            raise ValueError(
                "Frame dimensions not found in the last segment for extraction."
            )

        for seg in merged_segments:  # TODO might be generator faster, so changes to ffmpeg extraction may be needed
            collected: List[Tuple[float, str]] = []
            frame_timestamps = seg.get("video_frame_timestamps", [])
            if not frame_timestamps:
                raise ValueError(
                    f"No frame timestamps found for segment {seg['start_time']:.2f}s to {seg['end_time']:.2f}s"
                )
            include_questions = bool(list_of_questions)
            caption_instruction = self.prompt_builder.build_frame_prompt(
                include_vqa=include_questions,
                questions=list_of_questions,
            )
            pairs = self._extract_frames_ffmpeg(
                filename,
                frame_timestamps,
                original_w=img_width,
                original_h=img_height,
                workers=min(8, (os.cpu_count() or 1) // 2),
            )
            prompt_texts = []

            for ts, img in pairs:
                messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "image", "image": img},
                            {"type": "text", "text": caption_instruction},
                        ],
                    }
                ]

                prompt_text = proc.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )

                prompt_texts.append(prompt_text)

            processor_inputs = proc(
                text=prompt_texts,
                images=[img for _, img in pairs],
                return_tensors="pt",
                padding=True,
            )
            len_objects = len(pairs)
            if include_questions:
                len_objects *= 2  # because we expect two outputs per input when questions are included
            captions = self._generate_from_processor_inputs(
                processor_inputs,
                prompt_texts,
                self.summary_model.tokenizer,
                len_objects=len_objects,
            )
            for t, c in zip(frame_timestamps, captions):
                collected.append((float(t), c))

            collected.sort(key=lambda x: x[0])
            bullets_summary, bullets_vqa = _categorize_outputs(
                collected, include_questions
            )

            seg["summary_bullets"] = bullets_summary
            seg["vqa_bullets"] = bullets_vqa

    def make_captions_for_subclips(
        self,
        entry: Dict[str, Any],
        list_of_questions: Optional[List[str]] = None,
    ) -> List[Dict[str, Any]]:
        """
        Generate captions for video subclips using both audio and visual information, for a further full video summary/VQA.
        Args:
            entry (Dict[str, Any]): Dictionary containing the video file information.
            list_of_questions (Optional[List[str]]): List of questions for VQA.
        Returns:
            List[Dict[str, Any]]: List of dictionaries containing timestamps and generated captions.
        """

        filename = entry.get("filename")
        if not filename:
            raise ValueError("entry must contain key 'filename'")

        if not os.path.exists(filename):
            raise ValueError(f"Video file {filename} does not exist.")

        audio_generated_captions = []
        if self.audio_model is not None:
            audio_generated_captions = self._extract_transcribe_audio_part(filename)
            entry["audio_descriptions"] = audio_generated_captions

        video_result_segments = self._extract_frame_timestamps_from_clip(filename)
        video_segments_w_timestamps = video_result_segments["segments"]
        video_meta = video_result_segments["video_meta"]
        merged_segments = self.merge_audio_visual_boundaries(
            audio_generated_captions,
            video_segments_w_timestamps,
        )

        self._make_captions_from_extracted_frames(
            filename,
            merged_segments,
            video_meta,
            list_of_questions=list_of_questions,
        )
        results = []
        proc = self.summary_model.processor
        for seg in merged_segments:
            frame_timestamps = seg.get("video_frame_timestamps", [])

            collected: List[Tuple[float, str]] = []
            include_audio = False
            audio_lines = seg["audio_phrases"]
            if audio_lines:
                include_audio = True

            include_questions = bool(list_of_questions)
            caption_instruction = self.prompt_builder.build_clip_prompt(
                frame_bullets=seg.get("summary_bullets", []),
                include_audio=include_audio,
                audio_transcription=seg.get("audio_phrases", []),
                include_vqa=include_questions,
                questions=list_of_questions,
                vqa_bullets=seg.get("vqa_bullets", []),
            )
            messages = [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": caption_instruction}],
                }
            ]
            prompt_text = proc.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            processor_inputs = proc(
                text=[prompt_text],
                return_tensors="pt",
                padding=True,
            )
            final_outputs = self._generate_from_processor_inputs(
                processor_inputs,
                [prompt_text],
                self.summary_model.tokenizer,
            )
            for t, c in zip(frame_timestamps, final_outputs):
                collected.append((float(t), c))

            collected.sort(key=lambda x: x[0])
            bullets_summary, bullets_vqa = _categorize_outputs(
                collected, include_questions
            )

            results.append(
                {
                    "start_time": seg["start_time"],
                    "end_time": seg["end_time"],
                    "summary_bullets": bullets_summary,
                    "vqa_bullets": bullets_vqa,
                }
            )

        return results

    def final_summary(self, summary_dict: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Produce a concise summary of the video, based on generated captions for all extracted frames.
        Args:
            summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
        Returns:
            Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
        """
        proc = self.summary_model.processor

        bullets = []
        for seg in summary_dict:
            seg_bullets = seg.get("summary_bullets", [])
            bullets.extend(seg_bullets)
        if not bullets:
            raise ValueError("No captions available for summary generation.")

        summary_user_prompt = self.prompt_builder.build_video_prompt(
            include_vqa=False,
            clip_summaries=bullets,
        )
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": summary_user_prompt}],
            }
        ]

        summary_prompt_text = proc.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        summary_inputs = proc(
            text=[summary_prompt_text], return_tensors="pt", padding=True
        )

        summary_inputs = {
            k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
            for k, v in summary_inputs.items()
        }
        final_summary_list = self._generate_from_processor_inputs(
            summary_inputs,
            [summary_prompt_text],
            self.summary_model.tokenizer,
        )
        final_summary = final_summary_list[0].strip() if final_summary_list else ""

        return {
            "summary": final_summary,
        }

    def final_answers(
        self,
        answers_dict: List[Dict[str, Any]],
        list_of_questions: List[str],
    ) -> Dict[str, Any]:
        """
        Answer the list of questions for the video based on the VQA bullets from the frames.
        Args:
            answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
        Returns:
            Dict[str, Any]: A dictionary containing the list of answers to the questions.
        """
        vqa_bullets = []
        summary_bullets = []
        for seg in answers_dict:
            summary_bullets.extend(seg.get("summary_bullets", []))
            seg_bullets = seg.get("vqa_bullets", [])
            vqa_bullets.extend(seg_bullets)

        if not vqa_bullets:
            raise ValueError(
                "No VQA bullets generated for single frames available for answering questions."
            )

        include_questions = bool(list_of_questions)
        if include_questions:
            prompt = self.prompt_builder.build_video_prompt(
                include_vqa=include_questions,
                questions=list_of_questions,
                vqa_bullets=vqa_bullets,
                clip_summaries=summary_bullets,
            )
        else:
            raise ValueError(
                "list_of_questions must be provided for making final answers."
            )

        proc = self.summary_model.processor
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": prompt}],
            }
        ]
        final_vqa_prompt_text = proc.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        final_vqa_inputs = proc(
            text=[final_vqa_prompt_text], return_tensors="pt", padding=True
        )
        final_vqa_inputs = {
            k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
            for k, v in final_vqa_inputs.items()
        }

        final_vqa_list = self._generate_from_processor_inputs(
            final_vqa_inputs,
            [final_vqa_prompt_text],
            self.summary_model.tokenizer,
        )

        final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
        vqa_answers = []
        answer_matches = re.findall(
            r"\d+\.\s+(.+?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
        )
        for answer in answer_matches:
            vqa_answers.append(answer.strip())
        return {
            "vqa_answers": vqa_answers,
        }

    def analyse_videos_from_dict(
        self,
        analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
        list_of_questions: Optional[List[str]] = None,
    ) -> Dict[str, Any]:
        """
        Analyse the video specified in self.subdict using frame extraction and captioning.
        Args:
            analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
            list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
        Returns:
            Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
        """
        if list_of_questions is not None and not isinstance(list_of_questions, list):
            raise TypeError("Expected list_of_questions to be a list of strings.")
        if list_of_questions and any(not isinstance(q, str) for q in list_of_questions):
            raise ValueError("All items in list_of_questions must be strings.")

        analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
            analysis_type, list_of_questions
        )

        for video_key, entry in self.subdict.items():
            answers_dict = self.make_captions_for_subclips(
                entry,
                list_of_questions=list_of_questions,
            )
            if is_summary:
                answer = self.final_summary(answers_dict)
                entry["summary"] = answer["summary"]
            if is_questions:
                answer = self.final_answers(answers_dict, list_of_questions)
                entry["vqa_answers"] = answer["vqa_answers"]

            self.subdict[video_key] = entry

        return self.subdict

__init__(summary_model=None, audio_model=None, subdict=None)

Class for analysing videos using QWEN-2.5-VL model. It provides methods for generating captions and answering questions about videos.

Parameters:

Name Type Description Default
summary_model [type]

An instance of MultimodalSummaryModel to be used for analysis.

None
subdict dict

Dictionary containing the video to be analysed. Defaults to {}.

None

Returns:

Type Description
None

None.

Source code in ammico/video_summary.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    summary_model: MultimodalSummaryModel = None,
    audio_model: Optional[AudioToTextModel] = None,
    subdict: Optional[Dict[str, Any]] = None,
) -> None:
    """
    Class for analysing videos using QWEN-2.5-VL model.
    It provides methods for generating captions and answering questions about videos.

    Args:
        summary_model ([type], optional): An instance of MultimodalSummaryModel to be used for analysis.
        subdict (dict, optional): Dictionary containing the video to be analysed. Defaults to {}.

    Returns:
        None.
    """
    if subdict is None:
        subdict = {}

    super().__init__(subdict)
    _validate_subdict(subdict)
    self.summary_model = summary_model or None
    self.audio_model = audio_model
    self.prompt_builder = PromptBuilder()

analyse_videos_from_dict(analysis_type=AnalysisType.SUMMARY, list_of_questions=None)

Analyse the video specified in self.subdict using frame extraction and captioning. Args: analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY. list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions. Returns: Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).

Source code in ammico/video_summary.py
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
def analyse_videos_from_dict(
    self,
    analysis_type: Union[AnalysisType, str] = AnalysisType.SUMMARY,
    list_of_questions: Optional[List[str]] = None,
) -> Dict[str, Any]:
    """
    Analyse the video specified in self.subdict using frame extraction and captioning.
    Args:
        analysis_type (Union[AnalysisType, str], optional): Type of analysis to perform. Defaults to AnalysisType.SUMMARY.
        list_of_questions (List[str], optional): List of questions to answer about the video. Required if analysis_type includes questions.
    Returns:
        Dict[str, Any]: A dictionary containing the analysis results, including summary and answers for provided questions(if any).
    """
    if list_of_questions is not None and not isinstance(list_of_questions, list):
        raise TypeError("Expected list_of_questions to be a list of strings.")
    if list_of_questions and any(not isinstance(q, str) for q in list_of_questions):
        raise ValueError("All items in list_of_questions must be strings.")

    analysis_type, is_summary, is_questions = AnalysisType._validate_analysis_type(
        analysis_type, list_of_questions
    )

    for video_key, entry in self.subdict.items():
        answers_dict = self.make_captions_for_subclips(
            entry,
            list_of_questions=list_of_questions,
        )
        if is_summary:
            answer = self.final_summary(answers_dict)
            entry["summary"] = answer["summary"]
        if is_questions:
            answer = self.final_answers(answers_dict, list_of_questions)
            entry["vqa_answers"] = answer["vqa_answers"]

        self.subdict[video_key] = entry

    return self.subdict

final_answers(answers_dict, list_of_questions)

Answer the list of questions for the video based on the VQA bullets from the frames. Args: answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets. Returns: Dict[str, Any]: A dictionary containing the list of answers to the questions.

Source code in ammico/video_summary.py
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def final_answers(
    self,
    answers_dict: List[Dict[str, Any]],
    list_of_questions: List[str],
) -> Dict[str, Any]:
    """
    Answer the list of questions for the video based on the VQA bullets from the frames.
    Args:
        answers_dict (Dict[str, Any]): Dictionary containing the VQA bullets.
    Returns:
        Dict[str, Any]: A dictionary containing the list of answers to the questions.
    """
    vqa_bullets = []
    summary_bullets = []
    for seg in answers_dict:
        summary_bullets.extend(seg.get("summary_bullets", []))
        seg_bullets = seg.get("vqa_bullets", [])
        vqa_bullets.extend(seg_bullets)

    if not vqa_bullets:
        raise ValueError(
            "No VQA bullets generated for single frames available for answering questions."
        )

    include_questions = bool(list_of_questions)
    if include_questions:
        prompt = self.prompt_builder.build_video_prompt(
            include_vqa=include_questions,
            questions=list_of_questions,
            vqa_bullets=vqa_bullets,
            clip_summaries=summary_bullets,
        )
    else:
        raise ValueError(
            "list_of_questions must be provided for making final answers."
        )

    proc = self.summary_model.processor
    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }
    ]
    final_vqa_prompt_text = proc.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    final_vqa_inputs = proc(
        text=[final_vqa_prompt_text], return_tensors="pt", padding=True
    )
    final_vqa_inputs = {
        k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
        for k, v in final_vqa_inputs.items()
    }

    final_vqa_list = self._generate_from_processor_inputs(
        final_vqa_inputs,
        [final_vqa_prompt_text],
        self.summary_model.tokenizer,
    )

    final_vqa_output = final_vqa_list[0].strip() if final_vqa_list else ""
    vqa_answers = []
    answer_matches = re.findall(
        r"\d+\.\s+(.+?)(?=\n\d+\.|$)", final_vqa_output, flags=re.DOTALL
    )
    for answer in answer_matches:
        vqa_answers.append(answer.strip())
    return {
        "vqa_answers": vqa_answers,
    }

final_summary(summary_dict)

Produce a concise summary of the video, based on generated captions for all extracted frames. Args: summary_dict (Dict[str, Any]): Dictionary containing captions for the frames. Returns: Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.

Source code in ammico/video_summary.py
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
def final_summary(self, summary_dict: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Produce a concise summary of the video, based on generated captions for all extracted frames.
    Args:
        summary_dict (Dict[str, Any]): Dictionary containing captions for the frames.
    Returns:
        Dict[str, Any]: A dictionary containing the list of captions with timestamps and the final summary.
    """
    proc = self.summary_model.processor

    bullets = []
    for seg in summary_dict:
        seg_bullets = seg.get("summary_bullets", [])
        bullets.extend(seg_bullets)
    if not bullets:
        raise ValueError("No captions available for summary generation.")

    summary_user_prompt = self.prompt_builder.build_video_prompt(
        include_vqa=False,
        clip_summaries=bullets,
    )
    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": summary_user_prompt}],
        }
    ]

    summary_prompt_text = proc.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    summary_inputs = proc(
        text=[summary_prompt_text], return_tensors="pt", padding=True
    )

    summary_inputs = {
        k: v.to(self.summary_model.device) if isinstance(v, torch.Tensor) else v
        for k, v in summary_inputs.items()
    }
    final_summary_list = self._generate_from_processor_inputs(
        summary_inputs,
        [summary_prompt_text],
        self.summary_model.tokenizer,
    )
    final_summary = final_summary_list[0].strip() if final_summary_list else ""

    return {
        "summary": final_summary,
    }

make_captions_for_subclips(entry, list_of_questions=None)

Generate captions for video subclips using both audio and visual information, for a further full video summary/VQA. Args: entry (Dict[str, Any]): Dictionary containing the video file information. list_of_questions (Optional[List[str]]): List of questions for VQA. Returns: List[Dict[str, Any]]: List of dictionaries containing timestamps and generated captions.

Source code in ammico/video_summary.py
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
def make_captions_for_subclips(
    self,
    entry: Dict[str, Any],
    list_of_questions: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
    """
    Generate captions for video subclips using both audio and visual information, for a further full video summary/VQA.
    Args:
        entry (Dict[str, Any]): Dictionary containing the video file information.
        list_of_questions (Optional[List[str]]): List of questions for VQA.
    Returns:
        List[Dict[str, Any]]: List of dictionaries containing timestamps and generated captions.
    """

    filename = entry.get("filename")
    if not filename:
        raise ValueError("entry must contain key 'filename'")

    if not os.path.exists(filename):
        raise ValueError(f"Video file {filename} does not exist.")

    audio_generated_captions = []
    if self.audio_model is not None:
        audio_generated_captions = self._extract_transcribe_audio_part(filename)
        entry["audio_descriptions"] = audio_generated_captions

    video_result_segments = self._extract_frame_timestamps_from_clip(filename)
    video_segments_w_timestamps = video_result_segments["segments"]
    video_meta = video_result_segments["video_meta"]
    merged_segments = self.merge_audio_visual_boundaries(
        audio_generated_captions,
        video_segments_w_timestamps,
    )

    self._make_captions_from_extracted_frames(
        filename,
        merged_segments,
        video_meta,
        list_of_questions=list_of_questions,
    )
    results = []
    proc = self.summary_model.processor
    for seg in merged_segments:
        frame_timestamps = seg.get("video_frame_timestamps", [])

        collected: List[Tuple[float, str]] = []
        include_audio = False
        audio_lines = seg["audio_phrases"]
        if audio_lines:
            include_audio = True

        include_questions = bool(list_of_questions)
        caption_instruction = self.prompt_builder.build_clip_prompt(
            frame_bullets=seg.get("summary_bullets", []),
            include_audio=include_audio,
            audio_transcription=seg.get("audio_phrases", []),
            include_vqa=include_questions,
            questions=list_of_questions,
            vqa_bullets=seg.get("vqa_bullets", []),
        )
        messages = [
            {
                "role": "user",
                "content": [{"type": "text", "text": caption_instruction}],
            }
        ]
        prompt_text = proc.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        processor_inputs = proc(
            text=[prompt_text],
            return_tensors="pt",
            padding=True,
        )
        final_outputs = self._generate_from_processor_inputs(
            processor_inputs,
            [prompt_text],
            self.summary_model.tokenizer,
        )
        for t, c in zip(frame_timestamps, final_outputs):
            collected.append((float(t), c))

        collected.sort(key=lambda x: x[0])
        bullets_summary, bullets_vqa = _categorize_outputs(
            collected, include_questions
        )

        results.append(
            {
                "start_time": seg["start_time"],
                "end_time": seg["end_time"],
                "summary_bullets": bullets_summary,
                "vqa_bullets": bullets_vqa,
            }
        )

    return results

merge_audio_visual_boundaries(audio_segs, video_segs, segment_threshold_duration=8)

Merge audio phrase boundaries and video scene cuts into coherent temporal segments for the model Args: audio_segs: List of audio segments with 'start_time' and 'end_time' video_segs: List of video segments with 'start_time' and 'end_time' segment_threshold_duration: Duration to create a new segment boundary Returns: List of merged segments

Source code in ammico/video_summary.py
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
def merge_audio_visual_boundaries(
    self,
    audio_segs: List[Dict[str, Any]],
    video_segs: List[Dict[str, Any]],
    segment_threshold_duration: int = 8,
) -> List[Dict[str, Any]]:
    """
    Merge audio phrase boundaries and video scene cuts into coherent temporal segments for the model
    Args:
        audio_segs: List of audio segments with 'start_time' and 'end_time'
        video_segs: List of video segments with 'start_time' and 'end_time'
        segment_threshold_duration: Duration to create a new segment boundary
    Returns:
        List of merged segments
    """
    if not audio_segs:
        new_vid = self._combine_visual_frames_by_time(video_segs)
        return new_vid

    events = [
        ("audio", seg["start_time"], seg["end_time"], seg) for seg in audio_segs
    ] + [("video", seg["start_time"], seg["end_time"], seg) for seg in video_segs]

    if not events:
        raise ValueError("No audio and video segments to merge.")

    events.sort(key=lambda x: x[1])
    global_last_end = max(e[2] for e in events)
    # Create merged segments respecting both boundaries
    merged = []
    current_segment_start = 0
    current_audio_phrases = []
    current_video_scenes = []

    for event_type, start, _, data in events:
        current_duration = start - current_segment_start
        if current_duration > segment_threshold_duration:
            segment_end = start

            if segment_end < current_segment_start:
                segment_end = current_segment_start

            merged.append(
                {
                    "start_time": current_segment_start,
                    "end_time": segment_end,
                    "audio_phrases": current_audio_phrases,
                    "video_scenes": current_video_scenes,
                    "duration": segment_end - current_segment_start,
                }
            )
            # start a new segment at the current event's start
            current_segment_start = segment_end
            current_audio_phrases = []
            current_video_scenes = []

        if event_type == "audio":
            current_audio_phrases.append(data)
        else:
            current_video_scenes.append(data)

    if current_audio_phrases or current_video_scenes:
        final_end = max(global_last_end, events[-1][2], current_segment_start)
        if final_end < current_segment_start:
            final_end = current_segment_start

        merged.append(
            {
                "start_time": current_segment_start,
                "end_time": final_end,
                "audio_phrases": current_audio_phrases,
                "video_scenes": current_video_scenes,
                "duration": final_end - current_segment_start,
            }
        )

    self._reassign_video_timestamps_to_segments(merged, video_segs)
    return merged

Colors

ColorDetector

Bases: AnalysisMethod

Source code in ammico/colors.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class ColorDetector(AnalysisMethod):
    def __init__(
        self,
        subdict: dict,
        delta_e_method: str = "CIE 1976",
    ) -> None:
        """Color Analysis class, analyse hue and identify named colors.

        Args:
            subdict (dict): The dictionary containing the image path.
            delta_e_method (str): The calculation method used for assigning the
                closest color name, defaults to "CIE 1976".
                The available options are: 'CIE 1976', 'CIE 1994', 'CIE 2000',
                'CMC', 'ITP', 'CAM02-LCD', 'CAM02-SCD', 'CAM02-UCS', 'CAM16-LCD',
                'CAM16-SCD', 'CAM16-UCS', 'DIN99'
        """
        super().__init__(subdict)
        self.subdict.update(self.set_keys())
        self.merge_color = True
        self.n_colors = 100
        if delta_e_method not in COLOR_SCHEMES:
            raise ValueError(
                "Invalid selection for assigning the color name. Please select one of {}".format(
                    COLOR_SCHEMES
                )
            )
        self.delta_e_method = delta_e_method

    def set_keys(self) -> dict:
        colors = {
            "red": 0,
            "green": 0,
            "blue": 0,
            "yellow": 0,
            "cyan": 0,
            "orange": 0,
            "purple": 0,
            "pink": 0,
            "brown": 0,
            "grey": 0,
            "white": 0,
            "black": 0,
        }
        return colors

    def analyse_image(self) -> dict:
        """
        Uses the colorgram library to extract the n most common colors from the images.
        One problem is, that the most common colors are taken before beeing categorized,
        so for small values it might occur that the ten most common colors are shades of grey,
        while other colors are present but will be ignored. Because of this n_colors=100 was chosen as default.

        The colors are then matched to the closest color in the CSS3 color list using the delta-e metric.
        They are then merged into one data frame.
        The colors can be reduced to a smaller list of colors using the get_color_table function.
        These colors are: "red", "green", "blue", "yellow","cyan", "orange", "purple", "pink", "brown", "grey", "white", "black".

        Returns:
            dict: Dictionary with color names as keys and percentage of color in image as values.
        """
        filename = self.subdict["filename"]

        colors = colorgram.extract(filename, self.n_colors)
        for color in colors:
            rgb_name = self.rgb2name(
                color.rgb,
                merge_color=self.merge_color,
                delta_e_method=self.delta_e_method,
            )
            self.subdict[rgb_name] += color.proportion

        # ensure color rounding
        for key in self.set_keys().keys():
            if self.subdict[key]:
                self.subdict[key] = round(self.subdict[key], 2)

        return self.subdict

    def rgb2name(
        self, c, merge_color: bool = True, delta_e_method: str = "CIE 1976"
    ) -> str:
        """Take an rgb color as input and return the closest color name from the CSS3 color list.

        Args:
            c (Union[List,tuple]): RGB value.
            merge_color (bool, Optional): Whether color name should be reduced, defaults to True.
        Returns:
            str: Closest matching color name.
        """
        if len(c) != 3:
            raise ValueError("Input color must be a list or tuple of length 3 (RGB).")

        h_color = "#{:02x}{:02x}{:02x}".format(int(c[0]), int(c[1]), int(c[2]))
        try:
            output_color = webcolors.hex_to_name(h_color, spec="css3")
            output_color = output_color.lower().replace("grey", "gray")
        except ValueError:
            delta_e_lst = []
            filtered_colors = webcolors._definitions._CSS3_NAMES_TO_HEX

            for _, img_hex in filtered_colors.items():
                cur_clr = webcolors.hex_to_rgb(img_hex)
                # calculate color Delta-E
                delta_e = colour.delta_E(c, cur_clr, method=delta_e_method)
                delta_e_lst.append(delta_e)
            # find lowest delta-e
            min_diff = np.argsort(delta_e_lst)[0]
            output_color = (
                str(list(filtered_colors.items())[min_diff][0])
                .lower()
                .replace("grey", "gray")
            )

        # match color to reduced list:
        if merge_color:
            for reduced_key, reduced_color_sub_list in get_color_table().items():
                if str(output_color).lower() in [
                    str(color_name).lower()
                    for color_name in reduced_color_sub_list["ColorName"]
                ]:
                    output_color = reduced_key.lower()
                    break
        return output_color

__init__(subdict, delta_e_method='CIE 1976')

Color Analysis class, analyse hue and identify named colors.

Parameters:

Name Type Description Default
subdict dict

The dictionary containing the image path.

required
delta_e_method str

The calculation method used for assigning the closest color name, defaults to "CIE 1976". The available options are: 'CIE 1976', 'CIE 1994', 'CIE 2000', 'CMC', 'ITP', 'CAM02-LCD', 'CAM02-SCD', 'CAM02-UCS', 'CAM16-LCD', 'CAM16-SCD', 'CAM16-UCS', 'DIN99'

'CIE 1976'
Source code in ammico/colors.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
    self,
    subdict: dict,
    delta_e_method: str = "CIE 1976",
) -> None:
    """Color Analysis class, analyse hue and identify named colors.

    Args:
        subdict (dict): The dictionary containing the image path.
        delta_e_method (str): The calculation method used for assigning the
            closest color name, defaults to "CIE 1976".
            The available options are: 'CIE 1976', 'CIE 1994', 'CIE 2000',
            'CMC', 'ITP', 'CAM02-LCD', 'CAM02-SCD', 'CAM02-UCS', 'CAM16-LCD',
            'CAM16-SCD', 'CAM16-UCS', 'DIN99'
    """
    super().__init__(subdict)
    self.subdict.update(self.set_keys())
    self.merge_color = True
    self.n_colors = 100
    if delta_e_method not in COLOR_SCHEMES:
        raise ValueError(
            "Invalid selection for assigning the color name. Please select one of {}".format(
                COLOR_SCHEMES
            )
        )
    self.delta_e_method = delta_e_method

analyse_image()

Uses the colorgram library to extract the n most common colors from the images. One problem is, that the most common colors are taken before beeing categorized, so for small values it might occur that the ten most common colors are shades of grey, while other colors are present but will be ignored. Because of this n_colors=100 was chosen as default.

The colors are then matched to the closest color in the CSS3 color list using the delta-e metric. They are then merged into one data frame. The colors can be reduced to a smaller list of colors using the get_color_table function. These colors are: "red", "green", "blue", "yellow","cyan", "orange", "purple", "pink", "brown", "grey", "white", "black".

Returns:

Name Type Description
dict dict

Dictionary with color names as keys and percentage of color in image as values.

Source code in ammico/colors.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def analyse_image(self) -> dict:
    """
    Uses the colorgram library to extract the n most common colors from the images.
    One problem is, that the most common colors are taken before beeing categorized,
    so for small values it might occur that the ten most common colors are shades of grey,
    while other colors are present but will be ignored. Because of this n_colors=100 was chosen as default.

    The colors are then matched to the closest color in the CSS3 color list using the delta-e metric.
    They are then merged into one data frame.
    The colors can be reduced to a smaller list of colors using the get_color_table function.
    These colors are: "red", "green", "blue", "yellow","cyan", "orange", "purple", "pink", "brown", "grey", "white", "black".

    Returns:
        dict: Dictionary with color names as keys and percentage of color in image as values.
    """
    filename = self.subdict["filename"]

    colors = colorgram.extract(filename, self.n_colors)
    for color in colors:
        rgb_name = self.rgb2name(
            color.rgb,
            merge_color=self.merge_color,
            delta_e_method=self.delta_e_method,
        )
        self.subdict[rgb_name] += color.proportion

    # ensure color rounding
    for key in self.set_keys().keys():
        if self.subdict[key]:
            self.subdict[key] = round(self.subdict[key], 2)

    return self.subdict

rgb2name(c, merge_color=True, delta_e_method='CIE 1976')

Take an rgb color as input and return the closest color name from the CSS3 color list.

Parameters:

Name Type Description Default
c Union[List, tuple]

RGB value.

required
merge_color (bool, Optional)

Whether color name should be reduced, defaults to True.

True

Returns: str: Closest matching color name.

Source code in ammico/colors.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def rgb2name(
    self, c, merge_color: bool = True, delta_e_method: str = "CIE 1976"
) -> str:
    """Take an rgb color as input and return the closest color name from the CSS3 color list.

    Args:
        c (Union[List,tuple]): RGB value.
        merge_color (bool, Optional): Whether color name should be reduced, defaults to True.
    Returns:
        str: Closest matching color name.
    """
    if len(c) != 3:
        raise ValueError("Input color must be a list or tuple of length 3 (RGB).")

    h_color = "#{:02x}{:02x}{:02x}".format(int(c[0]), int(c[1]), int(c[2]))
    try:
        output_color = webcolors.hex_to_name(h_color, spec="css3")
        output_color = output_color.lower().replace("grey", "gray")
    except ValueError:
        delta_e_lst = []
        filtered_colors = webcolors._definitions._CSS3_NAMES_TO_HEX

        for _, img_hex in filtered_colors.items():
            cur_clr = webcolors.hex_to_rgb(img_hex)
            # calculate color Delta-E
            delta_e = colour.delta_E(c, cur_clr, method=delta_e_method)
            delta_e_lst.append(delta_e)
        # find lowest delta-e
        min_diff = np.argsort(delta_e_lst)[0]
        output_color = (
            str(list(filtered_colors.items())[min_diff][0])
            .lower()
            .replace("grey", "gray")
        )

    # match color to reduced list:
    if merge_color:
        for reduced_key, reduced_color_sub_list in get_color_table().items():
            if str(output_color).lower() in [
                str(color_name).lower()
                for color_name in reduced_color_sub_list["ColorName"]
            ]:
                output_color = reduced_key.lower()
                break
    return output_color

Utils

AnalysisMethod

Base class to be inherited by all analysis methods.

Source code in ammico/utils.py
85
86
87
88
89
90
91
92
93
94
95
96
class AnalysisMethod:
    """Base class to be inherited by all analysis methods."""

    def __init__(self, subdict: dict) -> None:
        self.subdict = subdict
        # define keys that will be set by the analysis

    def set_keys(self):
        raise NotImplementedError()

    def analyse_image(self):
        raise NotImplementedError()

DownloadResource

A remote resource that needs on demand downloading.

We use this as a wrapper to the pooch library. The wrapper registers each data file and allows prefetching through the CLI entry point ammico_prefetch_models.

Source code in ammico/utils.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class DownloadResource:
    """A remote resource that needs on demand downloading.

    We use this as a wrapper to the pooch library. The wrapper registers
    each data file and allows prefetching through the CLI entry point
    ammico_prefetch_models.
    """

    # We store a list of defined resouces in a class variable, allowing
    # us prefetching from a CLI e.g. to bundle into a Docker image
    resources = []

    def __init__(self, **kwargs):
        DownloadResource.resources.append(self)
        self.kwargs = kwargs

    def get(self):
        return pooch.retrieve(**self.kwargs)

ammico_prefetch_models()

Prefetch all the download resources

Source code in ammico/utils.py
44
45
46
47
def ammico_prefetch_models():
    """Prefetch all the download resources"""
    for res in DownloadResource.resources:
        res.get()

append_data_to_dict(mydict)

Append entries from nested dictionaries to keys in a global dict.

Source code in ammico/utils.py
438
439
440
441
442
443
444
445
446
447
def append_data_to_dict(mydict: dict) -> dict:
    """Append entries from nested dictionaries to keys in a global dict."""

    # first initialize empty list for each key that is present
    outdict = {key: [] for key in list(mydict.values())[0].keys()}
    # now append the values to each key in a list
    for subdict in mydict.values():
        for key in subdict.keys():
            outdict[key].append(subdict[key])
    return outdict

dump_df(mydict)

Utility to dump the dictionary into a dataframe.

Source code in ammico/utils.py
450
451
452
def dump_df(mydict: dict) -> DataFrame:
    """Utility to dump the dictionary into a dataframe."""
    return DataFrame.from_dict(mydict)

find_files(path=None, pattern=None, recursive=True, limit=20, random_seed=None, return_as_list=False)

Find image files on the file system.

Parameters:

Name Type Description Default
path str

The base directory where we are looking for the images. Defaults to None, which uses the ammico data directory if set or the current working directory otherwise.

None
pattern str | list

The naming pattern that the filename should match. Use either '.ext' or just 'ext' Defaults to ["png", "jpg", "jpeg", "gif", "webp", "avif","tiff"]. Can be used to allow other patterns or to only include specific prefixes or suffixes.

None
recursive bool

Whether to recurse into subdirectories. Default is set to True.

True
limit int / list

The maximum number of images to be found. Provide a list or tuple of length 2 to batch the images. Defaults to 20. To return all images, set to None or -1.

20
random_seed int

The random seed to use for shuffling the images. If None is provided the data will not be shuffeled. Defaults to None.

None
return_as_list bool

Whether to return the list of files instead of a dict. Defaults to False.

False

Returns: dict: A nested dictionary with file ids and all filenames including the path. Or list: A list of file paths if return_as_list is set to True.

Source code in ammico/utils.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def find_files(
    path: Optional[Union[str, Path, None]] = None,
    pattern: Optional[Iterable[str]] = None,
    recursive: bool = True,
    limit=20,
    random_seed: int = None,
    return_as_list: bool = False,
) -> Union[dict, list]:
    """Find image files on the file system.

    Args:
        path (str, optional): The base directory where we are looking for the images. Defaults
            to None, which uses the ammico data directory if set or the current
            working directory otherwise.
        pattern (str|list, optional): The naming pattern that the filename should match.
                Use either '.ext' or just 'ext'
                Defaults to ["png", "jpg", "jpeg", "gif", "webp", "avif","tiff"]. Can be used to allow other patterns or to only include
                specific prefixes or suffixes.
        recursive (bool, optional): Whether to recurse into subdirectories. Default is set to True.
        limit (int/list, optional): The maximum number of images to be found.
            Provide a list or tuple of length 2 to batch the images.
            Defaults to 20. To return all images, set to None or -1.
        random_seed (int, optional): The random seed to use for shuffling the images.
            If None is provided the data will not be shuffeled. Defaults to None.
        return_as_list (bool, optional): Whether to return the list of files instead of a dict.
            Defaults to False.
    Returns:
        dict: A nested dictionary with file ids and all filenames including the path.
        Or
        list: A list of file paths if return_as_list is set to True.
    """

    if path is None:
        path = os.environ.get("AMMICO_DATA_HOME", ".")
    if pattern is None:
        pattern = ["png", "jpg", "jpeg", "gif", "webp", "avif", "tiff"]

    if isinstance(pattern, str):
        pattern = [pattern]
    results = []
    for p in pattern:
        results.extend(_match_pattern(path, p, recursive=recursive))

    if len(results) == 0:
        raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'")

    if random_seed is not None:
        random.seed(random_seed)
        random.shuffle(results)

    images = _limit_results(results, limit)

    if return_as_list:
        return images

    return initialize_dict(images)

find_videos(path=None, pattern=['mp4', 'mov', 'avi', 'mkv', 'webm'], recursive=True, limit=5, random_seed=None)

Find video files on the file system.

Source code in ammico/utils.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def find_videos(
    path: str = None,
    pattern=["mp4", "mov", "avi", "mkv", "webm"],
    recursive: bool = True,
    limit=5,
    random_seed: int = None,
) -> dict:
    """Find video files on the file system."""
    if path is None:
        path = os.environ.get("AMMICO_DATA_HOME", ".")
    if isinstance(pattern, str):
        pattern = [pattern]
    results = []
    for p in pattern:
        results.extend(_match_pattern(path, p, recursive=recursive))
    if len(results) == 0:
        raise FileNotFoundError(f"No files found in {path} with pattern '{pattern}'")
    if random_seed is not None:
        random.seed(random_seed)
        random.shuffle(results)
    videos = _limit_results(results, limit)
    return initialize_dict(videos)

get_supported_whisperx_languages()

Get the list of supported whisperx languages.

Source code in ammico/utils.py
483
484
485
486
487
488
def get_supported_whisperx_languages() -> List[str]:
    """Get the list of supported whisperx languages."""
    supported_languages = set(DEFAULT_ALIGN_MODELS_TORCH.keys()) | set(
        DEFAULT_ALIGN_MODELS_HF.keys()
    )
    return sorted(supported_languages)

initialize_dict(filelist)

Initialize the nested dictionary for all the found images.

Parameters:

Name Type Description Default
filelist list

The list of files to be analyzed, including their paths.

required

Returns: dict: The nested dictionary with all image ids and their paths.

Source code in ammico/utils.py
376
377
378
379
380
381
382
383
384
385
386
387
def initialize_dict(filelist: list) -> dict:
    """Initialize the nested dictionary for all the found images.

    Args:
        filelist (list): The list of files to be analyzed, including their paths.
    Returns:
        dict: The nested dictionary with all image ids and their paths."""
    mydict = {}
    for img_path in filelist:
        id_ = os.path.splitext(os.path.basename(img_path))[0]
        mydict[id_] = {"filename": img_path}
    return mydict

is_interactive()

Check if we are running in an interactive environment.

Source code in ammico/utils.py
461
462
463
464
465
def is_interactive():
    """Check if we are running in an interactive environment."""
    import __main__ as main

    return not hasattr(main, "__file__")

load_image(image_path)

Load image from file path or return if already PIL Image.

Source code in ammico/utils.py
491
492
493
494
495
496
497
498
499
500
def load_image(image_path: Union[str, Path, Image.Image]) -> Image.Image:
    """Load image from file path or return if already PIL Image."""
    if isinstance(image_path, Image.Image):
        return image_path

    image_path = Path(image_path)
    if not image_path.exists():
        raise FileNotFoundError(f"Image file not found: {image_path}")

    return Image.open(image_path).convert("RGB")

prepare_image(image, target_size=(512, 512), resize_mode='resize')

Prepare image for model input with optimal resolution.

Source code in ammico/utils.py
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def prepare_image(
    image: Image.Image,
    target_size: Tuple[int, int] = (512, 512),
    resize_mode: str = "resize",
) -> Image.Image:
    """Prepare image for model input with optimal resolution."""
    width, height = image.size
    target_w, target_h = target_size

    if resize_mode == "center_crop":
        scale = max(target_w / width, target_h / height)
        new_width = int(width * scale)
        new_height = int(height * scale)
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

        left = (new_width - target_w) // 2
        top = (new_height - target_h) // 2
        image = image.crop((left, top, left + target_w, top + target_h))
    else:
        image = image.resize(target_size, Image.Resampling.LANCZOS)

    return image

Display

AnalysisExplorer

Source code in ammico/display.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
class AnalysisExplorer:
    def __init__(self, mydict: dict) -> None:
        """Initialize the AnalysisExplorer class to create an interactive
        visualization of the analysis results.

        Args:
            mydict (dict): A nested dictionary containing image data for all images.

        """
        self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
        self.mydict = mydict
        self.theme = {
            "scheme": "monokai",
            "author": "wimer hazenberg (http://www.monokai.nl)",
            "base00": "#272822",
            "base01": "#383830",
            "base02": "#49483e",
            "base03": "#75715e",
            "base04": "#a59f85",
            "base05": "#f8f8f2",
            "base06": "#f5f4f1",
            "base07": "#f9f8f5",
            "base08": "#f92672",
            "base09": "#fd971f",
            "base0A": "#f4bf75",
            "base0B": "#a6e22e",
            "base0C": "#a1efe4",
            "base0D": "#66d9ef",
            "base0E": "#ae81ff",
            "base0F": "#cc6633",
        }

        # Setup the layout
        app_layout = html.Div(
            [
                # Top row, only file explorer
                dbc.Row(
                    [dbc.Col(self._top_file_explorer(mydict))],
                    id="Div_top",
                    style={
                        "width": "30%",
                    },
                ),
                # second row, middle picture and right output
                dbc.Row(
                    [
                        # first column: picture
                        dbc.Col(self._middle_picture_frame()),
                        dbc.Col(self._right_output_json()),
                    ]
                ),
            ],
            # style={"width": "95%", "display": "inline-block"},
        )
        self.app.layout = app_layout

        # Add callbacks to the app
        self.app.callback(
            Output("img_middle_picture_id", "src"),
            Input("left_select_id", "value"),
            prevent_initial_call=True,
        )(self.update_picture)

        self.app.callback(
            Output("right_json_viewer", "children"),
            Input("button_run", "n_clicks"),
            State("left_select_id", "options"),
            State("left_select_id", "value"),
            State("Dropdown_select_Detector", "value"),
            State("Dropdown_analysis_type", "value"),
            State("textarea_questions", "value"),
            State("setting_privacy_env_var", "value"),
            State("setting_Color_delta_e_method", "value"),
            prevent_initial_call=True,
        )(self._right_output_analysis)

        self.app.callback(
            Output("settings_TextDetector", "style"),
            Output("settings_ColorDetector", "style"),
            Output("settings_VQA", "style"),
            Input("Dropdown_select_Detector", "value"),
        )(self._update_detector_setting)

        self.app.callback(
            Output("textarea_questions", "style"),
            Input("Dropdown_analysis_type", "value"),
        )(self._show_questions_textarea_on_demand)

    # I split the different sections into subfunctions for better clarity
    def _top_file_explorer(self, mydict: dict) -> html.Div:
        """Initialize the file explorer dropdown for selecting the file to be analyzed.

        Args:
            mydict (dict): A dictionary containing image data.

        Returns:
            html.Div: The layout for the file explorer dropdown.
        """
        left_layout = html.Div(
            [
                dcc.Dropdown(
                    options={value["filename"]: key for key, value in mydict.items()},
                    id="left_select_id",
                )
            ]
        )
        return left_layout

    def _middle_picture_frame(self) -> html.Div:
        """Initialize the picture frame to display the image.

        Returns:
            html.Div: The layout for the picture frame.
        """
        middle_layout = html.Div(
            [
                html.Img(
                    id="img_middle_picture_id",
                    style={
                        "width": "80%",
                    },
                )
            ]
        )
        return middle_layout

    def _create_setting_layout(self):
        settings_layout = html.Div(
            [
                # text summary start
                html.Div(
                    id="settings_TextDetector",
                    style={"display": "none"},
                    children=[
                        # row 1
                        dbc.Row(
                            dbc.Col(
                                [
                                    html.P(
                                        "Privacy disclosure acceptance environment variable"
                                    ),
                                    dcc.Input(
                                        type="text",
                                        value="PRIVACY_AMMICO",
                                        id="setting_privacy_env_var",
                                        style={"width": "100%"},
                                    ),
                                ],
                                align="start",
                            ),
                        ),
                    ],
                ),  # text summary end
                html.Div(
                    id="settings_ColorDetector",
                    style={"display": "none"},
                    children=[
                        html.Div(
                            [
                                dcc.Dropdown(
                                    options=COLOR_SCHEMES,
                                    value="CIE 1976",
                                    id="setting_Color_delta_e_method",
                                )
                            ],
                            style={
                                "width": "49%",
                                "display": "inline-block",
                                "margin-top": "10px",
                            },
                        )
                    ],
                ),
                # start VQA settings
                html.Div(
                    id="settings_VQA",
                    style={"display": "none"},
                    children=[
                        dbc.Card(
                            [
                                dbc.CardBody(
                                    [
                                        dbc.Row(
                                            dbc.Col(
                                                dcc.Dropdown(
                                                    id="Dropdown_analysis_type",
                                                    options=[
                                                        {"label": v, "value": v}
                                                        for v in SUMMARY_ANALYSIS_TYPE
                                                    ],
                                                    value="summary_and_questions",
                                                    clearable=False,
                                                    style={
                                                        "width": "100%",
                                                        "minWidth": "240px",
                                                        "maxWidth": "520px",
                                                    },
                                                ),
                                            ),
                                            justify="start",
                                        ),
                                        html.Div(style={"height": "8px"}),
                                        dbc.Row(
                                            [
                                                dbc.Col(
                                                    dcc.Textarea(
                                                        id="textarea_questions",
                                                        value="Are there people in the image?\nWhat is this picture about?",
                                                        placeholder="One question per line...",
                                                        style={
                                                            "width": "100%",
                                                            "minHeight": "160px",
                                                            "height": "220px",
                                                            "resize": "vertical",
                                                            "overflow": "auto",
                                                        },
                                                        rows=8,
                                                    ),
                                                    width=12,
                                                ),
                                            ],
                                            justify="start",
                                        ),
                                    ]
                                )
                            ],
                            style={
                                "width": "100%",
                                "marginTop": "10px",
                                "zIndex": 2000,
                            },
                        )
                    ],
                ),
            ],
            style={"width": "100%", "display": "inline-block", "overflow": "visible"},
        )
        return settings_layout

    def _right_output_json(self) -> html.Div:
        """Initialize the DetectorDropdown, argument Div and JSON viewer for displaying the analysis output.

        Returns:
            html.Div: The layout for the JSON viewer.
        """
        right_layout = html.Div(
            [
                dbc.Col(
                    [
                        dbc.Row(
                            dcc.Dropdown(
                                options=[
                                    "TextDetector",
                                    "ColorDetector",
                                    "VQA",
                                ],
                                value="TextDetector",
                                id="Dropdown_select_Detector",
                                style={"width": "60%"},
                            ),
                            justify="start",
                        ),
                        dbc.Row(
                            children=[self._create_setting_layout()],
                            id="div_detector_args",
                            justify="start",
                        ),
                        dbc.Row(
                            html.Button(
                                "Run Detector",
                                id="button_run",
                                style={
                                    "margin-top": "15px",
                                    "margin-bottom": "15px",
                                    "margin-left": "11px",
                                    "width": "30%",
                                },
                            ),
                            justify="start",
                        ),
                        dbc.Row(
                            dcc.Loading(
                                id="loading-2",
                                children=[
                                    # This is where the json is shown.
                                    html.Div(id="right_json_viewer"),
                                ],
                                type="circle",
                            ),
                            justify="start",
                        ),
                    ],
                    align="start",
                )
            ]
        )
        return right_layout

    def run_server(self, port: int = 8050) -> None:
        """Run the Dash server to start the analysis explorer.


        Args:
            port (int, optional): The port number to run the server on (default: 8050).
        """

        self.app.run(debug=True, port=port)

    # Dash callbacks
    def update_picture(self, img_path: str) -> Optional[Image.Image]:
        """Callback function to update the displayed image.

        Args:
            img_path (str): The path of the selected image.

        Returns:
            Union[PIL.PngImagePlugin, None]: The image object to be displayed
                or None if the image path is

        """
        if img_path is not None:
            image = Image.open(img_path)
            return image
        else:
            return None

    def _update_detector_setting(self, setting_input):
        # return settings_TextDetector -> style,
        display_none = {"display": "none"}
        display_flex = {
            "display": "flex",
            "flexWrap": "wrap",
            "width": 400,
            "margin-top": "20px",
        }

        if setting_input == "TextDetector":
            return display_flex, display_none, display_none, display_none
        if setting_input == "ColorDetector":
            return display_none, display_none, display_flex, display_none
        if setting_input == "VQA":
            return display_none, display_none, display_none, display_flex
        else:
            return display_none, display_none, display_none, display_none

    def _parse_questions(self, text: Optional[str]) -> Optional[List[str]]:
        if not text:
            return None
        qs = [q.strip() for q in text.splitlines() if q.strip()]
        return qs if qs else None

    def _right_output_analysis(
        self,
        n_clicks,
        all_img_options: dict,
        current_img_value: str,
        detector_value: str,
        analysis_type_value: str,
        textarea_questions_value: str,
        setting_privacy_env_var: str,
        setting_color_delta_e_method: str,
    ) -> dict:
        """Callback function to perform analysis on the selected image and return the output.

        Args:
            all_options (dict): The available options in the file explorer dropdown.
            current_value (str): The current selected value in the file explorer dropdown.

        Returns:
            dict: The analysis output for the selected image.
        """
        identify_dict = {
            "TextDetector": text.TextDetector,
            "ColorDetector": colors.ColorDetector,
            "VQA": image_summary.ImageSummaryDetector,
        }

        # Get image ID from dropdown value, which is the filepath
        if current_img_value is None:
            return {}
        image_id = all_img_options[current_img_value]
        image_copy = self.mydict.get(image_id, {}).copy()

        analysis_dict: Dict[str, Any] = {}
        if detector_value == "VQA":
            try:
                qwen_model = MultimodalSummaryModel(
                    model_id="Qwen/Qwen2.5-VL-3B-Instruct"
                )  # TODO: allow user to specify model
                vqa_cls = identify_dict.get("VQA")
                vqa_detector = vqa_cls(qwen_model, subdict={})
                questions_list = self._parse_questions(textarea_questions_value)
                analysis_result = vqa_detector.analyse_image(
                    image_copy,
                    analysis_type=analysis_type_value,
                    list_of_questions=questions_list,
                    is_concise_summary=True,
                    is_concise_answer=True,
                )
                analysis_dict = analysis_result or {}
            except Exception as e:
                warnings.warn(f"VQA/Image tasks failed: {e}")
                analysis_dict = {"image_tasks_error": str(e)}
        else:
            # detector value is the string name of the chosen detector
            identify_function = identify_dict[detector_value]

            if detector_value == "TextDetector":
                detector_class = identify_function(
                    image_copy,
                    accept_privacy=(
                        setting_privacy_env_var
                        if setting_privacy_env_var
                        else "PRIVACY_AMMICO"
                    ),
                )
            elif detector_value == "ColorDetector":
                detector_class = identify_function(
                    image_copy,
                    delta_e_method=setting_color_delta_e_method,
                )
            else:
                detector_class = identify_function(image_copy)

            analysis_dict = detector_class.analyse_image()

        new_analysis_dict: Dict[str, Any] = {}

        # Iterate over the items in the original dictionary
        for k, v in analysis_dict.items():
            # Check if the value is a list
            if isinstance(v, list):
                # If it is, convert each item in the list to a string and join them with a comma
                new_value = ", ".join([str(f) for f in v])
            else:
                # If it's not a list, keep the value as it is
                new_value = v

            # Add the new key-value pair to the new dictionary
            new_analysis_dict[k] = new_value

        df = pd.DataFrame([new_analysis_dict]).set_index("filename").T
        df.index.rename("filename", inplace=True)
        return dbc.Table.from_dataframe(
            df, striped=True, bordered=True, hover=True, index=True
        )

    def _show_questions_textarea_on_demand(self, analysis_type_value: str) -> dict:
        if analysis_type_value in ("questions", "summary_and_questions"):
            return {"display": "block", "width": "100%"}
        else:
            return {"display": "none"}

__init__(mydict)

Initialize the AnalysisExplorer class to create an interactive visualization of the analysis results.

Parameters:

Name Type Description Default
mydict dict

A nested dictionary containing image data for all images.

required
Source code in ammico/display.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __init__(self, mydict: dict) -> None:
    """Initialize the AnalysisExplorer class to create an interactive
    visualization of the analysis results.

    Args:
        mydict (dict): A nested dictionary containing image data for all images.

    """
    self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
    self.mydict = mydict
    self.theme = {
        "scheme": "monokai",
        "author": "wimer hazenberg (http://www.monokai.nl)",
        "base00": "#272822",
        "base01": "#383830",
        "base02": "#49483e",
        "base03": "#75715e",
        "base04": "#a59f85",
        "base05": "#f8f8f2",
        "base06": "#f5f4f1",
        "base07": "#f9f8f5",
        "base08": "#f92672",
        "base09": "#fd971f",
        "base0A": "#f4bf75",
        "base0B": "#a6e22e",
        "base0C": "#a1efe4",
        "base0D": "#66d9ef",
        "base0E": "#ae81ff",
        "base0F": "#cc6633",
    }

    # Setup the layout
    app_layout = html.Div(
        [
            # Top row, only file explorer
            dbc.Row(
                [dbc.Col(self._top_file_explorer(mydict))],
                id="Div_top",
                style={
                    "width": "30%",
                },
            ),
            # second row, middle picture and right output
            dbc.Row(
                [
                    # first column: picture
                    dbc.Col(self._middle_picture_frame()),
                    dbc.Col(self._right_output_json()),
                ]
            ),
        ],
        # style={"width": "95%", "display": "inline-block"},
    )
    self.app.layout = app_layout

    # Add callbacks to the app
    self.app.callback(
        Output("img_middle_picture_id", "src"),
        Input("left_select_id", "value"),
        prevent_initial_call=True,
    )(self.update_picture)

    self.app.callback(
        Output("right_json_viewer", "children"),
        Input("button_run", "n_clicks"),
        State("left_select_id", "options"),
        State("left_select_id", "value"),
        State("Dropdown_select_Detector", "value"),
        State("Dropdown_analysis_type", "value"),
        State("textarea_questions", "value"),
        State("setting_privacy_env_var", "value"),
        State("setting_Color_delta_e_method", "value"),
        prevent_initial_call=True,
    )(self._right_output_analysis)

    self.app.callback(
        Output("settings_TextDetector", "style"),
        Output("settings_ColorDetector", "style"),
        Output("settings_VQA", "style"),
        Input("Dropdown_select_Detector", "value"),
    )(self._update_detector_setting)

    self.app.callback(
        Output("textarea_questions", "style"),
        Input("Dropdown_analysis_type", "value"),
    )(self._show_questions_textarea_on_demand)

run_server(port=8050)

Run the Dash server to start the analysis explorer.

Parameters:

Name Type Description Default
port int

The port number to run the server on (default: 8050).

8050
Source code in ammico/display.py
329
330
331
332
333
334
335
336
337
def run_server(self, port: int = 8050) -> None:
    """Run the Dash server to start the analysis explorer.


    Args:
        port (int, optional): The port number to run the server on (default: 8050).
    """

    self.app.run(debug=True, port=port)

update_picture(img_path)

Callback function to update the displayed image.

Parameters:

Name Type Description Default
img_path str

The path of the selected image.

required

Returns:

Type Description
Optional[Image]

Union[PIL.PngImagePlugin, None]: The image object to be displayed or None if the image path is

Source code in ammico/display.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def update_picture(self, img_path: str) -> Optional[Image.Image]:
    """Callback function to update the displayed image.

    Args:
        img_path (str): The path of the selected image.

    Returns:
        Union[PIL.PngImagePlugin, None]: The image object to be displayed
            or None if the image path is

    """
    if img_path is not None:
        image = Image.open(img_path)
        return image
    else:
        return None

Model

AudioToTextModel

Source code in ammico/model.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class AudioToTextModel:
    def __init__(
        self,
        model_size: str = "large",
        device: Optional[str] = None,
        language: Optional[str] = None,
    ) -> None:
        """
        Class for WhisperX model loading and inference.
        Args:
            model_size: Size of Whisper model to load (small, base, large).
            device: "cuda" or "cpu" (auto-detected when None).
            language: ISO-639-1 language code (e.g., "en", "fr", "de").
                     If None, language will be detected automatically.
                     Set this to avoid unreliable detection on small clips.
        """
        self.device = resolve_model_device(device)

        self.model_size = resolve_model_size(model_size)

        self.model = None

        self.language = self._validate_language(language)

        self._load_model()

    def _validate_language(self, language: Optional[str]) -> Optional[str]:
        """

        Validate the provided language code against whisperx's supported languages.
        Args:
            language: ISO-639-1 language code (e.g., "en", "fr", "de").
        Returns:
            Validated language code or None.
        Raises:
            ValueError: If the language code is invalid or unsupported.
        """

        if not language:
            return None

        language = language.strip().lower()
        supported_languages = get_supported_whisperx_languages()

        if len(language) != 2:
            raise ValueError(
                f"Invalid language code: '{language}'. Language codes must be 2 letters."
            )

        if not language.isalpha():
            raise ValueError(
                f"Invalid language code: '{language}'. Language codes must contain only alphabetic characters."
            )

        if language not in supported_languages:
            raise ValueError(
                f"Unsupported language code: '{language}'. Supported: {sorted(supported_languages)}"
            )

        return language

    def _load_model(self):
        if self.device == "cuda":
            self.model = whisperx.load_model(
                self.model_size,
                device=self.device,
                compute_type="float16",
                language=self.language,
            )
        else:
            self.model = whisperx.load_model(
                self.model_size,
                device=self.device,
                compute_type="int8",
                language=self.language,
            )

    def close(self) -> None:
        """Free model resources (helpful in long-running processes)."""
        try:
            if self.model is not None:
                del self.model
                self.model = None
        finally:
            try:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            except Exception as e:
                warnings.warn(
                    "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
                    f"{e!r}",
                    RuntimeWarning,
                    stacklevel=2,
                )

__init__(model_size='large', device=None, language=None)

Class for WhisperX model loading and inference. Args: model_size: Size of Whisper model to load (small, base, large). device: "cuda" or "cpu" (auto-detected when None). language: ISO-639-1 language code (e.g., "en", "fr", "de"). If None, language will be detected automatically. Set this to avoid unreliable detection on small clips.

Source code in ammico/model.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def __init__(
    self,
    model_size: str = "large",
    device: Optional[str] = None,
    language: Optional[str] = None,
) -> None:
    """
    Class for WhisperX model loading and inference.
    Args:
        model_size: Size of Whisper model to load (small, base, large).
        device: "cuda" or "cpu" (auto-detected when None).
        language: ISO-639-1 language code (e.g., "en", "fr", "de").
                 If None, language will be detected automatically.
                 Set this to avoid unreliable detection on small clips.
    """
    self.device = resolve_model_device(device)

    self.model_size = resolve_model_size(model_size)

    self.model = None

    self.language = self._validate_language(language)

    self._load_model()

close()

Free model resources (helpful in long-running processes).

Source code in ammico/model.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def close(self) -> None:
    """Free model resources (helpful in long-running processes)."""
    try:
        if self.model is not None:
            del self.model
            self.model = None
    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception as e:
            warnings.warn(
                "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
                f"{e!r}",
                RuntimeWarning,
                stacklevel=2,
            )

MultimodalEmbeddingsModel

Source code in ammico/model.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class MultimodalEmbeddingsModel:
    def __init__(
        self,
        device: Optional[str] = None,
    ) -> None:
        """
        Class for Multimodal Embeddings model loading and inference. Uses Jina CLIP-V2 model.
        Args:
            device: "cuda" or "cpu" (auto-detected when None).
        """
        self.device = resolve_model_device(device)

        model_id = "jinaai/jina-clip-v2"

        self.model = SentenceTransformer(
            model_id,
            device=self.device,
            trust_remote_code=True,
            model_kwargs={"torch_dtype": "auto"},
        )

        self.model.eval()

        self.embedding_dim = 1024

    @torch.inference_mode()
    def encode_text(
        self,
        texts: Union[str, List[str]],
        batch_size: int = 64,
        truncate_dim: Optional[int] = None,
    ) -> Union[torch.Tensor, np.ndarray]:
        if isinstance(texts, str):
            texts = [texts]

        convert_to_tensor = self.device == "cuda"
        convert_to_numpy = not convert_to_tensor

        embeddings = self.model.encode(
            texts,
            batch_size=batch_size,
            convert_to_tensor=convert_to_tensor,
            convert_to_numpy=convert_to_numpy,
            normalize_embeddings=True,
        )

        if truncate_dim is not None:
            if not (64 <= truncate_dim <= self.embedding_dim):
                raise ValueError(
                    f"truncate_dim must be between 64 and {self.embedding_dim}"
                )
            embeddings = embeddings[:, :truncate_dim]
        return embeddings

    @torch.inference_mode()
    def encode_image(
        self,
        images: Union[Image.Image, List[Image.Image]],
        batch_size: int = 32,
        truncate_dim: Optional[int] = None,
    ) -> Union[torch.Tensor, np.ndarray]:
        if not isinstance(images, (Image.Image, list)):
            raise ValueError(
                "images must be a PIL.Image or a list of PIL.Image objects. Please load images properly."
            )

        convert_to_tensor = self.device == "cuda"
        convert_to_numpy = not convert_to_tensor

        embeddings = self.model.encode(
            images if isinstance(images, list) else [images],
            batch_size=batch_size,
            convert_to_tensor=convert_to_tensor,
            convert_to_numpy=convert_to_numpy,
            normalize_embeddings=True,
        )

        if truncate_dim is not None:
            if not (64 <= truncate_dim <= self.embedding_dim):
                raise ValueError(
                    f"truncate_dim must be between 64 and {self.embedding_dim}"
                )
            embeddings = embeddings[:, :truncate_dim]

        return embeddings

    def close(self) -> None:
        """Free model resources (helpful in long-running processes)."""
        try:
            if self.model is not None:
                del self.model
                self.model = None
        finally:
            try:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            except Exception as e:
                warnings.warn(
                    "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
                    f"{e!r}",
                    RuntimeWarning,
                    stacklevel=2,
                )

__init__(device=None)

Class for Multimodal Embeddings model loading and inference. Uses Jina CLIP-V2 model. Args: device: "cuda" or "cpu" (auto-detected when None).

Source code in ammico/model.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def __init__(
    self,
    device: Optional[str] = None,
) -> None:
    """
    Class for Multimodal Embeddings model loading and inference. Uses Jina CLIP-V2 model.
    Args:
        device: "cuda" or "cpu" (auto-detected when None).
    """
    self.device = resolve_model_device(device)

    model_id = "jinaai/jina-clip-v2"

    self.model = SentenceTransformer(
        model_id,
        device=self.device,
        trust_remote_code=True,
        model_kwargs={"torch_dtype": "auto"},
    )

    self.model.eval()

    self.embedding_dim = 1024

close()

Free model resources (helpful in long-running processes).

Source code in ammico/model.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def close(self) -> None:
    """Free model resources (helpful in long-running processes)."""
    try:
        if self.model is not None:
            del self.model
            self.model = None
    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception as e:
            warnings.warn(
                "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
                f"{e!r}",
                RuntimeWarning,
                stacklevel=2,
            )

MultimodalSummaryModel

Source code in ammico/model.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class MultimodalSummaryModel:
    DEFAULT_CUDA_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
    DEFAULT_CPU_MODEL = "Qwen/Qwen2.5-VL-3B-Instruct"

    def __init__(
        self,
        model_id: Optional[str] = None,
        device: Optional[str] = None,
        cache_dir: Optional[str] = None,
    ) -> None:
        """
        Class for QWEN-2.5-VL model loading and inference.
        Args:
            model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu".
            device: "cuda" or "cpu" (auto-detected when None).
            cache_dir: huggingface cache dir (optional).
        """
        self.device = resolve_model_device(device)

        if model_id is not None and model_id not in (
            self.DEFAULT_CUDA_MODEL,
            self.DEFAULT_CPU_MODEL,
        ):
            raise ValueError(
                f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}"
            )

        self.model_id = model_id or (
            self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
        )

        self.cache_dir = cache_dir
        self._trust_remote_code = True
        self._quantize = True

        self.model = None
        self.processor = None
        self.tokenizer = None

        self._load_model_and_processor()

    def _load_model_and_processor(self):
        load_kwargs = {"trust_remote_code": self._trust_remote_code, "use_cache": True}
        if self.cache_dir:
            load_kwargs["cache_dir"] = self.cache_dir

        self.processor = AutoProcessor.from_pretrained(
            self.model_id, padding_side="left", **load_kwargs
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **load_kwargs)

        if self.device == "cuda":
            compute_dtype = (
                torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
            )
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=compute_dtype,
            )
            load_kwargs["quantization_config"] = bnb_config
            load_kwargs["device_map"] = "auto"

        else:
            load_kwargs.pop("quantization_config", None)
            load_kwargs.pop("device_map", None)

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            self.model_id, **load_kwargs
        )
        self.model.eval()

    def close(self) -> None:
        """Free model resources (helpful in long-running processes)."""
        try:
            if self.model is not None:
                del self.model
                self.model = None
            if self.processor is not None:
                del self.processor
                self.processor = None
            if self.tokenizer is not None:
                del self.tokenizer
                self.tokenizer = None
        finally:
            try:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            except Exception as e:
                warnings.warn(
                    "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
                    f"{e!r}",
                    RuntimeWarning,
                    stacklevel=2,
                )

__init__(model_id=None, device=None, cache_dir=None)

Class for QWEN-2.5-VL model loading and inference. Args: model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu". device: "cuda" or "cpu" (auto-detected when None). cache_dir: huggingface cache dir (optional).

Source code in ammico/model.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    model_id: Optional[str] = None,
    device: Optional[str] = None,
    cache_dir: Optional[str] = None,
) -> None:
    """
    Class for QWEN-2.5-VL model loading and inference.
    Args:
        model_id: Type of model to load, defaults to a smaller version for CPU if device is "cpu".
        device: "cuda" or "cpu" (auto-detected when None).
        cache_dir: huggingface cache dir (optional).
    """
    self.device = resolve_model_device(device)

    if model_id is not None and model_id not in (
        self.DEFAULT_CUDA_MODEL,
        self.DEFAULT_CPU_MODEL,
    ):
        raise ValueError(
            f"model_id must be one of {self.DEFAULT_CUDA_MODEL} or {self.DEFAULT_CPU_MODEL}"
        )

    self.model_id = model_id or (
        self.DEFAULT_CUDA_MODEL if self.device == "cuda" else self.DEFAULT_CPU_MODEL
    )

    self.cache_dir = cache_dir
    self._trust_remote_code = True
    self._quantize = True

    self.model = None
    self.processor = None
    self.tokenizer = None

    self._load_model_and_processor()

close()

Free model resources (helpful in long-running processes).

Source code in ammico/model.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def close(self) -> None:
    """Free model resources (helpful in long-running processes)."""
    try:
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception as e:
            warnings.warn(
                "Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
                f"{e!r}",
                RuntimeWarning,
                stacklevel=2,
            )

Prompt Builder

ProcessingLevel

Bases: Enum

Define the three processing levels in a pipeline. FRAME: individual frame analysis CLIP: video segment (multiple frames) VIDEO: full video (multiple clips)

Source code in ammico/prompt_builder.py
 5
 6
 7
 8
 9
10
11
12
13
class ProcessingLevel(Enum):
    """Define the three processing levels in a pipeline.
    FRAME: individual frame analysis
    CLIP: video segment (multiple frames)
    VIDEO: full video (multiple clips)"""

    FRAME = "frame"
    CLIP = "clip"
    VIDEO = "video"

PromptBuilder

Modular prompt builder for multi-level video analysis. Handles frame-level, clip-level, and video-level prompts.

Source code in ammico/prompt_builder.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class PromptBuilder:
    """
    Modular prompt builder for multi-level video analysis.
    Handles frame-level, clip-level, and video-level prompts.
    """

    ROLE_MODULE = """You are a precise video analysis AI. Your purpose is to:
    - Extract only information explicitly present in provided sources
    - Never hallucinate or infer beyond what is shown
    - Generate clear, concise, well-structured outputs
    - Maintain logical coherence across visual and audio sources"""

    CONSTRAINTS_MODULE = """## Quality Requirements

    - **Accuracy:** Use only explicitly provided information
    - **Conciseness:** Eliminate redundancy; be direct
    - **Clarity:** Use accessible language
    - **Consistency:** Align audio and visual information when both exist
    - **Format Compliance:** Follow output format exactly"""

    @staticmethod
    def visual_frames_module() -> str:
        """For frame-level processing with actual images."""
        str_to_return = """## Visual Information

        Visual information is represented by keyframe extracted from the video segment at specified timestamp."""
        return str_to_return

    @staticmethod
    def visual_captions_module(frame_bullets: List[str]) -> str:
        """For clip-level processing with frame summaries."""
        bullets_text = "\n".join(frame_bullets)
        str_to_return = f"""## Visual Information

        The following are summary bullets extracted from video frames with timestamps.
        These bullets represent the visual content detected in each frame of the video segment:

        {bullets_text}"""

        return str_to_return

    @staticmethod
    def visual_captions_final_module(clip_summaries: List[str]) -> str:
        """For video-level processing with clip summaries."""
        str_to_return = f"""## Visual Information

        The following are brief summaries obtained for each segment of the video.
        These summaries are associated with the timestamp of each segment's beginning:

        {clip_summaries}"""

        return str_to_return

    @staticmethod
    def audio_module(audio_transcription: List[Dict[str, Any]]) -> str:
        """Audio transcription with timestamps."""
        audio_text = "\n".join(
            [
                f"[{a['start_time']:.2f}s - {a['end_time']:.2f}s]: {a['text'].strip()}"
                for a in audio_transcription
            ]
        )
        str_to_return = f"""## Audio Information

        The following is the audio transcription for the same video segment,
        with precise timestamps for each spoken element:

        {audio_text}"""
        return str_to_return

    @staticmethod
    def summary_task(has_audio: bool = False) -> str:
        """Generate summary task (with or without audio)."""
        sources = "visual and audio information" if has_audio else "visual information"
        str_to_return = f"""## Task: Generate Concise Summary

        Based on the {sources} provided, generate a brief summary that:
        - Captures and summarizes the main events and themes
        - Uses clear, accessible language
        - Is between 1-3 sentences
        - Contains no unsupported claims

        Return ONLY this format:

        Summary: <your summary here>"""
        return str_to_return

    @staticmethod
    def summary_vqa_task(
        level: ProcessingLevel,
        has_audio: bool = False,
    ) -> str:
        """Generate summary+VQA task (adapts based on level and audio). For Frame and Clip levels."""

        if level == ProcessingLevel.FRAME:
            vqa_task = """Answer the provided questions based ONLY on information from the visual information. Answers must be brief and direct.

            **Critical Rule:** If you cannot answer a question from the provided sources,
            respond with: "Cannot be determined from provided information."
            """
        elif level == ProcessingLevel.CLIP:
            if has_audio:
                priority_list = """
                    1. **Frame-Level Answers** - Pre-computed per-frame answers (auxiliary reference only)
                    2. **Audio Information** - Spoken content from audio transcription
                    3. **Visual Information** - Direct visual content from video frames"""
            else:
                priority_list = """1. **Frame-Level Answers** - Pre-computed per-frame answers (auxiliary reference only)
                    2. **Visual Information** - Direct visual content from video frames"""
            vqa_task = f"""For each question, use the BEST available source in this priority:
                {priority_list}

                **Critical Logic:**
                - If frame-level answer is a REAL answer (not "Cannot be determined from provided information.") → use it
                - If frame-level answer is "Cannot be determined" → SKIP IT and check audio/visual instead
                - If answer is found in visual OR audio information → use that
                - ONLY respond "Cannot be determined" if truly no information exists anywhere

                """

        str_to_return = f"""## You have two tasks:

        ### task 1: Concise Summary
        Generate a brief summary that captures and summarizes main events and themes from the visual information (1-3 sentences).

        ### task 2: Question Answering
        {vqa_task}

        Return ONLY this format:

        Summary: <your summary here>

        VQA Answers:
        1. <answer to question 1>
        2. <answer to question 2>
        [etc.]"""

        return str_to_return

    @staticmethod
    def vqa_only_task() -> str:
        """VQA-only task for video-level processing."""
        str_to_return = """## Task: Answer Questions

        For each question, use the BEST available source in this priority:
            1. **Segment-Level Answers** - Pre-computed per-frame answers (auxiliary reference only)
            2. **Visual Information** - Direct visual content from video frames

        **Critical Logic:**
                - If segment-level answer is a REAL answer (not "Cannot be determined from provided information.") → use it
                - If frame-level answer is "Cannot be determined" → SKIP IT and check visual information instead
                - If answer is found in visual information → use that
                - ONLY respond "Cannot be determined" if truly no information exists anywhere

        Return ONLY this format:

        VQA Answers:
        1. <answer to question 1>
        2. <answer to question 2>
        [etc.]"""

        return str_to_return

    @staticmethod
    def questions_module(questions: List[str]) -> str:
        """Format questions list."""
        questions_text = "\n".join(
            [f"{i + 1}. {q.strip()}" for i, q in enumerate(questions)]
        )
        str_to_return = f"""## Questions to Answer

        {questions_text}
        """
        return str_to_return

    @staticmethod
    def vqa_context_module(vqa_bullets: List[str], is_final: bool = False) -> str:
        """VQA context (frame-level or clip-level answers)."""
        if is_final:
            header = """## SEGMENT-Level Answer Context (Reference Only)

            The following are answers to above questions obtained for each segment of the video.
            These answers are associated with the timestamp of each segment's beginning:"""
        else:
            header = """## FRAME-Level Answer Context (Reference Only)

            For each question, the following are frame-level answers provided as reference.
            If these answers are "Cannot be determined", do not accept that as final—instead, 
            use visual and audio information to answer the question:"""

        bullets_text = "\n".join(vqa_bullets)
        return f"{header}\n\n{bullets_text}"

    @classmethod
    def build_frame_prompt(
        cls, include_vqa: bool = False, questions: Optional[List[str]] = None
    ) -> str:
        """Build prompt for frame-level analysis."""
        modules = [cls.ROLE_MODULE, cls.visual_frames_module()]

        if include_vqa and not questions:
            raise ValueError("Questions must be provided when VQA should be included.")

        if include_vqa:
            modules.append(cls.summary_vqa_task(ProcessingLevel.FRAME))
            modules.append(cls.questions_module(questions))
        else:
            modules.append(cls.summary_task())

        modules.append(cls.CONSTRAINTS_MODULE)
        return "\n\n".join(modules)

    @classmethod
    def build_clip_prompt(
        cls,
        frame_bullets: List[str],
        include_audio: bool = False,
        audio_transcription: Optional[List[Dict]] = None,
        include_vqa: bool = False,
        questions: Optional[List[str]] = None,
        vqa_bullets: Optional[List[str]] = None,
    ) -> str:
        """Build prompt for clip-level analysis."""
        modules = [cls.ROLE_MODULE, cls.visual_captions_module(frame_bullets)]

        if include_audio and audio_transcription:
            modules.append(cls.audio_module(audio_transcription))

        if include_vqa and not questions:
            raise ValueError("Questions must be provided when VQA should be included.")

        if include_vqa:
            modules.append(
                cls.summary_vqa_task(ProcessingLevel.CLIP, has_audio=include_audio)
            )
            modules.append(cls.questions_module(questions))
            if vqa_bullets:
                modules.append(cls.vqa_context_module(vqa_bullets, is_final=False))
        else:
            modules.append(cls.summary_task(has_audio=include_audio))

        modules.append(cls.CONSTRAINTS_MODULE)
        return "\n\n".join(modules)

    @classmethod
    def build_video_prompt(
        cls,
        include_vqa: bool = False,
        clip_summaries: Optional[List[str]] = None,
        questions: Optional[List[str]] = None,
        vqa_bullets: Optional[List[str]] = None,
    ) -> str:
        """Build prompt for video-level analysis."""
        modules = [cls.ROLE_MODULE]

        if not include_vqa:
            modules.append(cls.visual_captions_final_module(clip_summaries))
            modules.append(cls.summary_task())
        else:
            if not questions:
                raise ValueError(
                    "Questions must be provided when VQA should be included."
                )
            if not vqa_bullets:
                raise ValueError("Vqa_bullets must be provided for video-level VQA.")

            modules.append(cls.visual_captions_final_module(clip_summaries))
            modules.append(cls.vqa_context_module(vqa_bullets, is_final=True))
            modules.append(cls.vqa_only_task())
            modules.append(cls.questions_module(questions))

        modules.append(cls.CONSTRAINTS_MODULE)
        return "\n\n".join(modules)

audio_module(audio_transcription) staticmethod

Audio transcription with timestamps.

Source code in ammico/prompt_builder.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@staticmethod
def audio_module(audio_transcription: List[Dict[str, Any]]) -> str:
    """Audio transcription with timestamps."""
    audio_text = "\n".join(
        [
            f"[{a['start_time']:.2f}s - {a['end_time']:.2f}s]: {a['text'].strip()}"
            for a in audio_transcription
        ]
    )
    str_to_return = f"""## Audio Information

    The following is the audio transcription for the same video segment,
    with precise timestamps for each spoken element:

    {audio_text}"""
    return str_to_return

build_clip_prompt(frame_bullets, include_audio=False, audio_transcription=None, include_vqa=False, questions=None, vqa_bullets=None) classmethod

Build prompt for clip-level analysis.

Source code in ammico/prompt_builder.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@classmethod
def build_clip_prompt(
    cls,
    frame_bullets: List[str],
    include_audio: bool = False,
    audio_transcription: Optional[List[Dict]] = None,
    include_vqa: bool = False,
    questions: Optional[List[str]] = None,
    vqa_bullets: Optional[List[str]] = None,
) -> str:
    """Build prompt for clip-level analysis."""
    modules = [cls.ROLE_MODULE, cls.visual_captions_module(frame_bullets)]

    if include_audio and audio_transcription:
        modules.append(cls.audio_module(audio_transcription))

    if include_vqa and not questions:
        raise ValueError("Questions must be provided when VQA should be included.")

    if include_vqa:
        modules.append(
            cls.summary_vqa_task(ProcessingLevel.CLIP, has_audio=include_audio)
        )
        modules.append(cls.questions_module(questions))
        if vqa_bullets:
            modules.append(cls.vqa_context_module(vqa_bullets, is_final=False))
    else:
        modules.append(cls.summary_task(has_audio=include_audio))

    modules.append(cls.CONSTRAINTS_MODULE)
    return "\n\n".join(modules)

build_frame_prompt(include_vqa=False, questions=None) classmethod

Build prompt for frame-level analysis.

Source code in ammico/prompt_builder.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@classmethod
def build_frame_prompt(
    cls, include_vqa: bool = False, questions: Optional[List[str]] = None
) -> str:
    """Build prompt for frame-level analysis."""
    modules = [cls.ROLE_MODULE, cls.visual_frames_module()]

    if include_vqa and not questions:
        raise ValueError("Questions must be provided when VQA should be included.")

    if include_vqa:
        modules.append(cls.summary_vqa_task(ProcessingLevel.FRAME))
        modules.append(cls.questions_module(questions))
    else:
        modules.append(cls.summary_task())

    modules.append(cls.CONSTRAINTS_MODULE)
    return "\n\n".join(modules)

build_video_prompt(include_vqa=False, clip_summaries=None, questions=None, vqa_bullets=None) classmethod

Build prompt for video-level analysis.

Source code in ammico/prompt_builder.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
@classmethod
def build_video_prompt(
    cls,
    include_vqa: bool = False,
    clip_summaries: Optional[List[str]] = None,
    questions: Optional[List[str]] = None,
    vqa_bullets: Optional[List[str]] = None,
) -> str:
    """Build prompt for video-level analysis."""
    modules = [cls.ROLE_MODULE]

    if not include_vqa:
        modules.append(cls.visual_captions_final_module(clip_summaries))
        modules.append(cls.summary_task())
    else:
        if not questions:
            raise ValueError(
                "Questions must be provided when VQA should be included."
            )
        if not vqa_bullets:
            raise ValueError("Vqa_bullets must be provided for video-level VQA.")

        modules.append(cls.visual_captions_final_module(clip_summaries))
        modules.append(cls.vqa_context_module(vqa_bullets, is_final=True))
        modules.append(cls.vqa_only_task())
        modules.append(cls.questions_module(questions))

    modules.append(cls.CONSTRAINTS_MODULE)
    return "\n\n".join(modules)

questions_module(questions) staticmethod

Format questions list.

Source code in ammico/prompt_builder.py
179
180
181
182
183
184
185
186
187
188
189
@staticmethod
def questions_module(questions: List[str]) -> str:
    """Format questions list."""
    questions_text = "\n".join(
        [f"{i + 1}. {q.strip()}" for i, q in enumerate(questions)]
    )
    str_to_return = f"""## Questions to Answer

    {questions_text}
    """
    return str_to_return

summary_task(has_audio=False) staticmethod

Generate summary task (with or without audio).

Source code in ammico/prompt_builder.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@staticmethod
def summary_task(has_audio: bool = False) -> str:
    """Generate summary task (with or without audio)."""
    sources = "visual and audio information" if has_audio else "visual information"
    str_to_return = f"""## Task: Generate Concise Summary

    Based on the {sources} provided, generate a brief summary that:
    - Captures and summarizes the main events and themes
    - Uses clear, accessible language
    - Is between 1-3 sentences
    - Contains no unsupported claims

    Return ONLY this format:

    Summary: <your summary here>"""
    return str_to_return

summary_vqa_task(level, has_audio=False) staticmethod

Generate summary+VQA task (adapts based on level and audio). For Frame and Clip levels.

Source code in ammico/prompt_builder.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@staticmethod
def summary_vqa_task(
    level: ProcessingLevel,
    has_audio: bool = False,
) -> str:
    """Generate summary+VQA task (adapts based on level and audio). For Frame and Clip levels."""

    if level == ProcessingLevel.FRAME:
        vqa_task = """Answer the provided questions based ONLY on information from the visual information. Answers must be brief and direct.

        **Critical Rule:** If you cannot answer a question from the provided sources,
        respond with: "Cannot be determined from provided information."
        """
    elif level == ProcessingLevel.CLIP:
        if has_audio:
            priority_list = """
                1. **Frame-Level Answers** - Pre-computed per-frame answers (auxiliary reference only)
                2. **Audio Information** - Spoken content from audio transcription
                3. **Visual Information** - Direct visual content from video frames"""
        else:
            priority_list = """1. **Frame-Level Answers** - Pre-computed per-frame answers (auxiliary reference only)
                2. **Visual Information** - Direct visual content from video frames"""
        vqa_task = f"""For each question, use the BEST available source in this priority:
            {priority_list}

            **Critical Logic:**
            - If frame-level answer is a REAL answer (not "Cannot be determined from provided information.") → use it
            - If frame-level answer is "Cannot be determined" → SKIP IT and check audio/visual instead
            - If answer is found in visual OR audio information → use that
            - ONLY respond "Cannot be determined" if truly no information exists anywhere

            """

    str_to_return = f"""## You have two tasks:

    ### task 1: Concise Summary
    Generate a brief summary that captures and summarizes main events and themes from the visual information (1-3 sentences).

    ### task 2: Question Answering
    {vqa_task}

    Return ONLY this format:

    Summary: <your summary here>

    VQA Answers:
    1. <answer to question 1>
    2. <answer to question 2>
    [etc.]"""

    return str_to_return

visual_captions_final_module(clip_summaries) staticmethod

For video-level processing with clip summaries.

Source code in ammico/prompt_builder.py
57
58
59
60
61
62
63
64
65
66
67
@staticmethod
def visual_captions_final_module(clip_summaries: List[str]) -> str:
    """For video-level processing with clip summaries."""
    str_to_return = f"""## Visual Information

    The following are brief summaries obtained for each segment of the video.
    These summaries are associated with the timestamp of each segment's beginning:

    {clip_summaries}"""

    return str_to_return

visual_captions_module(frame_bullets) staticmethod

For clip-level processing with frame summaries.

Source code in ammico/prompt_builder.py
44
45
46
47
48
49
50
51
52
53
54
55
@staticmethod
def visual_captions_module(frame_bullets: List[str]) -> str:
    """For clip-level processing with frame summaries."""
    bullets_text = "\n".join(frame_bullets)
    str_to_return = f"""## Visual Information

    The following are summary bullets extracted from video frames with timestamps.
    These bullets represent the visual content detected in each frame of the video segment:

    {bullets_text}"""

    return str_to_return

visual_frames_module() staticmethod

For frame-level processing with actual images.

Source code in ammico/prompt_builder.py
36
37
38
39
40
41
42
@staticmethod
def visual_frames_module() -> str:
    """For frame-level processing with actual images."""
    str_to_return = """## Visual Information

    Visual information is represented by keyframe extracted from the video segment at specified timestamp."""
    return str_to_return

vqa_context_module(vqa_bullets, is_final=False) staticmethod

VQA context (frame-level or clip-level answers).

Source code in ammico/prompt_builder.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
@staticmethod
def vqa_context_module(vqa_bullets: List[str], is_final: bool = False) -> str:
    """VQA context (frame-level or clip-level answers)."""
    if is_final:
        header = """## SEGMENT-Level Answer Context (Reference Only)

        The following are answers to above questions obtained for each segment of the video.
        These answers are associated with the timestamp of each segment's beginning:"""
    else:
        header = """## FRAME-Level Answer Context (Reference Only)

        For each question, the following are frame-level answers provided as reference.
        If these answers are "Cannot be determined", do not accept that as final—instead, 
        use visual and audio information to answer the question:"""

    bullets_text = "\n".join(vqa_bullets)
    return f"{header}\n\n{bullets_text}"

vqa_only_task() staticmethod

VQA-only task for video-level processing.

Source code in ammico/prompt_builder.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@staticmethod
def vqa_only_task() -> str:
    """VQA-only task for video-level processing."""
    str_to_return = """## Task: Answer Questions

    For each question, use the BEST available source in this priority:
        1. **Segment-Level Answers** - Pre-computed per-frame answers (auxiliary reference only)
        2. **Visual Information** - Direct visual content from video frames

    **Critical Logic:**
            - If segment-level answer is a REAL answer (not "Cannot be determined from provided information.") → use it
            - If frame-level answer is "Cannot be determined" → SKIP IT and check visual information instead
            - If answer is found in visual information → use that
            - ONLY respond "Cannot be determined" if truly no information exists anywhere

    Return ONLY this format:

    VQA Answers:
    1. <answer to question 1>
    2. <answer to question 2>
    [etc.]"""

    return str_to_return