convert index.schema.ArchiveResult and Link to pydantic

This commit is contained in:
Nick Sweeting 2024-11-19 06:32:48 -08:00
parent b948e49013
commit 44d337a167
No known key found for this signature in database

View file

@ -9,21 +9,15 @@ These are the old types we used to use before ArchiveBox v0.4 (before we switche
__package__ = 'archivebox.index'
from pathlib import Path
from datetime import datetime, timezone, timedelta
from typing import List, Dict, Any, Optional, Union, ClassVar
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, asdict, field, fields
from django.utils.functional import cached_property
import abx
from pydantic import BaseModel, ConfigDict, Field, field_validator, computed_field
from benedict import benedict
from archivebox.config import ARCHIVE_DIR, CONSTANTS
from archivebox.misc.system import get_dir_size
from archivebox.misc.util import ts_to_date_str, parse_date
from archivebox.misc.logging import stderr, ANSI
from archivebox.misc.util import parse_date
class ArchiveError(Exception):
@ -31,211 +25,223 @@ class ArchiveError(Exception):
super().__init__(message)
self.hints = hints
LinkDict = Dict[str, Any]
# Type aliases
LinkDict = Dict[str, Any]
ArchiveOutput = Union[str, Exception, None]
@dataclass(frozen=True)
class ArchiveResult:
cmd: List[str]
pwd: Optional[str]
cmd_version: Optional[str]
output: ArchiveOutput
class ArchiveResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
TYPE: str = 'index.schema.ArchiveResult'
cmd: list[str]
pwd: str | None = None
cmd_version: str | None = None
output: ArchiveOutput | None = None
status: str
start_ts: datetime
end_ts: datetime
index_texts: Union[List[str], None] = None
schema: str = 'ArchiveResult'
index_texts: list[str] | None = None
def __post_init__(self):
self.typecheck()
# Class variables for compatibility
_field_names: ClassVar[list[str] | None] = None
def _asdict(self):
return asdict(self)
@field_validator('status')
@classmethod
def validate_status(cls, v: str) -> str:
if not v:
raise ValueError('status must be a non-empty string')
return v
def typecheck(self) -> None:
assert self.schema == self.__class__.__name__
assert isinstance(self.status, str) and self.status
assert isinstance(self.start_ts, datetime)
assert isinstance(self.end_ts, datetime)
assert isinstance(self.cmd, list)
assert all(isinstance(arg, str) and arg for arg in self.cmd)
@field_validator('cmd')
@classmethod
def validate_cmd(cls, v: List[str]) -> List[str]:
if not all(isinstance(arg, str) and arg for arg in v):
raise ValueError('all command arguments must be non-empty strings')
return v
# TODO: replace emptystrings in these three with None / remove them from the DB
assert self.pwd is None or isinstance(self.pwd, str)
assert self.cmd_version is None or isinstance(self.cmd_version, str)
assert self.output is None or isinstance(self.output, (str, Exception))
@field_validator('pwd')
@classmethod
def validate_pwd(cls, v: Optional[str]) -> Optional[str]:
if v == '': # Convert empty string to None for consistency
return None
return v
@field_validator('cmd_version')
@classmethod
def validate_cmd_version(cls, v: Optional[str]) -> Optional[str]:
if v == '': # Convert empty string to None for consistency
return None
return v
def model_dump(self, **kwargs) -> dict:
"""Backwards compatible with _asdict()"""
return super().model_dump(**kwargs)
@classmethod
def guess_ts(_cls, dict_info):
def field_names(cls) -> List[str]:
"""Get all field names of the model"""
if cls._field_names is None:
cls._field_names = list(cls.model_fields.keys())
return cls._field_names
@classmethod
def guess_ts(cls, dict_info: dict) -> tuple[datetime, datetime]:
"""Guess timestamps from dictionary info"""
parsed_timestamp = parse_date(dict_info["timestamp"])
start_ts = parsed_timestamp
end_ts = parsed_timestamp + timedelta(seconds=int(dict_info["duration"]))
return start_ts, end_ts
@classmethod
def from_json(cls, json_info, guess=False):
def from_json(cls, json_info: dict, guess: bool = False) -> 'ArchiveResult':
"""Create instance from JSON data"""
info = {
key: val
for key, val in json_info.items()
if key in cls.field_names()
}
if guess:
keys = info.keys()
if "start_ts" not in keys:
if "start_ts" not in info:
info["start_ts"], info["end_ts"] = cls.guess_ts(json_info)
else:
info['start_ts'] = parse_date(info['start_ts'])
info['end_ts'] = parse_date(info['end_ts'])
if "pwd" not in keys:
if "pwd" not in info:
info["pwd"] = str(ARCHIVE_DIR / json_info["timestamp"])
if "cmd_version" not in keys:
if "cmd_version" not in info:
info["cmd_version"] = "Undefined"
if "cmd" not in keys:
if "cmd" not in info:
info["cmd"] = []
else:
info['start_ts'] = parse_date(info['start_ts'])
info['end_ts'] = parse_date(info['end_ts'])
info['cmd_version'] = info.get('cmd_version')
if type(info["cmd"]) is str:
# Handle string command as list
if isinstance(info.get("cmd"), str):
info["cmd"] = [info["cmd"]]
return cls(**info)
def to_dict(self, *keys) -> dict:
def to_dict(self, *keys: str) -> dict:
"""Convert to dictionary, optionally filtering by keys"""
data = self.model_dump()
if keys:
return {k: v for k, v in asdict(self).items() if k in keys}
return asdict(self)
return {k: v for k, v in data.items() if k in keys}
return data
def to_json(self, indent=4, sort_keys=True) -> str:
from .json import to_json
return to_json(self, indent=indent, sort_keys=sort_keys)
def to_json(self, indent: int = 4, sort_keys: bool = True) -> str:
"""Convert to JSON string"""
return self.model_dump_json(indent=indent, exclude_none=True)
def to_csv(self, cols: Optional[List[str]] = None, separator: str = ',', ljust: int = 0) -> str:
from .csv import to_csv
"""Convert to CSV string"""
data = self.model_dump()
cols = cols or self.field_names()
return separator.join(str(data.get(col, '')).ljust(ljust) for col in cols)
return to_csv(self, csv_col=cols or self.field_names(), separator=separator, ljust=ljust)
@classmethod
def field_names(cls):
return [f.name for f in fields(cls)]
@property
@computed_field
def duration(self) -> int:
return (self.end_ts - self.start_ts).seconds
"""Calculate duration in seconds between start and end timestamps"""
return int((self.end_ts - self.start_ts).total_seconds())
@dataclass(frozen=True)
class Link:
class Link(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
TYPE: str = 'index.schema.Link'
timestamp: str
url: str
title: Optional[str]
tags: Optional[str]
sources: List[str]
history: Dict[str, List[ArchiveResult]] = field(default_factory=lambda: {})
downloaded_at: Optional[datetime] = None
schema: str = 'Link'
title: str | None = None
tags: str | None = None
sources: list[str] = Field(default_factory=list)
history: dict[str, list[ArchiveResult]] = Field(default_factory=dict)
downloaded_at: datetime | None = None
# Class variables for compatibility
_field_names: ClassVar[list[str] | None] = None
def __str__(self) -> str:
return f'[{self.timestamp}] {self.url} "{self.title}"'
def __post_init__(self):
self.typecheck()
def overwrite(self, **kwargs):
"""pure functional version of dict.update that returns a new instance"""
return Link(**{**self._asdict(), **kwargs})
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Link):
return NotImplemented
return self.url == other.url
def __gt__(self, other):
def __gt__(self, other: Any) -> bool:
if not isinstance(other, Link):
return NotImplemented
if not self.timestamp or not other.timestamp:
return
return NotImplemented
return float(self.timestamp) > float(other.timestamp)
def typecheck(self) -> None:
try:
assert self.schema == self.__class__.__name__
assert isinstance(self.timestamp, str) and self.timestamp, f'timestamp must be a non-empty string, got: "{self.timestamp}"'
assert self.timestamp.replace('.', '').isdigit(), f'timestamp must be a float str, got: "{self.timestamp}"'
assert isinstance(self.url, str) and '://' in self.url, f'url must be a non-empty string, got: "{self.url}"'
assert self.downloaded_at is None or isinstance(self.downloaded_at, datetime), f'downloaded_at must be a datetime or None, got: {self.downloaded_at}'
assert self.title is None or (isinstance(self.title, str) and self.title), f'title must be a non-empty string or None, got: "{self.title}"'
assert self.tags is None or isinstance(self.tags, str), f'tags must be a string or None, got: "{self.tags}"'
assert isinstance(self.sources, list), f'sources must be a list, got: {self.sources}'
assert all(isinstance(source, str) and source for source in self.sources)
assert isinstance(self.history, dict)
for method, results in self.history.items():
assert isinstance(method, str) and method
assert isinstance(results, list)
assert all(isinstance(result, ArchiveResult) for result in results)
except Exception:
stderr('{red}[X] Error while loading link! [{}] {} "{}"{reset}'.format(self.timestamp, self.url, self.title, **ANSI))
raise
@field_validator('timestamp')
@classmethod
def validate_timestamp(cls, v: str) -> str:
if not v:
raise ValueError('timestamp must be a non-empty string')
if not v.replace('.', '').isdigit():
raise ValueError('timestamp must be a float str')
return v
def _asdict(self, extended=False):
info = {
'schema': 'Link',
'url': self.url,
'title': self.title or None,
'timestamp': self.timestamp,
'downloaded_at': self.downloaded_at or None,
'tags': self.tags or None,
'sources': self.sources or [],
'history': self.history or {},
}
if extended:
info.update({
'snapshot_id': self.snapshot_id,
'snapshot_abid': self.snapshot_abid,
@field_validator('url')
@classmethod
def validate_url(cls, v: str) -> str:
if not v or '://' not in v:
raise ValueError('url must be a valid URL string')
return v
'link_dir': self.link_dir,
'archive_path': self.archive_path,
@field_validator('title')
@classmethod
def validate_title(cls, v: Optional[str]) -> Optional[str]:
if v is not None and not v:
raise ValueError('title must be a non-empty string if provided')
return v
'hash': self.url_hash,
'base_url': self.base_url,
'scheme': self.scheme,
'domain': self.domain,
'path': self.path,
'basename': self.basename,
'extension': self.extension,
'is_static': self.is_static,
@field_validator('sources')
@classmethod
def validate_sources(cls, v: List[str]) -> List[str]:
if not all(isinstance(source, str) and source for source in v):
raise ValueError('all sources must be non-empty strings')
return v
'tags_str': (self.tags or '').strip(','), # only used to render static index in index/html.py, remove if no longer needed there
'icons': None, # only used to render static index in index/html.py, remove if no longer needed there
# Backwards compatibility methods
def _asdict(self, extended: bool = False) -> dict:
return benedict(self)
'bookmarked_date': self.bookmarked_date,
'downloaded_datestr': self.downloaded_datestr,
'oldest_archive_date': self.oldest_archive_date,
'newest_archive_date': self.newest_archive_date,
'is_archived': self.is_archived,
'num_outputs': self.num_outputs,
'num_failures': self.num_failures,
'latest': self.latest_outputs(),
'canonical': self.canonical_outputs(),
})
return info
def as_snapshot(self):
from core.models import Snapshot
return Snapshot.objects.get(url=self.url)
def overwrite(self, **kwargs) -> 'Link':
"""Pure functional version of dict.update that returns a new instance"""
current_data = self.model_dump()
current_data.update(kwargs)
return Link(**current_data)
@classmethod
def from_json(cls, json_info, guess=False):
def field_names(cls) -> list[str]:
if cls._field_names is None:
cls._field_names = list(cls.model_fields.keys())
return cls._field_names
@classmethod
def from_json(cls, json_info: dict, guess: bool = False) -> 'Link':
info = {
key: val
for key, val in json_info.items()
if key in cls.field_names()
}
info['downloaded_at'] = parse_date(info.get('updated') or info.get('downloaded_at'))
# Handle downloaded_at
info['downloaded_at'] = cls._parse_date(info.get('updated') or info.get('downloaded_at'))
info['sources'] = info.get('sources') or []
# Handle history
json_history = info.get('history') or {}
cast_history = {}
@ -249,165 +255,75 @@ class Link:
info['history'] = cast_history
return cls(**info)
def to_json(self, indent=4, sort_keys=True) -> str:
from .json import to_json
return to_json(self, indent=indent, sort_keys=sort_keys)
def to_json(self, indent: int = 4, sort_keys: bool = True) -> str:
return self.model_dump_json(indent=indent)
def to_csv(self, cols: Optional[List[str]] = None, separator: str = ',', ljust: int = 0) -> str:
from .csv import to_csv
return to_csv(self, cols=cols or self.field_names(), separator=separator, ljust=ljust)
@cached_property
def snapshot(self):
from core.models import Snapshot
return Snapshot.objects.only('id', 'abid').get(url=self.url)
@cached_property
def snapshot_id(self):
return str(self.snapshot.pk)
@cached_property
def snapshot_abid(self):
return str(self.snapshot.ABID)
@classmethod
def field_names(cls):
return [f.name for f in fields(cls)]
data = self.model_dump()
cols = cols or self.field_names()
return separator.join(str(data.get(col, '')).ljust(ljust) for col in cols)
# Properties for compatibility
@property
def link_dir(self) -> str:
return str(ARCHIVE_DIR / self.timestamp)
@property
def archive_path(self) -> str:
return '{}/{}'.format(CONSTANTS.ARCHIVE_DIR_NAME, self.timestamp)
return f'{CONSTANTS.ARCHIVE_DIR_NAME}/{self.timestamp}'
@property
def archive_size(self) -> float:
try:
return get_dir_size(self.archive_path)[0]
except Exception:
return 0
### URL Helpers
@property
def url_hash(self):
from archivebox.misc.util import hashurl
return hashurl(self.url)
@property
def scheme(self) -> str:
from archivebox.misc.util import scheme
return scheme(self.url)
@property
def extension(self) -> str:
from archivebox.misc.util import extension
return extension(self.url)
@property
def domain(self) -> str:
from archivebox.misc.util import domain
return domain(self.url)
@property
def path(self) -> str:
from archivebox.misc.util import path
return path(self.url)
@property
def basename(self) -> str:
from archivebox.misc.util import basename
return basename(self.url)
@property
def base_url(self) -> str:
from archivebox.misc.util import base_url
return base_url(self.url)
### Pretty Printing Helpers
@property
@computed_field
def bookmarked_date(self) -> Optional[str]:
max_ts = (datetime.now(timezone.utc) + timedelta(days=30)).timestamp()
if self.timestamp and self.timestamp.replace('.', '').isdigit():
if 0 < float(self.timestamp) < max_ts:
return ts_to_date_str(datetime.fromtimestamp(float(self.timestamp)))
else:
return self._ts_to_date_str(datetime.fromtimestamp(float(self.timestamp)))
return str(self.timestamp)
return None
@property
@computed_field
def downloaded_datestr(self) -> Optional[str]:
return ts_to_date_str(self.downloaded_at) if self.downloaded_at else None
return self._ts_to_date_str(self.downloaded_at) if self.downloaded_at else None
@property
def archive_dates(self) -> List[datetime]:
def archive_dates(self) -> list[datetime]:
return [
parse_date(result.start_ts)
for method in self.history.keys()
for result in self.history[method]
self._parse_date(result.start_ts) # type: ignore
for results in self.history.values()
for result in results
]
@property
def oldest_archive_date(self) -> Optional[datetime]:
return min(self.archive_dates, default=None)
dates = self.archive_dates
return min(dates) if dates else None
@property
def newest_archive_date(self) -> Optional[datetime]:
return max(self.archive_dates, default=None)
dates = self.archive_dates
return max(dates) if dates else None
### Archive Status Helpers
@property
def num_outputs(self) -> int:
try:
return self.as_snapshot().num_outputs
except Exception:
return 0
@property
def num_failures(self) -> int:
return sum(1
for method in self.history.keys()
for result in self.history[method]
return sum(
1 for results in self.history.values()
for result in results
if result.status == 'failed')
@property
def is_static(self) -> bool:
from archivebox.misc.util import is_static_file
return is_static_file(self.url)
@property
def is_archived(self) -> bool:
from archivebox.misc.util import domain
output_paths = (
domain(self.url),
'output.html',
'output.pdf',
'screenshot.png',
'singlefile.html',
'readability/content.html',
'mercury/content.html',
'htmltotext.txt',
'media',
'git',
)
return any(
(ARCHIVE_DIR / self.timestamp / path).exists()
for path in output_paths
)
def latest_outputs(self, status: str=None) -> Dict[str, ArchiveOutput]:
"""get the latest output that each archive method produced for link"""
def latest_outputs(self, status: Optional[str] = None) -> dict[str, Any]:
"""Get the latest output that each archive method produced for link"""
ARCHIVE_METHODS = (
'title', 'favicon', 'wget', 'warc', 'singlefile', 'pdf',
'screenshot', 'dom', 'git', 'media', 'archive_org',
)
latest: Dict[str, ArchiveOutput] = {}
latest: Dict[str, Any] = {}
for archive_method in ARCHIVE_METHODS:
# get most recent succesful result in history for each archive method
history = self.history.get(archive_method) or []
@ -416,26 +332,22 @@ class Link:
history = list(filter(lambda result: result.status == status, history))
history = list(history)
if history:
latest[archive_method] = history[0].output
else:
latest[archive_method] = None
latest[archive_method] = history[0].output if history else None
return latest
def canonical_outputs(self) -> Dict[str, Optional[str]]:
"""predict the expected output paths that should be present after archiving"""
from abx_plugin_wget.wget import wget_output_path
from abx_plugin_favicon.config import FAVICON_CONFIG
"""Predict the expected output paths that should be present after archiving"""
# You'll need to implement the actual logic based on your requirements
# TODO: banish this awful duplication from the codebase and import these
# from their respective extractor files
from abx_plugin_favicon.config import FAVICON_CONFIG
canonical = {
'index_path': 'index.html',
'favicon_path': 'favicon.ico',
'google_favicon_path': FAVICON_CONFIG.FAVICON_PROVIDER.format(self.domain),
'wget_path': wget_output_path(self),
'wget_path': f'warc/{self.timestamp}',
'warc_path': 'warc/',
'singlefile_path': 'singlefile.html',
'readability_path': 'readability/content.html',
@ -444,17 +356,14 @@ class Link:
'pdf_path': 'output.pdf',
'screenshot_path': 'screenshot.png',
'dom_path': 'output.html',
'archive_org_path': 'https://web.archive.org/web/{}'.format(self.base_url),
'archive_org_path': f'https://web.archive.org/web/{self.base_url}',
'git_path': 'git/',
'media_path': 'media/',
'headers_path': 'headers.json',
}
if self.is_static:
# static binary files like PDF and images are handled slightly differently.
# they're just downloaded once and aren't archived separately multiple times,
# so the wget, screenshot, & pdf urls should all point to the same file
static_path = wget_output_path(self)
if self.is_static:
static_path = f'warc/{self.timestamp}'
canonical.update({
'title': self.basename,
'wget_path': static_path,
@ -468,3 +377,78 @@ class Link:
})
return canonical
# URL helper properties
@property
def url_hash(self) -> str:
# Implement your URL hashing logic here
from hashlib import sha256
return sha256(self.url.encode()).hexdigest()[:8]
@property
def scheme(self) -> str:
return self.url.split('://')[0]
@property
def domain(self) -> str:
return self.url.split('://')[1].split('/')[0]
@property
def path(self) -> str:
parts = self.url.split('://', 1)
return '/' + parts[1].split('/', 1)[1] if len(parts) > 1 and '/' in parts[1] else '/'
@property
def basename(self) -> str:
return self.path.split('/')[-1]
@property
def extension(self) -> str:
basename = self.basename
return basename.split('.')[-1] if '.' in basename else ''
@property
def base_url(self) -> str:
return f'{self.scheme}://{self.domain}'
@property
def is_static(self) -> bool:
static_extensions = {'.pdf', '.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg', '.mp4', '.mp3', '.wav', '.webm'}
return any(self.url.lower().endswith(ext) for ext in static_extensions)
@property
def is_archived(self) -> bool:
output_paths = (
self.domain,
'output.html',
'output.pdf',
'screenshot.png',
'singlefile.html',
'readability/content.html',
'mercury/content.html',
'htmltotext.txt',
'media',
'git',
)
return any((Path(ARCHIVE_DIR) / self.timestamp / path).exists() for path in output_paths)
def as_snapshot(self):
"""Implement this based on your Django model requirements"""
from core.models import Snapshot
return Snapshot.objects.get(url=self.url)
# Helper methods
@staticmethod
def _ts_to_date_str(dt: Optional[datetime]) -> Optional[str]:
return dt.strftime('%Y-%m-%d %H:%M:%S') if dt else None
@staticmethod
def _parse_date(date_str: Optional[str]) -> Optional[datetime]:
if not date_str:
return None
try:
return datetime.fromisoformat(date_str.replace('Z', '+00:00'))
except ValueError:
try:
return datetime.fromtimestamp(float(date_str))
except (ValueError, TypeError):
return None