spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbradley <...@git.apache.org>
Subject [GitHub] spark pull request: [SPARK-6893][ML] default pipeline parameter ha...
Date Wed, 15 Apr 2015 22:40:01 GMT
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5534#discussion_r28470162
  
    --- Diff: python/pyspark/ml/param/__init__.py ---
    @@ -67,11 +66,112 @@ def params(self):
             return filter(lambda attr: isinstance(attr, Param),
                           [getattr(self, x) for x in dir(self) if x != "params"])
     
    -    def _merge_params(self, params):
    -        paramMap = self.paramMap.copy()
    -        paramMap.update(params)
    +    def _explain(self, param):
    +        """
    +        Explains a single param and returns its name, doc, and optional
    +        default value and user-supplied value in a string.
    +        """
    +        param = self._resolveParam(param)
    +        values = []
    +        if self.isDefined(param):
    +            if param in self.defaultParamMap:
    +                values.append("default: %s" % self.defaultParamMap[param])
    +            if param in self.paramMap:
    +                values.append("current: %s" % self.paramMap[param])
    +        else:
    +            values.append("undefined")
    +        valueStr = "(" + ", ".join(values) + ")"
    +        return "%s: %s %s" % (param.name, param.doc, valueStr)
    +
    +    def explainParams(self):
    +        """
    +        Returns the documentation of all params with their optionally
    +        default values and user-supplied values.
    +        """
    +        return "\n".join([self._explain(param) for param in self.params])
    +
    +    def getParam(self, paramName):
    +        """
    +        Gets a param by its name.
    +        """
    +        param = getattr(self, paramName)
    +        if isinstance(param, Param):
    +            return param
    +        else:
    +            raise ValueError("Cannot find param with name %s." % paramName)
    +
    +    def isSet(self, param):
    +        """
    +        Checks whether a param is explicitly set by user.
    +        """
    +        param = self._resolveParam(param)
    +        return param in self.paramMap
    +
    +    def hasDefault(self, param):
    +        """
    +        Checks whether a param has a default value.
    +        """
    +        param = self._resolveParam(param)
    +        return param in self.defaultParamMap
    +
    +    def isDefined(self, param):
    +        """
    +        Checks whether a param is explicitly set by user or has a default value.
    +        """
    +        return self.isSet(param) or self.hasDefault(param)
    +
    +    def getOrDefault(self, param):
    +        """
    +        Gets the value of a param in the user-supplied param map or its
    +        default value. Raises an error if either is set.
    +        """
    +        if isinstance(param, Param):
    +            if param in self.paramMap:
    +                return self.paramMap[param]
    +            else:
    +                return self.defaultParamMap[param]
    +        elif isinstance(param, str):
    +            return self.getOrDefault(self.getParam(param))
    +        else:
    +            raise KeyError("Cannot recognize %r as a param." % param)
    +
    +    def extractParamMap(self, extraParamMap={}):
    +        """
    +        Extracts the embedded default param values and user-supplied
    +        values, and then merges them with extra values from input into
    +        a flat param map, where the latter values is used if there
    +        exist conflicts, i.e., with ordering: default param values <
    +        user-supplied values < extraParamMap.
    +        :param extraParamMap: extra param values
    +        :return: merged param map
    +        """
    +        paramMap = self.defaultParamMap.copy()
    +        paramMap.update(self.paramMap)
    +        paramMap.update(extraParamMap)
             return paramMap
     
    +    def _shouldOwn(self, param):
    +        """
    +        Validates that the input param belongs to this Params instance.
    +        """
    +        if param.parent is not self:
    +            raise ValueError("Param %r does not belong to %r." % (param, self))
    +
    +    def _resolveParam(self, param):
    +        """
    +        Resolves a param and validates the ownership.
    +        :param param: param name or the param instance, which must
    +                      belongs to this Params instance
    --- End diff --
    
    "belongs" --> "belong"


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message