403 lines
11 KiB
Python
403 lines
11 KiB
Python
"""CRUD operations for audio tracks."""
|
|
from typing import List, Optional, Dict, Tuple
|
|
from uuid import UUID
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import or_, and_, func, any_
|
|
|
|
from .schema import AudioTrack
|
|
from ..core.analyzer import AudioAnalysis
|
|
from ..utils.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def create_track(db: Session, analysis: AudioAnalysis) -> AudioTrack:
|
|
"""Create a new track from analysis data.
|
|
|
|
Args:
|
|
db: Database session
|
|
analysis: AudioAnalysis object
|
|
|
|
Returns:
|
|
Created AudioTrack instance
|
|
"""
|
|
track = AudioTrack(
|
|
filepath=analysis.filepath,
|
|
filename=analysis.filename,
|
|
duration_seconds=analysis.duration_seconds,
|
|
file_size_bytes=analysis.file_size_bytes,
|
|
format=analysis.format,
|
|
analyzed_at=analysis.analyzed_at,
|
|
|
|
# Features
|
|
tempo_bpm=analysis.tempo_bpm,
|
|
key=analysis.key,
|
|
time_signature=analysis.time_signature,
|
|
energy=analysis.energy,
|
|
danceability=analysis.danceability,
|
|
valence=analysis.valence,
|
|
loudness_lufs=analysis.loudness_lufs,
|
|
spectral_centroid=analysis.spectral_centroid,
|
|
zero_crossing_rate=analysis.zero_crossing_rate,
|
|
|
|
# Classification
|
|
genre_primary=analysis.genre_primary,
|
|
genre_secondary=analysis.genre_secondary,
|
|
genre_confidence=analysis.genre_confidence,
|
|
mood_primary=analysis.mood_primary,
|
|
mood_secondary=analysis.mood_secondary,
|
|
mood_arousal=analysis.mood_arousal,
|
|
mood_valence=analysis.mood_valence,
|
|
instruments=analysis.instruments,
|
|
|
|
# Vocals
|
|
has_vocals=analysis.has_vocals,
|
|
vocal_gender=analysis.vocal_gender,
|
|
|
|
# Metadata
|
|
extra_metadata=analysis.metadata,
|
|
)
|
|
|
|
db.add(track)
|
|
db.commit()
|
|
db.refresh(track)
|
|
|
|
logger.info(f"Created track: {track.id} - {track.filename}")
|
|
return track
|
|
|
|
|
|
def get_track_by_id(db: Session, track_id: UUID) -> Optional[AudioTrack]:
|
|
"""Get track by ID.
|
|
|
|
Args:
|
|
db: Database session
|
|
track_id: Track UUID
|
|
|
|
Returns:
|
|
AudioTrack or None if not found
|
|
"""
|
|
return db.query(AudioTrack).filter(AudioTrack.id == track_id).first()
|
|
|
|
|
|
def get_track_by_filepath(db: Session, filepath: str) -> Optional[AudioTrack]:
|
|
"""Get track by filepath.
|
|
|
|
Args:
|
|
db: Database session
|
|
filepath: File path
|
|
|
|
Returns:
|
|
AudioTrack or None if not found
|
|
"""
|
|
return db.query(AudioTrack).filter(AudioTrack.filepath == filepath).first()
|
|
|
|
|
|
def get_tracks(
|
|
db: Session,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
genre: Optional[str] = None,
|
|
mood: Optional[str] = None,
|
|
bpm_min: Optional[float] = None,
|
|
bpm_max: Optional[float] = None,
|
|
energy_min: Optional[float] = None,
|
|
energy_max: Optional[float] = None,
|
|
has_vocals: Optional[bool] = None,
|
|
key: Optional[str] = None,
|
|
instrument: Optional[str] = None,
|
|
sort_by: str = "analyzed_at",
|
|
sort_desc: bool = True,
|
|
) -> Tuple[List[AudioTrack], int]:
|
|
"""Get tracks with filters and pagination.
|
|
|
|
Args:
|
|
db: Database session
|
|
skip: Number of records to skip
|
|
limit: Maximum number of records to return
|
|
genre: Filter by genre (searches in genre_primary, supports category matching)
|
|
mood: Filter by mood
|
|
bpm_min: Minimum BPM
|
|
bpm_max: Maximum BPM
|
|
energy_min: Minimum energy (0-1)
|
|
energy_max: Maximum energy (0-1)
|
|
has_vocals: Filter by vocal presence
|
|
key: Filter by musical key
|
|
instrument: Filter by instrument
|
|
sort_by: Field to sort by
|
|
sort_desc: Sort descending if True
|
|
|
|
Returns:
|
|
Tuple of (tracks list, total count)
|
|
"""
|
|
query = db.query(AudioTrack)
|
|
|
|
# Apply filters
|
|
if genre:
|
|
# Match genre category (e.g., "Pop" matches "Pop---Ballad", "Pop---Indie Pop", etc.)
|
|
query = query.filter(
|
|
or_(
|
|
AudioTrack.genre_primary.like(f"{genre}%"),
|
|
AudioTrack.genre_primary == genre,
|
|
AudioTrack.genre_secondary.any(genre)
|
|
)
|
|
)
|
|
|
|
if mood:
|
|
query = query.filter(
|
|
or_(
|
|
AudioTrack.mood_primary == mood,
|
|
AudioTrack.mood_secondary.any(mood)
|
|
)
|
|
)
|
|
|
|
if bpm_min is not None:
|
|
query = query.filter(AudioTrack.tempo_bpm >= bpm_min)
|
|
|
|
if bpm_max is not None:
|
|
query = query.filter(AudioTrack.tempo_bpm <= bpm_max)
|
|
|
|
if energy_min is not None:
|
|
query = query.filter(AudioTrack.energy >= energy_min)
|
|
|
|
if energy_max is not None:
|
|
query = query.filter(AudioTrack.energy <= energy_max)
|
|
|
|
if has_vocals is not None:
|
|
query = query.filter(AudioTrack.has_vocals == has_vocals)
|
|
|
|
if key:
|
|
query = query.filter(AudioTrack.key == key)
|
|
|
|
if instrument:
|
|
query = query.filter(AudioTrack.instruments.any(instrument))
|
|
|
|
# Get total count before pagination
|
|
total = query.count()
|
|
|
|
# Apply sorting
|
|
if hasattr(AudioTrack, sort_by):
|
|
sort_column = getattr(AudioTrack, sort_by)
|
|
if sort_desc:
|
|
query = query.order_by(sort_column.desc())
|
|
else:
|
|
query = query.order_by(sort_column.asc())
|
|
|
|
# Apply pagination
|
|
tracks = query.offset(skip).limit(limit).all()
|
|
|
|
return tracks, total
|
|
|
|
|
|
def search_tracks(
|
|
db: Session,
|
|
query: str,
|
|
genre: Optional[str] = None,
|
|
mood: Optional[str] = None,
|
|
limit: int = 100,
|
|
) -> List[AudioTrack]:
|
|
"""Search tracks by text query.
|
|
|
|
Args:
|
|
db: Database session
|
|
query: Search query string
|
|
genre: Optional genre filter
|
|
mood: Optional mood filter
|
|
limit: Maximum results
|
|
|
|
Returns:
|
|
List of matching AudioTrack instances
|
|
"""
|
|
search_query = db.query(AudioTrack)
|
|
|
|
# Text search on multiple fields
|
|
search_term = f"%{query.lower()}%"
|
|
search_query = search_query.filter(
|
|
or_(
|
|
func.lower(AudioTrack.filename).like(search_term),
|
|
func.lower(AudioTrack.genre_primary).like(search_term),
|
|
func.lower(AudioTrack.mood_primary).like(search_term),
|
|
AudioTrack.instruments.op('&&')(f'{{{query.lower()}}}'), # Array overlap
|
|
)
|
|
)
|
|
|
|
# Apply additional filters
|
|
if genre:
|
|
search_query = search_query.filter(
|
|
or_(
|
|
AudioTrack.genre_primary == genre,
|
|
AudioTrack.genre_secondary.any(genre)
|
|
)
|
|
)
|
|
|
|
if mood:
|
|
search_query = search_query.filter(
|
|
or_(
|
|
AudioTrack.mood_primary == mood,
|
|
AudioTrack.mood_secondary.any(mood)
|
|
)
|
|
)
|
|
|
|
# Order by relevance (simple: by filename match first)
|
|
search_query = search_query.order_by(AudioTrack.analyzed_at.desc())
|
|
|
|
return search_query.limit(limit).all()
|
|
|
|
|
|
def get_similar_tracks(
|
|
db: Session,
|
|
track_id: UUID,
|
|
limit: int = 10,
|
|
) -> List[AudioTrack]:
|
|
"""Get tracks similar to the given track.
|
|
|
|
Args:
|
|
db: Database session
|
|
track_id: Reference track ID
|
|
limit: Maximum results
|
|
|
|
Returns:
|
|
List of similar AudioTrack instances
|
|
|
|
Note:
|
|
If embeddings are available, uses vector similarity.
|
|
Otherwise, falls back to genre + mood + BPM similarity.
|
|
"""
|
|
# Get reference track
|
|
ref_track = get_track_by_id(db, track_id)
|
|
if not ref_track:
|
|
return []
|
|
|
|
# TODO: Implement vector similarity when embeddings are available
|
|
# For now, use genre + mood + BPM similarity
|
|
|
|
query = db.query(AudioTrack).filter(AudioTrack.id != track_id)
|
|
|
|
# Same genre (primary or secondary)
|
|
if ref_track.genre_primary:
|
|
query = query.filter(
|
|
or_(
|
|
AudioTrack.genre_primary == ref_track.genre_primary,
|
|
AudioTrack.genre_secondary.any(ref_track.genre_primary)
|
|
)
|
|
)
|
|
|
|
# Similar mood
|
|
if ref_track.mood_primary:
|
|
query = query.filter(
|
|
or_(
|
|
AudioTrack.mood_primary == ref_track.mood_primary,
|
|
AudioTrack.mood_secondary.any(ref_track.mood_primary)
|
|
)
|
|
)
|
|
|
|
# Similar BPM (±10%)
|
|
if ref_track.tempo_bpm:
|
|
bpm_range = ref_track.tempo_bpm * 0.1
|
|
query = query.filter(
|
|
and_(
|
|
AudioTrack.tempo_bpm >= ref_track.tempo_bpm - bpm_range,
|
|
AudioTrack.tempo_bpm <= ref_track.tempo_bpm + bpm_range,
|
|
)
|
|
)
|
|
|
|
# Order by analyzed_at (could be improved with similarity score)
|
|
query = query.order_by(AudioTrack.analyzed_at.desc())
|
|
|
|
return query.limit(limit).all()
|
|
|
|
|
|
def delete_track(db: Session, track_id: UUID) -> bool:
|
|
"""Delete a track.
|
|
|
|
Args:
|
|
db: Database session
|
|
track_id: Track UUID
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
track = get_track_by_id(db, track_id)
|
|
if not track:
|
|
return False
|
|
|
|
db.delete(track)
|
|
db.commit()
|
|
|
|
logger.info(f"Deleted track: {track_id}")
|
|
return True
|
|
|
|
|
|
def get_stats(db: Session) -> Dict:
|
|
"""Get database statistics.
|
|
|
|
Args:
|
|
db: Database session
|
|
|
|
Returns:
|
|
Dictionary with statistics
|
|
"""
|
|
total_tracks = db.query(func.count(AudioTrack.id)).scalar()
|
|
|
|
# Genre distribution
|
|
genre_counts = (
|
|
db.query(AudioTrack.genre_primary, func.count(AudioTrack.id))
|
|
.filter(AudioTrack.genre_primary.isnot(None))
|
|
.group_by(AudioTrack.genre_primary)
|
|
.order_by(func.count(AudioTrack.id).desc())
|
|
.limit(10)
|
|
.all()
|
|
)
|
|
|
|
# Mood distribution
|
|
mood_counts = (
|
|
db.query(AudioTrack.mood_primary, func.count(AudioTrack.id))
|
|
.filter(AudioTrack.mood_primary.isnot(None))
|
|
.group_by(AudioTrack.mood_primary)
|
|
.order_by(func.count(AudioTrack.id).desc())
|
|
.limit(10)
|
|
.all()
|
|
)
|
|
|
|
# Average BPM
|
|
avg_bpm = db.query(func.avg(AudioTrack.tempo_bpm)).scalar()
|
|
|
|
# Total duration
|
|
total_duration = db.query(func.sum(AudioTrack.duration_seconds)).scalar()
|
|
|
|
return {
|
|
"total_tracks": total_tracks or 0,
|
|
"genres": [{"genre": g, "count": c} for g, c in genre_counts],
|
|
"moods": [{"mood": m, "count": c} for m, c in mood_counts],
|
|
"average_bpm": round(float(avg_bpm), 1) if avg_bpm else 0.0,
|
|
"total_duration_hours": round(float(total_duration) / 3600, 1) if total_duration else 0.0,
|
|
}
|
|
|
|
|
|
def upsert_track(db: Session, analysis: AudioAnalysis) -> AudioTrack:
|
|
"""Create or update track (based on filepath).
|
|
|
|
Args:
|
|
db: Database session
|
|
analysis: AudioAnalysis object
|
|
|
|
Returns:
|
|
AudioTrack instance
|
|
"""
|
|
# Check if track already exists
|
|
existing_track = get_track_by_filepath(db, analysis.filepath)
|
|
|
|
if existing_track:
|
|
# Update existing track
|
|
for key, value in analysis.dict(exclude={'filepath'}).items():
|
|
setattr(existing_track, key, value)
|
|
|
|
db.commit()
|
|
db.refresh(existing_track)
|
|
|
|
logger.info(f"Updated track: {existing_track.id} - {existing_track.filename}")
|
|
return existing_track
|
|
|
|
else:
|
|
# Create new track
|
|
return create_track(db, analysis)
|