langfuse._task_manager.media_manager

  1import logging
  2import time
  3from queue import Empty, Queue
  4from typing import Any, Callable, Optional, TypeVar
  5
  6import backoff
  7import requests
  8from typing_extensions import ParamSpec
  9
 10from langfuse.api import GetMediaUploadUrlRequest, PatchMediaBody
 11from langfuse.api.client import FernLangfuse
 12from langfuse.api.core import ApiError
 13from langfuse.media import LangfuseMedia
 14from langfuse.utils import _get_timestamp
 15
 16from .media_upload_queue import UploadMediaJob
 17
 18T = TypeVar("T")
 19P = ParamSpec("P")
 20
 21
 22class MediaManager:
 23    _log = logging.getLogger(__name__)
 24
 25    def __init__(
 26        self,
 27        *,
 28        api_client: FernLangfuse,
 29        media_upload_queue: Queue,
 30        max_retries: Optional[int] = 3,
 31    ):
 32        self._api_client = api_client
 33        self._queue = media_upload_queue
 34        self._max_retries = max_retries
 35
 36    def process_next_media_upload(self):
 37        try:
 38            upload_job = self._queue.get(block=True, timeout=1)
 39            self._log.debug(f"Processing upload for {upload_job['media_id']}")
 40            self._process_upload_media_job(data=upload_job)
 41
 42            self._queue.task_done()
 43        except Empty:
 44            self._log.debug("Media upload queue is empty")
 45            pass
 46        except Exception as e:
 47            self._log.error(f"Error uploading media: {e}")
 48            self._queue.task_done()
 49
 50    def process_media_in_event(self, event: dict):
 51        try:
 52            if "body" not in event:
 53                return
 54
 55            body = event["body"]
 56            trace_id = body.get("traceId", None) or (
 57                body.get("id", None)
 58                if "type" in event and "trace" in event["type"]
 59                else None
 60            )
 61
 62            if trace_id is None:
 63                raise ValueError("trace_id is required for media upload")
 64
 65            observation_id = (
 66                body.get("id", None)
 67                if "type" in event
 68                and ("generation" in event["type"] or "span" in event["type"])
 69                else None
 70            )
 71
 72            multimodal_fields = ["input", "output", "metadata"]
 73
 74            for field in multimodal_fields:
 75                if field in body:
 76                    processed_data = self._find_and_process_media(
 77                        data=body[field],
 78                        trace_id=trace_id,
 79                        observation_id=observation_id,
 80                        field=field,
 81                    )
 82
 83                    body[field] = processed_data
 84
 85        except Exception as e:
 86            self._log.error(f"Error processing multimodal event: {e}")
 87
 88    def _find_and_process_media(
 89        self,
 90        *,
 91        data: Any,
 92        trace_id: str,
 93        observation_id: Optional[str],
 94        field: str,
 95    ):
 96        seen = set()
 97        max_levels = 10
 98
 99        def _process_data_recursively(data: Any, level: int):
100            if id(data) in seen or level > max_levels:
101                return data
102
103            seen.add(id(data))
104
105            if isinstance(data, LangfuseMedia):
106                self._process_media(
107                    media=data,
108                    trace_id=trace_id,
109                    observation_id=observation_id,
110                    field=field,
111                )
112
113                return data
114
115            if isinstance(data, str) and data.startswith("data:"):
116                media = LangfuseMedia(
117                    obj=data,
118                    base64_data_uri=data,
119                )
120
121                self._process_media(
122                    media=media,
123                    trace_id=trace_id,
124                    observation_id=observation_id,
125                    field=field,
126                )
127
128                return media
129
130            if isinstance(data, list):
131                return [_process_data_recursively(item, level + 1) for item in data]
132
133            if isinstance(data, dict):
134                return {
135                    key: _process_data_recursively(value, level + 1)
136                    for key, value in data.items()
137                }
138
139            return data
140
141        return _process_data_recursively(data, 1)
142
143    def _process_media(
144        self,
145        *,
146        media: LangfuseMedia,
147        trace_id: str,
148        observation_id: Optional[str],
149        field: str,
150    ):
151        if (
152            media._content_length is None
153            or media._content_type is None
154            or media._content_sha256_hash is None
155            or media._content_bytes is None
156        ):
157            return
158
159        upload_url_response = self._request_with_backoff(
160            self._api_client.media.get_upload_url,
161            request=GetMediaUploadUrlRequest(
162                contentLength=media._content_length,
163                contentType=media._content_type,
164                sha256Hash=media._content_sha256_hash,
165                field=field,
166                traceId=trace_id,
167                observationId=observation_id,
168            ),
169        )
170
171        upload_url = upload_url_response.upload_url
172        media._media_id = upload_url_response.media_id  # Important as this is will be used in the media reference string in serializer
173
174        if upload_url is not None:
175            self._log.debug(f"Scheduling upload for {media._media_id}")
176            self._queue.put(
177                item={
178                    "upload_url": upload_url,
179                    "media_id": media._media_id,
180                    "content_bytes": media._content_bytes,
181                    "content_type": media._content_type,
182                    "content_sha256_hash": media._content_sha256_hash,
183                },
184                block=True,
185                timeout=1,
186            )
187
188        else:
189            self._log.debug(f"Media {media._media_id} already uploaded")
190
191    def _process_upload_media_job(
192        self,
193        *,
194        data: UploadMediaJob,
195    ):
196        upload_start_time = time.time()
197        upload_response = self._request_with_backoff(
198            requests.put,
199            data["upload_url"],
200            headers={
201                "Content-Type": data["content_type"],
202                "x-amz-checksum-sha256": data["content_sha256_hash"],
203            },
204            data=data["content_bytes"],
205        )
206        upload_time_ms = int((time.time() - upload_start_time) * 1000)
207
208        self._request_with_backoff(
209            self._api_client.media.patch,
210            media_id=data["media_id"],
211            request=PatchMediaBody(
212                uploadedAt=_get_timestamp(),
213                uploadHttpStatus=upload_response.status_code,
214                uploadHttpError=upload_response.text,
215                uploadTimeMs=upload_time_ms,
216            ),
217        )
218
219        self._log.debug(
220            f"Media upload completed for {data['media_id']} in {upload_time_ms}ms"
221        )
222
223    def _request_with_backoff(
224        self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
225    ) -> T:
226        @backoff.on_exception(
227            backoff.expo, Exception, max_tries=self._max_retries, logger=None
228        )
229        def execute_task_with_backoff() -> T:
230            try:
231                return func(*args, **kwargs)
232            except ApiError as e:
233                if (
234                    e.status_code is not None
235                    and 400 <= e.status_code < 500
236                    and (e.status_code) != 429
237                ):
238                    raise e
239            except Exception as e:
240                raise e
241
242            raise Exception("Failed to execute task")
243
244        return execute_task_with_backoff()
P = ~P
class MediaManager:
 23class MediaManager:
 24    _log = logging.getLogger(__name__)
 25
 26    def __init__(
 27        self,
 28        *,
 29        api_client: FernLangfuse,
 30        media_upload_queue: Queue,
 31        max_retries: Optional[int] = 3,
 32    ):
 33        self._api_client = api_client
 34        self._queue = media_upload_queue
 35        self._max_retries = max_retries
 36
 37    def process_next_media_upload(self):
 38        try:
 39            upload_job = self._queue.get(block=True, timeout=1)
 40            self._log.debug(f"Processing upload for {upload_job['media_id']}")
 41            self._process_upload_media_job(data=upload_job)
 42
 43            self._queue.task_done()
 44        except Empty:
 45            self._log.debug("Media upload queue is empty")
 46            pass
 47        except Exception as e:
 48            self._log.error(f"Error uploading media: {e}")
 49            self._queue.task_done()
 50
 51    def process_media_in_event(self, event: dict):
 52        try:
 53            if "body" not in event:
 54                return
 55
 56            body = event["body"]
 57            trace_id = body.get("traceId", None) or (
 58                body.get("id", None)
 59                if "type" in event and "trace" in event["type"]
 60                else None
 61            )
 62
 63            if trace_id is None:
 64                raise ValueError("trace_id is required for media upload")
 65
 66            observation_id = (
 67                body.get("id", None)
 68                if "type" in event
 69                and ("generation" in event["type"] or "span" in event["type"])
 70                else None
 71            )
 72
 73            multimodal_fields = ["input", "output", "metadata"]
 74
 75            for field in multimodal_fields:
 76                if field in body:
 77                    processed_data = self._find_and_process_media(
 78                        data=body[field],
 79                        trace_id=trace_id,
 80                        observation_id=observation_id,
 81                        field=field,
 82                    )
 83
 84                    body[field] = processed_data
 85
 86        except Exception as e:
 87            self._log.error(f"Error processing multimodal event: {e}")
 88
 89    def _find_and_process_media(
 90        self,
 91        *,
 92        data: Any,
 93        trace_id: str,
 94        observation_id: Optional[str],
 95        field: str,
 96    ):
 97        seen = set()
 98        max_levels = 10
 99
100        def _process_data_recursively(data: Any, level: int):
101            if id(data) in seen or level > max_levels:
102                return data
103
104            seen.add(id(data))
105
106            if isinstance(data, LangfuseMedia):
107                self._process_media(
108                    media=data,
109                    trace_id=trace_id,
110                    observation_id=observation_id,
111                    field=field,
112                )
113
114                return data
115
116            if isinstance(data, str) and data.startswith("data:"):
117                media = LangfuseMedia(
118                    obj=data,
119                    base64_data_uri=data,
120                )
121
122                self._process_media(
123                    media=media,
124                    trace_id=trace_id,
125                    observation_id=observation_id,
126                    field=field,
127                )
128
129                return media
130
131            if isinstance(data, list):
132                return [_process_data_recursively(item, level + 1) for item in data]
133
134            if isinstance(data, dict):
135                return {
136                    key: _process_data_recursively(value, level + 1)
137                    for key, value in data.items()
138                }
139
140            return data
141
142        return _process_data_recursively(data, 1)
143
144    def _process_media(
145        self,
146        *,
147        media: LangfuseMedia,
148        trace_id: str,
149        observation_id: Optional[str],
150        field: str,
151    ):
152        if (
153            media._content_length is None
154            or media._content_type is None
155            or media._content_sha256_hash is None
156            or media._content_bytes is None
157        ):
158            return
159
160        upload_url_response = self._request_with_backoff(
161            self._api_client.media.get_upload_url,
162            request=GetMediaUploadUrlRequest(
163                contentLength=media._content_length,
164                contentType=media._content_type,
165                sha256Hash=media._content_sha256_hash,
166                field=field,
167                traceId=trace_id,
168                observationId=observation_id,
169            ),
170        )
171
172        upload_url = upload_url_response.upload_url
173        media._media_id = upload_url_response.media_id  # Important as this is will be used in the media reference string in serializer
174
175        if upload_url is not None:
176            self._log.debug(f"Scheduling upload for {media._media_id}")
177            self._queue.put(
178                item={
179                    "upload_url": upload_url,
180                    "media_id": media._media_id,
181                    "content_bytes": media._content_bytes,
182                    "content_type": media._content_type,
183                    "content_sha256_hash": media._content_sha256_hash,
184                },
185                block=True,
186                timeout=1,
187            )
188
189        else:
190            self._log.debug(f"Media {media._media_id} already uploaded")
191
192    def _process_upload_media_job(
193        self,
194        *,
195        data: UploadMediaJob,
196    ):
197        upload_start_time = time.time()
198        upload_response = self._request_with_backoff(
199            requests.put,
200            data["upload_url"],
201            headers={
202                "Content-Type": data["content_type"],
203                "x-amz-checksum-sha256": data["content_sha256_hash"],
204            },
205            data=data["content_bytes"],
206        )
207        upload_time_ms = int((time.time() - upload_start_time) * 1000)
208
209        self._request_with_backoff(
210            self._api_client.media.patch,
211            media_id=data["media_id"],
212            request=PatchMediaBody(
213                uploadedAt=_get_timestamp(),
214                uploadHttpStatus=upload_response.status_code,
215                uploadHttpError=upload_response.text,
216                uploadTimeMs=upload_time_ms,
217            ),
218        )
219
220        self._log.debug(
221            f"Media upload completed for {data['media_id']} in {upload_time_ms}ms"
222        )
223
224    def _request_with_backoff(
225        self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
226    ) -> T:
227        @backoff.on_exception(
228            backoff.expo, Exception, max_tries=self._max_retries, logger=None
229        )
230        def execute_task_with_backoff() -> T:
231            try:
232                return func(*args, **kwargs)
233            except ApiError as e:
234                if (
235                    e.status_code is not None
236                    and 400 <= e.status_code < 500
237                    and (e.status_code) != 429
238                ):
239                    raise e
240            except Exception as e:
241                raise e
242
243            raise Exception("Failed to execute task")
244
245        return execute_task_with_backoff()
MediaManager( *, api_client: langfuse.api.client.FernLangfuse, media_upload_queue: queue.Queue, max_retries: Optional[int] = 3)
26    def __init__(
27        self,
28        *,
29        api_client: FernLangfuse,
30        media_upload_queue: Queue,
31        max_retries: Optional[int] = 3,
32    ):
33        self._api_client = api_client
34        self._queue = media_upload_queue
35        self._max_retries = max_retries
def process_next_media_upload(self):
37    def process_next_media_upload(self):
38        try:
39            upload_job = self._queue.get(block=True, timeout=1)
40            self._log.debug(f"Processing upload for {upload_job['media_id']}")
41            self._process_upload_media_job(data=upload_job)
42
43            self._queue.task_done()
44        except Empty:
45            self._log.debug("Media upload queue is empty")
46            pass
47        except Exception as e:
48            self._log.error(f"Error uploading media: {e}")
49            self._queue.task_done()
def process_media_in_event(self, event: dict):
51    def process_media_in_event(self, event: dict):
52        try:
53            if "body" not in event:
54                return
55
56            body = event["body"]
57            trace_id = body.get("traceId", None) or (
58                body.get("id", None)
59                if "type" in event and "trace" in event["type"]
60                else None
61            )
62
63            if trace_id is None:
64                raise ValueError("trace_id is required for media upload")
65
66            observation_id = (
67                body.get("id", None)
68                if "type" in event
69                and ("generation" in event["type"] or "span" in event["type"])
70                else None
71            )
72
73            multimodal_fields = ["input", "output", "metadata"]
74
75            for field in multimodal_fields:
76                if field in body:
77                    processed_data = self._find_and_process_media(
78                        data=body[field],
79                        trace_id=trace_id,
80                        observation_id=observation_id,
81                        field=field,
82                    )
83
84                    body[field] = processed_data
85
86        except Exception as e:
87            self._log.error(f"Error processing multimodal event: {e}")