# -*- coding: utf-8 -*-
# This file is part of HEPData.
# Copyright (C) 2016 CERN.
# HEPData is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; either version 2 of the
# License, or (at your option) any later version.
# HEPData is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with HEPData; if not, write to the
# Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
# MA 02111-1307, USA.
# In applying this license, CERN does not
# waive the privileges and immunities granted to it by virtue of its status
# as an Intergovernmental Organization or submit itself to any jurisdiction.
from invenio_db import db
from invenio_pidstore.errors import PIDDoesNotExistError
from invenio_pidstore.resolver import Resolver
from invenio_records.api import Record
import os
from sqlalchemy.orm.exc import NoResultFound

from hepdata.config import CFG_PUB_TYPE
from hepdata.ext.elasticsearch.api import get_record
from hepdata.modules.submission.models import HEPSubmission, License

    "py": "Python",
    "c": "C",
    "cpp": "C++",
    "cxx": "C++",
    "cc": "C++",
    "C": "C++",
    "sh": "Bash Shell",
    "pl": "Perl",
    "cs": "C# Source Code",
    "java": "Java",
    "root": "ROOT",
    "json": "JSON",
    "yaml": "YAML",
    "txt": "Text",
    "rtf": "Text",
    "xls": "Excel",
    "xlsx": "Excel",
    "slha": "SUSY Les Houches Accord",
    "f": "Fortran"



ALLOWED_EXTENSIONS = ('.zip', '.tar', '.tar.gz', '.tgz', '.oldhepdata', '.yaml', '.yaml.gz')

[docs]def contains_accepted_url(file): for pattern in URL_PATTERNS: if pattern in file: return True, pattern return False, None
[docs]def allowed_file(filename): return filename.endswith(ALLOWED_EXTENSIONS)
[docs]def is_image(filename): if '.' in filename: extension = filename.rsplit(".", 1)[1] return extension.lower() in IMAGE_TYPES return False
[docs]def infer_file_type(file): if "." in file: result, pattern = contains_accepted_url(file) if result: return pattern else: extension = file.rsplit(".", 1)[1] if extension in FILE_TYPES: return FILE_TYPES[extension] else: return extension else: return "resource"
[docs]def get_or_create(session, model, **kwargs): instance = session.query(model).filter_by(**kwargs).first() if instance: return instance else: instance = model(**kwargs) session.add(instance) session.commit() return instance
[docs]def remove_file_extension(filename): if "." in filename: return filename.rsplit('.', 1)[0] else: return filename
[docs]def encode_string(string, type="utf-8"): try: return string.encode(type, errors='replace') except AttributeError: return string
[docs]def decode_string(string, type="utf-8"): try: return string.decode(type, errors='replace') except AttributeError: return string except UnicodeEncodeError: return string
[docs]def get_license(license_obj): dict = {} for field in ["name", "url", "description"]: dict[field] = "" if field in license_obj: dict[field] = license_obj[field] return get_or_create( db.session, License, name=dict['name'], url=dict['url'], description=dict['description'])
[docs]def find_file_in_directory(directory, file_predicate): """ Finds a file in a directory. Useful for say when the submission.yaml file is not at the top level of the unzipped archive but one or more levels below. :param directory: :param file_predicate: a lambda that checks if it's the file you're looking for :return: """ for root, dirs, files in os.walk(directory): for filename in files: if file_predicate(filename): return root, os.path.join(root, filename) return None
[docs]def default_time(obj): """Default JSON serializer.""" import datetime fmt = '%Y-%m-%d %H:%M:%S %Z' if isinstance(obj, datetime.datetime): return obj.strftime(fmt) return obj
[docs]def truncate_string(string, words): all_words = string.split() truncated_string = ' '.join(all_words[:words]) if len(all_words) > words: truncated_string += '...' return truncated_string
[docs]def get_record_contents(recid, status=None): """ Tries to get record from Elasticsearch first. Failing that, it tries from the database. :param recid: Record ID to get. :param status: Status of submission. If provided and not 'finished', will not check elasticsearch first. :return: a dictionary containing the record contents if the recid exists, None otherwise. """ record = None if status is None or status == 'finished': record = get_record(recid) if record is None: try: record = get_record_by_id(recid) except PIDDoesNotExistError: return None return record
[docs]def get_record_by_id(recid): try: resolver = Resolver(pid_type='recid', object_type='rec', getter=Record.get_record) pid, record = resolver.resolve(recid) return record except NoResultFound: print('No record found for recid {}'.format(recid)) return None except PIDDoesNotExistError: print('The PID {0} does not exist'.format(recid)) return None
[docs]def record_exists(*args, **kwargs): count = HEPSubmission.query.filter_by(**kwargs).count() return count > 0