| /* |
| * Copyright (C) 2012 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package android.bordeaux.services; |
| import android.util.Log; |
| |
| import android.bordeaux.learning.StochasticLinearRanker; |
| import java.util.HashMap; |
| import java.util.Map; |
| import java.io.Serializable; |
| |
| public class StochasticLinearRankerWithPrior extends StochasticLinearRanker { |
| private final String TAG = "StochasticLinearRankerWithPrior"; |
| private final float EPSILON = 0.0001f; |
| |
| /* If the is parameter is true, the final score would be a |
| linear combination of user model and prior model */ |
| private final String USE_PRIOR = "usePriorInformation"; |
| |
| /* When prior model is used, this parmaeter will set the mixing factor, alpha. */ |
| private final String SET_ALPHA = "setAlpha"; |
| |
| /* When prior model is used, If this parameter is true then algorithm will use |
| the automatic cross validated alpha for mixing user model and prior model */ |
| private final String USE_AUTO_ALPHA = "useAutoAlpha"; |
| |
| /* When automatic cross validation is active, this parameter will |
| set the forget rate in cross validation. */ |
| private final String SET_FORGET_RATE = "setForgetRate"; |
| |
| /* When automatic cross validation is active, this parameter will |
| set the minium number of required training pairs before using the user model */ |
| private final String SET_MIN_TRAIN_PAIR = "setMinTrainingPair"; |
| |
| private final String SET_USER_PERF = "setUserPerformance"; |
| private final String SET_PRIOR_PERF = "setPriorPerformance"; |
| private final String SET_NUM_TRAIN_PAIR = "setNumberTrainingPairs"; |
| private final String SET_AUTO_ALPHA = "setAutoAlpha"; |
| |
| |
| |
| private HashMap<String, Float> mPriorWeights = new HashMap<String, Float>(); |
| private float mAlpha = 0; |
| private float mAutoAlpha = 0; |
| private float mForgetRate = 0; |
| private float mUserRankerPerf = 0; |
| private float mPriorRankerPerf = 0; |
| private int mMinReqTrainingPair = 0; |
| private int mNumTrainPair = 0; |
| private boolean mUsePrior = false; |
| private boolean mUseAutoAlpha = false; |
| |
| static public class Model implements Serializable { |
| public StochasticLinearRanker.Model uModel = new StochasticLinearRanker.Model(); |
| public HashMap<String, Float> priorWeights = new HashMap<String, Float>(); |
| public HashMap<String, String> priorParameters = new HashMap<String, String>(); |
| } |
| |
| @Override |
| public void resetRanker(){ |
| super.resetRanker(); |
| mPriorWeights.clear(); |
| mAlpha = 0; |
| mAutoAlpha = 0; |
| mForgetRate = 0; |
| mMinReqTrainingPair = 0; |
| mUserRankerPerf = 0; |
| mPriorRankerPerf = 0; |
| mNumTrainPair = 0; |
| mUsePrior = false; |
| mUseAutoAlpha = false; |
| } |
| |
| @Override |
| public float scoreSample(String[] keys, float[] values) { |
| if (!mUsePrior){ |
| return super.scoreSample(keys, values); |
| } else { |
| if (mUseAutoAlpha) { |
| if (mNumTrainPair > mMinReqTrainingPair) |
| return (1 - mAutoAlpha) * super.scoreSample(keys,values) + |
| mAutoAlpha * priorScoreSample(keys,values); |
| else |
| return priorScoreSample(keys,values); |
| } else |
| return (1 - mAlpha) * super.scoreSample(keys,values) + |
| mAlpha * priorScoreSample(keys,values); |
| } |
| } |
| |
| public float priorScoreSample(String[] keys, float[] values) { |
| float score = 0; |
| for (int i=0; i< keys.length; i++){ |
| if (mPriorWeights.get(keys[i]) != null ) |
| score = score + mPriorWeights.get(keys[i]) * values[i]; |
| } |
| return score; |
| } |
| |
| @Override |
| public boolean updateClassifier(String[] keys_positive, |
| float[] values_positive, |
| String[] keys_negative, |
| float[] values_negative){ |
| if (mUsePrior && mUseAutoAlpha && (mNumTrainPair > mMinReqTrainingPair)) |
| updateAutoAlpha(keys_positive, values_positive, keys_negative, values_negative); |
| mNumTrainPair ++; |
| return super.updateClassifier(keys_positive, values_positive, |
| keys_negative, values_negative); |
| } |
| |
| void updateAutoAlpha(String[] keys_positive, |
| float[] values_positive, |
| String[] keys_negative, |
| float[] values_negative) { |
| float positiveUserScore = super.scoreSample(keys_positive, values_positive); |
| float negativeUserScore = super.scoreSample(keys_negative, values_negative); |
| float positivePriorScore = priorScoreSample(keys_positive, values_positive); |
| float negativePriorScore = priorScoreSample(keys_negative, values_negative); |
| float userDecision = 0; |
| float priorDecision = 0; |
| if (positiveUserScore > negativeUserScore) |
| userDecision = 1; |
| if (positivePriorScore > negativePriorScore) |
| priorDecision = 1; |
| mUserRankerPerf = (1 - mForgetRate) * mUserRankerPerf + userDecision; |
| mPriorRankerPerf = (1 - mForgetRate) * mPriorRankerPerf + priorDecision; |
| mAutoAlpha = (mPriorRankerPerf + EPSILON) / (mUserRankerPerf + mPriorRankerPerf + EPSILON); |
| } |
| |
| public Model getModel(){ |
| Model m = new Model(); |
| m.uModel = super.getUModel(); |
| m.priorWeights.putAll(mPriorWeights); |
| m.priorParameters.put(SET_ALPHA, String.valueOf(mAlpha)); |
| m.priorParameters.put(SET_AUTO_ALPHA, String.valueOf(mAutoAlpha)); |
| m.priorParameters.put(SET_FORGET_RATE, String.valueOf(mForgetRate)); |
| m.priorParameters.put(SET_MIN_TRAIN_PAIR, String.valueOf(mMinReqTrainingPair)); |
| m.priorParameters.put(SET_USER_PERF, String.valueOf(mUserRankerPerf)); |
| m.priorParameters.put(SET_PRIOR_PERF, String.valueOf(mPriorRankerPerf)); |
| m.priorParameters.put(SET_NUM_TRAIN_PAIR, String.valueOf(mNumTrainPair)); |
| m.priorParameters.put(USE_AUTO_ALPHA, String.valueOf(mUseAutoAlpha)); |
| m.priorParameters.put(USE_PRIOR, String.valueOf(mUsePrior)); |
| return m; |
| } |
| |
| public boolean loadModel(Model m) { |
| mPriorWeights.clear(); |
| mPriorWeights.putAll(m.priorWeights); |
| for (Map.Entry<String, String> e : m.priorParameters.entrySet()) { |
| boolean res = setModelParameter(e.getKey(), e.getValue()); |
| if (!res) return false; |
| } |
| return super.loadModel(m.uModel); |
| } |
| |
| public boolean setModelPriorWeights(HashMap<String, Float> pw){ |
| mPriorWeights.clear(); |
| mPriorWeights.putAll(pw); |
| return true; |
| } |
| |
| public boolean setModelParameter(String key, String value){ |
| if (key.equals(USE_AUTO_ALPHA)){ |
| mUseAutoAlpha = Boolean.parseBoolean(value); |
| } else if (key.equals(USE_PRIOR)){ |
| mUsePrior = Boolean.parseBoolean(value); |
| } else if (key.equals(SET_ALPHA)){ |
| mAlpha = Float.valueOf(value.trim()).floatValue(); |
| }else if (key.equals(SET_AUTO_ALPHA)){ |
| mAutoAlpha = Float.valueOf(value.trim()).floatValue(); |
| }else if (key.equals(SET_FORGET_RATE)){ |
| mForgetRate = Float.valueOf(value.trim()).floatValue(); |
| }else if (key.equals(SET_MIN_TRAIN_PAIR)){ |
| mMinReqTrainingPair = (int) Float.valueOf(value.trim()).floatValue(); |
| }else if (key.equals(SET_USER_PERF)){ |
| mUserRankerPerf = Float.valueOf(value.trim()).floatValue(); |
| }else if (key.equals(SET_PRIOR_PERF)){ |
| mPriorRankerPerf = Float.valueOf(value.trim()).floatValue(); |
| }else if (key.equals(SET_NUM_TRAIN_PAIR)){ |
| mNumTrainPair = (int) Float.valueOf(value.trim()).floatValue(); |
| }else |
| return super.setModelParameter(key, value); |
| return true; |
| } |
| |
| public void print(Model m){ |
| super.print(m.uModel); |
| String Spw = ""; |
| for (Map.Entry<String, Float> e : m.priorWeights.entrySet()) |
| Spw = Spw + "<" + e.getKey() + "," + e.getValue() + "> "; |
| Log.i(TAG, "Prior model is " + Spw); |
| String Spp = ""; |
| for (Map.Entry<String, String> e : m.priorParameters.entrySet()) |
| Spp = Spp + "<" + e.getKey() + "," + e.getValue() + "> "; |
| Log.i(TAG, "Prior parameters are " + Spp); |
| } |
| } |