Source code for watson_machine_learning_client.training

################################################################################
#
# Licensed Materials - Property of IBM
# (C) Copyright IBM Corp. 2017
# US Government Users Restricted Rights - Use, duplication disclosure restricted
# by GSA ADP Schedule Contract with IBM Corp.
#
################################################################################

from __future__ import print_function
import requests
import json
import re
from watson_machine_learning_client.utils import print_text_header_h1, print_text_header_h2, TRAINING_RUN_DETAILS_TYPE, STR_TYPE, STR_TYPE_NAME, docstring_parameter, str_type_conv, meta_props_str_conv, group_metrics, StatusLogger
import time
from watson_machine_learning_client.metanames import TrainingConfigurationMetaNames
from watson_machine_learning_client.wml_client_error import WMLClientError
from watson_machine_learning_client.href_definitions import is_uid
from watson_machine_learning_client.wml_resource import WMLResource
from ibm_boto3.exceptions import Boto3Error



[docs]class Training(WMLResource): """ Train new models. """ def __init__(self, client): WMLResource.__init__(self, __name__, client) self.ConfigurationMetaNames = TrainingConfigurationMetaNames() @staticmethod def _is_training_uid(s): res = re.match('training\-[a-zA-Z0-9\-\_]+', s) return res is not None @staticmethod def _is_training_url(s): res = re.match('\/v3\/models\/training\-[a-zA-Z0-9\-\_]+', s) return res is not None
[docs] def get_frameworks(self): """ Get list of supported frameworks. :returns: supported frameworks for training :rtype: dict **Example** >>> model_details = client.training.get_frameworks() """ response_get = requests.get(self._href_definitions.get_repo_models_frameworks_href(), headers=self._client._get_headers()) if response_get.status_code == 200: return response_get.json() else: error_msg = 'Getting supported frameworks failed.' + '\n' + "Error msg: " + response_get.text print(error_msg) return None
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_status(self, run_uid): """ Get training status. :param run_uid: ID of trained model :type run_uid: str :returns: training run status :rtype: dict **Example** >>> training_status = client.training.get_status(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, 'run_uid', STR_TYPE, True) details = self.get_details(run_uid) if details is not None: return WMLResource._get_required_element_from_dict(details, u'details', [u'entity', u'status']) else: raise WMLClientError(u'Getting trained model status failed. Unable to get model details for run_uid: \'{}\'.'.format(run_uid))
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_details(self, run_uid=None, limit=None): """ Get training run details. :param run_uid: ID of training run (optional, if not provided all runs details are returned) :type run_uid: str :param limit: limit number of fetched records (optional) :type limit: int :returns: training run(s) details :rtype: dict (if uid is not None) or {"resources": [dict]} (if uid is None) **Example** >>> training_run_details = client.training.get_details(run_uid) >>> training_runs_details = client.training.get_details() """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, 'run_uid', STR_TYPE, False) url = self._href_definitions.get_repo_models_href() return self._get_artifact_details(url, run_uid, limit, 'trained models')
[docs] @staticmethod @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_run_url(run_details): """ Get training run url from training run details. :param run_details: Created training run details :type run_details: dict :returns: training run URL that is used to manage the training :rtype: str **Example** >>> run_url = client.training.get_run_url(run_details) """ Training._validate_type(run_details, u'run_details', object, True) Training._validate_type_of_details(run_details, TRAINING_RUN_DETAILS_TYPE) return WMLResource._get_required_element_from_dict(run_details, u'run_details', [u'metadata', u'url'])
[docs] @staticmethod @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_run_uid(run_details): """ Get uid of training run. :param run_details: training run details :type run_details: dict :returns: uid of training run :rtype: str **Example** >>> model_uid = client.training.get_run_uid(run_details) """ Training._validate_type(run_details, u'run_details', object, True) Training._validate_type_of_details(run_details, TRAINING_RUN_DETAILS_TYPE) return WMLResource._get_required_element_from_dict(run_details, u'run_details', [u'metadata', u'guid'])
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def cancel(self, run_uid): """ Cancel model training. :param run_uid: ID of trained model :type run_uid: str **Example** >>> client.training.cancel(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, u'run_uid', STR_TYPE, True) patch_endpoint = self._href_definitions.get_repo_model_href(run_uid) patch_payload = [ { u'op': u'replace', u'path': u'/status/state', u'value': u'canceled' } ] response_patch = requests.patch(patch_endpoint, json=patch_payload, headers=self._client._get_headers()) self._handle_response(204, u'model training cancel', response_patch, False) return
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def run(self, definition_uid, meta_props, asynchronous=True): """ Train new model. :param definition_uid: uid to saved model_definition/pipeline :type definition_uid: str :param meta_props: meta data of the training configuration. To see available meta names use: >>> client.training.ConfigurationMetaNames.show() :type meta_props: dict :param asynchronous: Default `True` means that training job is submitted and progress can be checked later. `False` - method will wait till job completion and print training stats. :type asynchronous: bool :returns: training run details :rtype: dict **Example** >>> metadata = { >>> client.training.ConfigurationMetaNames.NAME: u'Hand-written Digit Recognition', >>> client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCE: { >>> u'connection': { >>> u'endpoint_url': u'https://s3-api.us-geo.objectstorage.service.networklayer.com', >>> u'access_key_id': u'***', >>> u'secret_access_key': u'***' >>> }, >>> u'source': { >>> u'bucket': u'wml-dev', >>> } >>> u'type': u's3' >>> } >>> client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: { >>> u'connection': { >>> u'endpoint_url': u'https://s3-api.us-geo.objectstorage.service.networklayer.com', >>> u'access_key_id': u'***', >>> u'secret_access_key': u'***' >>> }, >>> u'target': { >>> u'bucket': u'wml-dev-results', >>> } >>> u'type': u's3' >>> }, >>> client.training.ConfigurationMetaNames.COMPUTE_CONFIGURATION: {'name': 'p100'}, >>> client.training.ConfigurationMetaNames.EXECUTION_COMMAND: "python3 tensorflow_mnist_softmax.py --trainingIters 20", >>> } >>> run_details = client.training.run(definition_uid, meta_props=metadata) >>> run_uid = client.training.get_run_uid(run_details) """ definition_uid = str_type_conv(definition_uid) Training._validate_type(definition_uid, 'definition_uid', STR_TYPE, True) Training._validate_type(meta_props, 'meta_props', object, True) Training._validate_type(asynchronous, 'asynchronous', bool, True) meta_props_str_conv(meta_props) self.ConfigurationMetaNames._validate(meta_props) if definition_uid is not None and is_uid(definition_uid): definition_url = self._href_definitions.get_definition_href(definition_uid) elif definition_uid is not None: raise WMLClientError(u'Invalid uid: \'{}\'.'.format(definition_uid)) else: raise WMLClientError(u'Both uid and url are empty.') details = self._client.repository.get_definition_details(definition_uid) # TODO remove when training service starts copying such data on their own FRAMEWORK_NAME = details[u'entity'][u'framework'][u'name'] FRAMEWORK_VERSION = details[u'entity'][u'framework'][u'version'] if self.ConfigurationMetaNames.EXECUTION_COMMAND not in meta_props: meta_props.update( {self.ConfigurationMetaNames.EXECUTION_COMMAND: details['entity']['command']}) training_configuration_metadata = { u'model_definition': { u'framework': { u'name': FRAMEWORK_NAME, u'version': FRAMEWORK_VERSION }, u'name': meta_props[self.ConfigurationMetaNames.NAME], u'author': { }, u'definition_href': definition_url, u'execution': { u'command': meta_props[self.ConfigurationMetaNames.EXECUTION_COMMAND], u'compute_configuration': {u'name': self.ConfigurationMetaNames._COMPUTE_CONFIGURATION_DEFAULT} } }, u'training_data_reference': meta_props[self.ConfigurationMetaNames.TRAINING_DATA_REFERENCE], u'training_results_reference': meta_props[self.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE] } if self.ConfigurationMetaNames.DESCRIPTION in meta_props: training_configuration_metadata[u'model_definition'].update({u'description': meta_props[self.ConfigurationMetaNames.DESCRIPTION]}) if self.ConfigurationMetaNames.AUTHOR_NAME in meta_props: training_configuration_metadata[u'model_definition'][u'author'].update({u'name': meta_props[self.ConfigurationMetaNames.AUTHOR_NAME]}) # TODO uncomment if it will be truly optional in service # if self.ConfigurationMetaNames.FRAMEWORK_NAME in meta_props or self.ConfigurationMetaNames.FRAMEWORK_VERSION in meta_props: # training_configuration_metadata['model_definition'].update({'framework': {}}) # if self.ConfigurationMetaNames.FRAMEWORK_NAME in meta_props: # training_configuration_metadata['model_definition']['framework'].update({'name': meta_props[self.ConfigurationMetaNames.FRAMEWORK_NAME]}) # if self.ConfigurationMetaNames.FRAMEWORK_VERSION in meta_props: # training_configuration_metadata['model_definition']['framework'].update({'version': meta_props[self.ConfigurationMetaNames.FRAMEWORK_VERSION]}) # TODO uncomment if it will be truly optional in service # if self.ConfigurationMetaNames.EXECUTION_COMMAND in meta_props or self.ConfigurationMetaNames.EXECUTION_RESOURCE_SIZE in meta_props: # training_configuration_metadata['model_definition'].update({'execution': {}}) # if self.ConfigurationMetaNames.EXECUTION_COMMAND in meta_props: # training_configuration_metadata['model_definition']['execution'].update({'command': meta_props[self.ConfigurationMetaNames.EXECUTION_COMMAND]}) if self.ConfigurationMetaNames.COMPUTE_CONFIGURATION in meta_props: training_configuration_metadata[u'model_definition'][u'execution'][u'compute_configuration'].update(meta_props[self.ConfigurationMetaNames.COMPUTE_CONFIGURATION]) train_endpoint = u'{}/v3/models'.format(self._wml_credentials[u'url']) response_train_post = requests.post(train_endpoint, json=training_configuration_metadata, headers=self._client._get_headers()) run_details = self._handle_response(202, u'training', response_train_post) trained_model_guid = self.get_run_uid(run_details) if asynchronous is True: return run_details else: print_text_header_h1(u'Running \'{}\''.format(trained_model_guid)) status = self.get_status(trained_model_guid) state = status[u'state'] with StatusLogger(state) as status_logger: while state not in ['error', 'completed', 'canceled']: time.sleep(5) state = self.get_status(trained_model_guid)['state'] status_logger.log_state(state) if u'completed' in state: print(u'\nTraining of \'{}\' finished successfully.'.format(str(trained_model_guid))) else: print(u'\nTraining of \'{}\' failed with status: \'{}\'.'.format(trained_model_guid, str(status))) self._logger.debug(u'Response({}): {}'.format(state, run_details)) return self.get_details(trained_model_guid)
[docs] def list(self, limit=None): """ List training runs. If limit is set to None there will be only first 50 records shown. :returns: None :rtype: None .. note:: This function only prints the list of training runs **Example** >>> client.training.list() """ details = self.get_details() resources = details[u'resources'] values = [(m[u'metadata'][u'guid'], m[u'entity'][u'model_definition'][u'name'], m[u'entity'][u'status'][u'state'], m[u'metadata'][u'created_at'], m[u'entity'][u'model_definition'][u'framework'][u'name']) for m in resources] self._list(values, [u'GUID (training)', u'NAME', u'STATE', u'CREATED', u'FRAMEWORK'], limit, 50)
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def delete(self, run_uid): """ Delete training run. :param run_uid: ID of trained model :type run_uid: str :returns: returns the status message ("SUCCESS" or FAILED") :rtype: str **Example** >>> client.training.delete(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, u'run_uid', STR_TYPE, True) response_delete = requests.delete(self._href_definitions.get_repo_model_href(run_uid), headers=self._client._get_headers()) return self._handle_response(204, u'trained model deletion', response_delete, False)
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def monitor_logs(self, run_uid): """ Monitor training log file (prints log content to console). :param run_uid: ID of trained model :type run_uid: str **Example** >>> client.training.monitor_logs(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, u'run_uid', STR_TYPE, True) self._simple_monitor_logs(run_uid, lambda: print_text_header_h1(u'Log monitor started for training run: ' + str(run_uid))) print_text_header_h2('Log monitor done.')
def _COS_logs(self, run_uid,on_start=lambda: {}): on_start() run_details = self.get_details(run_uid) endpoint_url = run_details["entity"]["training_results_reference"]["connection"]["endpoint_url"] if("networklayer" in endpoint_url): endpoint_url = endpoint_url.replace("service.networklayer.com","softlayer.net") aws_access_key = run_details["entity"]["training_results_reference"]["connection"]["access_key_id"] aws_secret = run_details["entity"]["training_results_reference"]["connection"]["secret_access_key"] bucket = run_details["entity"]["training_results_reference"]["location"]["bucket"] try: run_details["entity"]["training_results_reference"]["location"]["model_location"] except: raise WMLClientError("The training-run has not started. Error - " + run_details["entity"]["status"]["error"]["errors"][0]["message"]) if (bucket == ""): bucket = run_details["entity"]["training_results_reference"]["target"]["bucket"] import ibm_boto3 client_cos = ibm_boto3.client(service_name='s3', aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret, endpoint_url=endpoint_url) try: key = run_details["entity"]["training_results_reference"]["location"]["model_location"] + '/learner-1/training-log.txt' obj = client_cos.get_object(Bucket=bucket, Key=key) print(obj['Body'].read().decode('utf-8')) except ibm_boto3.exceptions.ibm_botocore.client.ClientError as ex: if ex.response['Error']['Code'] == 'NoSuchKey': print("ERROR - Cannot find training-log.txt in the bucket") else: print(ex) def _COS_metrics(self, run_uid,on_start=lambda: {}): on_start() run_details = self.get_details(run_uid) endpoint_url = run_details["entity"]["training_results_reference"]["connection"]["endpoint_url"] if("networklayer" in endpoint_url): endpoint_url = endpoint_url.replace("service.networklayer.com","softlayer.net") aws_access_key = run_details["entity"]["training_results_reference"]["connection"]["access_key_id"] aws_secret = run_details["entity"]["training_results_reference"]["connection"]["secret_access_key"] bucket = run_details["entity"]["training_results_reference"]["location"]["bucket"] try: run_details["entity"]["training_results_reference"]["location"]["model_location"] except: raise WMLClientError("The training-run has not started. Error - " + run_details["entity"]["status"]["error"]["errors"][0]["message"]) if (bucket == ""): bucket = run_details["entity"]["training_results_reference"]["target"]["bucket"] import ibm_boto3 client_cos = ibm_boto3.client(service_name='s3', aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret, endpoint_url=endpoint_url) try: key = run_details["entity"]["training_results_reference"]["location"]["model_location"] + '/learner-1/evaluation-metrics.txt' obj = client_cos.get_object(Bucket=bucket, Key=key) print(obj['Body'].read().decode('utf-8')) except ibm_boto3.exceptions.ibm_botocore.client.ClientError as ex: if ex.response['Error']['Code'] == 'NoSuchKey': print("ERROR - Cannot find evaluation-metrics.txt in the bucket") else: print(ex) def _simple_monitor_logs(self, run_uid, on_start=lambda: {}): run_details = self.get_details(run_uid) status = run_details["entity"]["status"]["state"] if (status == "completed" or status == "error" or status == "failed" or status == "canceled"): self._COS_logs(run_uid, lambda: print_text_header_h1(u'Log monitor started for training run: ' + str(run_uid))) else: from lomond import WebSocket monitor_endpoint = self._wml_credentials[u'url'].replace(u'https', u'wss') + u'/v3/models/' + run_uid + u'/monitor' websocket = WebSocket(monitor_endpoint) try: websocket.add_header(bytes("Authorization", "utf-8"), bytes("bearer " + self._client.service_instance._get_token(), "utf-8")) except: websocket.add_header(bytes("Authorization"), bytes("bearer " + self._client.service_instance._get_token())) if 'apikey' in self._wml_credentials.keys(): try: websocket.add_header(bytes('ML-Instance-ID', 'utf-8'), bytes(self._wml_credentials['instance_id'],"utf-8")) except: websocket.add_header(bytes('ML-Instance-ID'), bytes(self._wml_credentials['instance_id'])) on_start() for event in websocket: if event.name == u'text': text = json.loads(event.text) if 'status' in text: status = text[u'status'] if u'message' in status: if len(status[u'message']) > 0: print(status[u'message']) websocket.close()
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def monitor_metrics(self, run_uid): """ Monitor metrics log file (prints log content to console). :param run_uid: ID of trained model :type run_uid: str **Example** >>> client.training.monitor_metrics(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, u'run_uid', STR_TYPE, True) run_details = self.get_details(run_uid) status = run_details["entity"]["status"]["state"] if (status == "completed" or status == "error" or status == "failed" or status == "canceled"): self._COS_metrics(run_uid, lambda: print_text_header_h1('Metric monitor started for training run: ' + str(run_uid))) print_text_header_h2('Metric monitor done.') else: from lomond import WebSocket monitor_endpoint = self._wml_credentials[u'url'].replace(u'https', u'wss') + u'/v3/models/' + run_uid + u'/monitor' websocket = WebSocket(monitor_endpoint) try: websocket.add_header(bytes("Authorization", "utf-8"), bytes("bearer " + self._client.service_instance._get_token(), "utf-8")) except: websocket.add_header(bytes("Authorization"), bytes("bearer " + self._client.service_instance._get_token())) if 'apikey' in self._wml_credentials.keys(): try: websocket.add_header(bytes('ML-Instance-ID', 'utf-8'), bytes(self._wml_credentials['instance_id'],"utf-8")) except: websocket.add_header(bytes('ML-Instance-ID'), bytes(self._wml_credentials['instance_id'])) print_text_header_h1('Metric monitor started for training run: ' + str(run_uid)) for event in websocket: if event.name == u'text': text = json.loads(event.text) status = text[u'status'] if u'metrics' in status: metrics = status[u'metrics'] if len(metrics) > 0: metric = metrics[0] values = u'' for x in metric[u'values']: values = values + x[u'name'] + ':' + str(x[u'value']) + u' ' msg = u'{} iteration:{} phase:{} {}'.format(metric[u'timestamp'], metric[u'iteration'], metric[u'phase'], values) print(msg) websocket.close() print_text_header_h2('Metric monitor done.')
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_metrics(self, run_uid): """ Get metrics values. :param run_uid: ID of trained model :type run_uid: str :returns: metric values :rtype: list of dicts **Example** >>> client.training.get_metrics(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, u'run_uid', STR_TYPE, True) status = self.get_status(run_uid) metrics = status['metrics'] return metrics
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_latest_metrics(self, run_uid): """ Get latest metrics values. :param run_uid: ID of trained model :type run_uid: {0} :returns: metric values :rtype: list of dicts **Example** >>> client.training.get_latest_metrics(run_uid) """ run_uid = str_type_conv(run_uid) Training._validate_type(run_uid, u'run_uid', STR_TYPE, True) status = self.get_status(run_uid) metrics = status['metrics'] latest_metrics = [] if len(metrics) > 0: grouped_metrics = group_metrics(metrics) for key, value in grouped_metrics.items(): sorted_value = sorted(value, key=lambda k: k['iteration']) latest_metrics.append(sorted_value[-1]) return latest_metrics
[docs] def list_definitions(self, limit=None): """ List stored definitions. If limit is set to None there will be only first 50 records shown. :returns: None :rtype: None .. note:: This function only prints the list of stored training-definitions **Example**: >>> client.training.list_definitions() """ self._client.repository.list_definitions(limit=limit)
[docs] @docstring_parameter({'str_type': STR_TYPE_NAME}) def get_definition_details(self, definition_uid=None, limit=None): """ Get metadata of stored definitions. If definition uid is not specified returns all model definitions metadata. :param definition_uid: stored model definition UID (optional) :type definition_uid: str :param limit: limit number of fetched records (optional) :type limit: int :returns: stored definition(s) metadata :rtype: dict **Example** >>> definition_details = client.training.get_definition_details(definition_uid) >>> definition_details = client.training.get_definition_details() """ return self._client.repository.get_definition_details(definition_uid, limit=limit)
[docs] def get_definition_uid(self, definition_details): """ Get uid of stored definition. :param definition_details: stored definition details :type definition_details: dict :returns: uid of stored model :rtype: str **Example** >>> definition_uid = client.training.get_definition_uid(definition_details) """ return self._client.repository.get_definition_uid(definition_details)
[docs] def get_definition_url(self, definition_details): """ Get url of stored definition. :param definition_details: stored definition details :type definition_details: dict :returns: url of stored definition :rtype: str **Example** >>> definition_url = client.training.get_definition_url(definition_details) """ return self._client.repository.get_definition_url(definition_details)