Ver código fonte

Fix importing TSV files with S3

Colin Powell 2 anos atrás
pai
commit
9a2ba1fd07

+ 8 - 1
vrobbler/apps/scrobbles/models.py

@@ -5,6 +5,7 @@ from typing import Optional
 from uuid import uuid4
 
 from books.models import Book
+from django.conf import settings
 from django.contrib.auth import get_user_model
 from django.db import models
 from django.urls import reverse
@@ -193,10 +194,16 @@ class AudioScrobblerTSVImport(BaseFileImportMixin):
         self.mark_started()
 
         tz = None
+        user_id = None
         if self.user:
+            user_id = self.user.id
             tz = self.user.profile.tzinfo
+        if getattr(settings, "USE_S3_STORAGE"):
+            tsv_str = self.tsv_file.url
+        else:
+            tsv_str = self.tsv_file.path
         scrobbles = process_audioscrobbler_tsv_file(
-            self.tsv_file.path, self.user.id, user_tz=tz
+            tsv_str, user_id, user_tz=tz
         )
         self.record_log(scrobbles)
         self.mark_finished()

+ 69 - 63
vrobbler/apps/scrobbles/tsv.py

@@ -1,8 +1,10 @@
+import codecs
 import csv
 import logging
 from datetime import datetime
 
 import pytz
+import requests
 from music.utils import (
     get_or_create_album,
     get_or_create_artist,
@@ -20,75 +22,79 @@ def process_audioscrobbler_tsv_file(file_path, user_id, user_tz=None):
     if not user_tz:
         user_tz = pytz.utc
 
-    with open(file_path) as infile:
-        source = "Audioscrobbler File"
-        rows = csv.reader(infile, delimiter="\t")
+    is_os_file = "https://" not in file_path
 
-        source_id = ""
-        for row_num, row in enumerate(rows):
-            if row_num in [0, 1, 2]:
-                if "Rockbox" in row[0]:
-                    source = "Rockbox"
-                source_id += row[0] + "\n"
-                continue
-            if len(row) > 8:
-                logger.warning(
-                    "Improper row length during Audioscrobbler import",
-                    extra={"row": row},
-                )
-                continue
+    if not is_os_file:
+        r = requests.get(file_path)
+        tsv_data = codecs.iterdecode(r.iter_lines(), "utf-8")
+    else:
+        tsv_data = open(file_path)
 
-            artist = get_or_create_artist(
-                row[AsTsvColumn["ARTIST_NAME"].value]
-            )
-            album = get_or_create_album(
-                row[AsTsvColumn["ALBUM_NAME"].value], artist
-            )
+    source = "Audioscrobbler File"
+    rows = csv.reader(tsv_data, delimiter="\t")
 
-            track = get_or_create_track(
-                title=row[AsTsvColumn["TRACK_NAME"].value],
-                mbid=row[AsTsvColumn["MB_ID"].value],
-                artist=artist,
-                album=album,
-                run_time_seconds=int(
-                    row[AsTsvColumn["RUN_TIME_SECONDS"].value]
-                ),
+    source_id = ""
+    for row_num, row in enumerate(rows):
+        if row_num in [0, 1, 2]:
+            if "Rockbox" in row[0]:
+                source = "Rockbox"
+            source_id += row[0] + "\n"
+            continue
+        if len(row) > 8:
+            logger.warning(
+                "Improper row length during Audioscrobbler import",
+                extra={"row": row},
             )
-            if row[AsTsvColumn["COMPLETE"].value] == "S":
-                logger.info(
-                    f"Skipping track {track} by {artist} because not finished"
-                )
-                continue
+            continue
 
-            timestamp = (
-                datetime.utcfromtimestamp(
-                    int(row[AsTsvColumn["TIMESTAMP"].value])
-                )
-                .replace(tzinfo=user_tz)
-                .astimezone(pytz.utc)
-            )
+        artist = get_or_create_artist(row[AsTsvColumn["ARTIST_NAME"].value])
+        album = get_or_create_album(
+            row[AsTsvColumn["ALBUM_NAME"].value], artist
+        )
 
-            new_scrobble = Scrobble(
-                user_id=user_id,
-                timestamp=timestamp,
-                source=source,
-                source_id=source_id,
-                track=track,
-                played_to_completion=True,
-                in_progress=False,
+        track = get_or_create_track(
+            title=row[AsTsvColumn["TRACK_NAME"].value],
+            mbid=row[AsTsvColumn["MB_ID"].value],
+            artist=artist,
+            album=album,
+            run_time_seconds=int(row[AsTsvColumn["RUN_TIME_SECONDS"].value]),
+        )
+        if row[AsTsvColumn["COMPLETE"].value] == "S":
+            logger.info(
+                f"Skipping track {track} by {artist} because not finished"
             )
-            existing = Scrobble.objects.filter(
-                timestamp=timestamp, track=track
-            ).first()
-            if existing:
-                logger.debug(f"Skipping existing scrobble {new_scrobble}")
-                continue
-            logger.debug(f"Queued scrobble {new_scrobble} for creation")
-            new_scrobbles.append(new_scrobble)
+            continue
 
-        created = Scrobble.objects.bulk_create(new_scrobbles)
-        logger.info(
-            f"Created {len(created)} scrobbles",
-            extra={"created_scrobbles": created},
+        timestamp = (
+            datetime.utcfromtimestamp(int(row[AsTsvColumn["TIMESTAMP"].value]))
+            .replace(tzinfo=user_tz)
+            .astimezone(pytz.utc)
         )
-        return created
+
+        new_scrobble = Scrobble(
+            user_id=user_id,
+            timestamp=timestamp,
+            source=source,
+            source_id=source_id,
+            track=track,
+            played_to_completion=True,
+            in_progress=False,
+        )
+        existing = Scrobble.objects.filter(
+            timestamp=timestamp, track=track
+        ).first()
+        if existing:
+            logger.debug(f"Skipping existing scrobble {new_scrobble}")
+            continue
+        logger.debug(f"Queued scrobble {new_scrobble} for creation")
+        new_scrobbles.append(new_scrobble)
+
+    if is_os_file:
+        tsv_data.close()
+
+    created = Scrobble.objects.bulk_create(new_scrobbles)
+    logger.info(
+        f"Created {len(created)} scrobbles",
+        extra={"created_scrobbles": created},
+    )
+    return created

+ 3 - 1
vrobbler/settings.py

@@ -248,7 +248,9 @@ USE_TZ = True
 #
 from storages.backends import s3boto3
 
-if os.getenv("VROBBLER_USE_S3", "False").lower() in TRUTHY:
+USE_S3_STORAGE = os.getenv("VROBBLER_USE_S3", "False").lower() in TRUTHY
+
+if USE_S3_STORAGE:
     AWS_S3_ENDPOINT_URL = os.getenv("AWS_S3_ENDPOINT_URL", "")
     AWS_STORAGE_BUCKET_NAME = os.getenv("AWS_STORAGE_BUCKET_NAME", "")
     AWS_S3_ACCESS_KEY_ID = os.getenv("AWS_S3_ACCESS_KEY_ID")