Files
Audio-Classifier/backend/src/models/crud.py
2025-12-22 15:53:26 +01:00

403 lines
11 KiB
Python

"""CRUD operations for audio tracks."""
from typing import List, Optional, Dict, Tuple
from uuid import UUID
from sqlalchemy.orm import Session
from sqlalchemy import or_, and_, func, any_
from .schema import AudioTrack
from ..core.analyzer import AudioAnalysis
from ..utils.logging import get_logger
logger = get_logger(__name__)
def create_track(db: Session, analysis: AudioAnalysis) -> AudioTrack:
"""Create a new track from analysis data.
Args:
db: Database session
analysis: AudioAnalysis object
Returns:
Created AudioTrack instance
"""
track = AudioTrack(
filepath=analysis.filepath,
filename=analysis.filename,
duration_seconds=analysis.duration_seconds,
file_size_bytes=analysis.file_size_bytes,
format=analysis.format,
analyzed_at=analysis.analyzed_at,
# Features
tempo_bpm=analysis.tempo_bpm,
key=analysis.key,
time_signature=analysis.time_signature,
energy=analysis.energy,
danceability=analysis.danceability,
valence=analysis.valence,
loudness_lufs=analysis.loudness_lufs,
spectral_centroid=analysis.spectral_centroid,
zero_crossing_rate=analysis.zero_crossing_rate,
# Classification
genre_primary=analysis.genre_primary,
genre_secondary=analysis.genre_secondary,
genre_confidence=analysis.genre_confidence,
mood_primary=analysis.mood_primary,
mood_secondary=analysis.mood_secondary,
mood_arousal=analysis.mood_arousal,
mood_valence=analysis.mood_valence,
instruments=analysis.instruments,
# Vocals
has_vocals=analysis.has_vocals,
vocal_gender=analysis.vocal_gender,
# Metadata
extra_metadata=analysis.metadata,
)
db.add(track)
db.commit()
db.refresh(track)
logger.info(f"Created track: {track.id} - {track.filename}")
return track
def get_track_by_id(db: Session, track_id: UUID) -> Optional[AudioTrack]:
"""Get track by ID.
Args:
db: Database session
track_id: Track UUID
Returns:
AudioTrack or None if not found
"""
return db.query(AudioTrack).filter(AudioTrack.id == track_id).first()
def get_track_by_filepath(db: Session, filepath: str) -> Optional[AudioTrack]:
"""Get track by filepath.
Args:
db: Database session
filepath: File path
Returns:
AudioTrack or None if not found
"""
return db.query(AudioTrack).filter(AudioTrack.filepath == filepath).first()
def get_tracks(
db: Session,
skip: int = 0,
limit: int = 100,
genre: Optional[str] = None,
mood: Optional[str] = None,
bpm_min: Optional[float] = None,
bpm_max: Optional[float] = None,
energy_min: Optional[float] = None,
energy_max: Optional[float] = None,
has_vocals: Optional[bool] = None,
key: Optional[str] = None,
instrument: Optional[str] = None,
sort_by: str = "analyzed_at",
sort_desc: bool = True,
) -> Tuple[List[AudioTrack], int]:
"""Get tracks with filters and pagination.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
genre: Filter by genre (searches in genre_primary, supports category matching)
mood: Filter by mood
bpm_min: Minimum BPM
bpm_max: Maximum BPM
energy_min: Minimum energy (0-1)
energy_max: Maximum energy (0-1)
has_vocals: Filter by vocal presence
key: Filter by musical key
instrument: Filter by instrument
sort_by: Field to sort by
sort_desc: Sort descending if True
Returns:
Tuple of (tracks list, total count)
"""
query = db.query(AudioTrack)
# Apply filters
if genre:
# Match genre category (e.g., "Pop" matches "Pop---Ballad", "Pop---Indie Pop", etc.)
query = query.filter(
or_(
AudioTrack.genre_primary.like(f"{genre}%"),
AudioTrack.genre_primary == genre,
AudioTrack.genre_secondary.any(genre)
)
)
if mood:
query = query.filter(
or_(
AudioTrack.mood_primary == mood,
AudioTrack.mood_secondary.any(mood)
)
)
if bpm_min is not None:
query = query.filter(AudioTrack.tempo_bpm >= bpm_min)
if bpm_max is not None:
query = query.filter(AudioTrack.tempo_bpm <= bpm_max)
if energy_min is not None:
query = query.filter(AudioTrack.energy >= energy_min)
if energy_max is not None:
query = query.filter(AudioTrack.energy <= energy_max)
if has_vocals is not None:
query = query.filter(AudioTrack.has_vocals == has_vocals)
if key:
query = query.filter(AudioTrack.key == key)
if instrument:
query = query.filter(AudioTrack.instruments.any(instrument))
# Get total count before pagination
total = query.count()
# Apply sorting
if hasattr(AudioTrack, sort_by):
sort_column = getattr(AudioTrack, sort_by)
if sort_desc:
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
tracks = query.offset(skip).limit(limit).all()
return tracks, total
def search_tracks(
db: Session,
query: str,
genre: Optional[str] = None,
mood: Optional[str] = None,
limit: int = 100,
) -> List[AudioTrack]:
"""Search tracks by text query.
Args:
db: Database session
query: Search query string
genre: Optional genre filter
mood: Optional mood filter
limit: Maximum results
Returns:
List of matching AudioTrack instances
"""
search_query = db.query(AudioTrack)
# Text search on multiple fields
search_term = f"%{query.lower()}%"
search_query = search_query.filter(
or_(
func.lower(AudioTrack.filename).like(search_term),
func.lower(AudioTrack.genre_primary).like(search_term),
func.lower(AudioTrack.mood_primary).like(search_term),
AudioTrack.instruments.op('&&')(f'{{{query.lower()}}}'), # Array overlap
)
)
# Apply additional filters
if genre:
search_query = search_query.filter(
or_(
AudioTrack.genre_primary == genre,
AudioTrack.genre_secondary.any(genre)
)
)
if mood:
search_query = search_query.filter(
or_(
AudioTrack.mood_primary == mood,
AudioTrack.mood_secondary.any(mood)
)
)
# Order by relevance (simple: by filename match first)
search_query = search_query.order_by(AudioTrack.analyzed_at.desc())
return search_query.limit(limit).all()
def get_similar_tracks(
db: Session,
track_id: UUID,
limit: int = 10,
) -> List[AudioTrack]:
"""Get tracks similar to the given track.
Args:
db: Database session
track_id: Reference track ID
limit: Maximum results
Returns:
List of similar AudioTrack instances
Note:
If embeddings are available, uses vector similarity.
Otherwise, falls back to genre + mood + BPM similarity.
"""
# Get reference track
ref_track = get_track_by_id(db, track_id)
if not ref_track:
return []
# TODO: Implement vector similarity when embeddings are available
# For now, use genre + mood + BPM similarity
query = db.query(AudioTrack).filter(AudioTrack.id != track_id)
# Same genre (primary or secondary)
if ref_track.genre_primary:
query = query.filter(
or_(
AudioTrack.genre_primary == ref_track.genre_primary,
AudioTrack.genre_secondary.any(ref_track.genre_primary)
)
)
# Similar mood
if ref_track.mood_primary:
query = query.filter(
or_(
AudioTrack.mood_primary == ref_track.mood_primary,
AudioTrack.mood_secondary.any(ref_track.mood_primary)
)
)
# Similar BPM (±10%)
if ref_track.tempo_bpm:
bpm_range = ref_track.tempo_bpm * 0.1
query = query.filter(
and_(
AudioTrack.tempo_bpm >= ref_track.tempo_bpm - bpm_range,
AudioTrack.tempo_bpm <= ref_track.tempo_bpm + bpm_range,
)
)
# Order by analyzed_at (could be improved with similarity score)
query = query.order_by(AudioTrack.analyzed_at.desc())
return query.limit(limit).all()
def delete_track(db: Session, track_id: UUID) -> bool:
"""Delete a track.
Args:
db: Database session
track_id: Track UUID
Returns:
True if deleted, False if not found
"""
track = get_track_by_id(db, track_id)
if not track:
return False
db.delete(track)
db.commit()
logger.info(f"Deleted track: {track_id}")
return True
def get_stats(db: Session) -> Dict:
"""Get database statistics.
Args:
db: Database session
Returns:
Dictionary with statistics
"""
total_tracks = db.query(func.count(AudioTrack.id)).scalar()
# Genre distribution
genre_counts = (
db.query(AudioTrack.genre_primary, func.count(AudioTrack.id))
.filter(AudioTrack.genre_primary.isnot(None))
.group_by(AudioTrack.genre_primary)
.order_by(func.count(AudioTrack.id).desc())
.limit(10)
.all()
)
# Mood distribution
mood_counts = (
db.query(AudioTrack.mood_primary, func.count(AudioTrack.id))
.filter(AudioTrack.mood_primary.isnot(None))
.group_by(AudioTrack.mood_primary)
.order_by(func.count(AudioTrack.id).desc())
.limit(10)
.all()
)
# Average BPM
avg_bpm = db.query(func.avg(AudioTrack.tempo_bpm)).scalar()
# Total duration
total_duration = db.query(func.sum(AudioTrack.duration_seconds)).scalar()
return {
"total_tracks": total_tracks or 0,
"genres": [{"genre": g, "count": c} for g, c in genre_counts],
"moods": [{"mood": m, "count": c} for m, c in mood_counts],
"average_bpm": round(float(avg_bpm), 1) if avg_bpm else 0.0,
"total_duration_hours": round(float(total_duration) / 3600, 1) if total_duration else 0.0,
}
def upsert_track(db: Session, analysis: AudioAnalysis) -> AudioTrack:
"""Create or update track (based on filepath).
Args:
db: Database session
analysis: AudioAnalysis object
Returns:
AudioTrack instance
"""
# Check if track already exists
existing_track = get_track_by_filepath(db, analysis.filepath)
if existing_track:
# Update existing track
for key, value in analysis.dict(exclude={'filepath'}).items():
setattr(existing_track, key, value)
db.commit()
db.refresh(existing_track)
logger.info(f"Updated track: {existing_track.id} - {existing_track.filename}")
return existing_track
else:
# Create new track
return create_track(db, analysis)