Created
April 13, 2023 21:01
-
-
Save betodealmeida/588289331d49add3dbb937e419d2e8f6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import Any, Dict, Generic, List | |
from flask_appbuilder import Model | |
from marshmallow import Schema, fields | |
from superset import security_manager | |
from superset.connectors.sqla.models import SqlaTable | |
from superset.datasets.dao import DatasetDAO | |
from superset.datasets.commands.exceptions import ( | |
DatasetForbiddenError, | |
DatasetNotFoundError, | |
DatasetRefreshFailedError, | |
) | |
from superset.exceptions import SupersetSecurityException | |
logger = logging.getLogger(__name__) | |
class CommandSchema(Schema): | |
number = fields.Integer() | |
class NewCommand(Generic[Model]): | |
schema = CommandSchema() | |
def __init__(self, deserialized: Dict[str, Any]): | |
self.models = self.load_models(deserialized) | |
@classmethod | |
def from_serialized(self, serialized: Dict[str, Any]) -> 'NewCommand': | |
""" | |
Instantiate the command from a serialized dictionary. | |
""" | |
deserialized = self.schema.load(serialized) | |
return NewCommand.from_deserialized(deserialized) | |
@classmethod | |
def from_deserialized(self, deserialized: Dict[str, Any]) -> 'NewCommand': | |
""" | |
Instantiate the command from a deserialized dictionary. | |
This is useful when the command is called from an API endpoint, since | |
the API can deserialize the request body into a dictionary and call the | |
Marshmallow schema validation. | |
""" | |
return NewCommand(deserialized) | |
def load_models(self, deserialized: Dict[str, Any]) -> List[Model]: | |
""" | |
Load models from the deserialized dictionary. | |
Even the command works on a single model it should return a list with | |
a single element. | |
""" | |
raise NotImplementedError("Subclasses must implement load_models") | |
def run(self) -> Any: | |
raise NotImplementedError("Subclasses must implement run") | |
class RefreshDatasetSchema(Schema): | |
pk = fields.Integer(required=True) | |
class RefreshDatasetCommand(NewCommand[SqlaTable]): | |
schema = RefreshDatasetSchema() | |
def load_models(self, deserialized: Dict[str, Any]) -> List[SqlaTable]: | |
model = DatasetDAO.find_by_id(deserialized['pk']) | |
if not model: | |
raise DatasetNotFoundError() | |
# check ownership | |
try: | |
security_manager.raise_for_ownership(model) | |
except SupersetSecurityException as ex: | |
raise DatasetForbiddenError() from ex | |
return [model] | |
def run(self) -> SqlaTable: | |
model = self.models[0] | |
try: | |
model.fetch_metadata() | |
return model | |
except Exception as ex: | |
logger.exception( | |
"An error occurred while fetching dataset metadata" | |
) | |
raise DatasetRefreshFailedError() from ex |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment