test_pie/external/MRF/TRW-S.h

87 lines
2.7 KiB
C
Raw Permalink Normal View History

2023-09-14 11:12:02 +02:00
#ifndef __TRWS_H__
#define __TRWS_H__
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include "mrf.h"
class TRWS : public MRF{
public:
typedef double REAL;
TRWS(int width, int height, int nLabels, EnergyFunction *eng);
TRWS(int nPixels, int nLabels,EnergyFunction *eng);
~TRWS();
void setNeighbors(int /*pix1*/, int /*pix2*/, CostVal /*weight*/){printf("Not implemented"); exit(1);}
Label getLabel(int pixel){return(m_answer[pixel]);};
void setLabel(int pixel,Label label){m_answer[pixel] = label;};
Label* getAnswerPtr(){return(m_answer);};
void clearAnswer();
void setParameters(int /*numParam*/, void * /*param*/){printf("No optional parameters to set"); exit(1);}
EnergyVal smoothnessEnergy();
EnergyVal dataEnergy();
double lowerBound() { return (double)m_lowerBound; }
// For general smoothness functions, this code tries to cache all function values in an array
// for efficiency. To prevent this, call the following function before calling initialize():
void dontCacheSmoothnessCosts() {m_allocateArrayForSmoothnessCostFn = false;}
protected:
void setData(DataCostFn dcost);
void setData(CostVal* data);
void setSmoothness(SmoothCostGeneralFn cost);
void setSmoothness(CostVal* V);
void setSmoothness(int smoothExp,CostVal smoothMax, CostVal lambda);
void setCues(CostVal* hCue, CostVal* vCue);
void Allocate();
void initializeAlg();
void optimizeAlg(int nIterations);
private:
enum
{
NONE,
L1,
L2,
FIXED_MATRIX,
GENERAL,
BINARY,
} m_type;
CostVal m_smoothMax; // used only if
CostVal m_lambda; // m_type == L1 or m_type == L2
Label *m_answer;
CostVal *m_V; // points to array of size nLabels^2 (if type==FIXED_MATRIX) or of size nEdges*nLabels^2 (if type==GENERAL)
CostVal *m_D;
CostVal *m_DBinary; // valid if type == BINARY
CostVal *m_horzWeights;
CostVal *m_vertWeights;
CostVal *m_horzWeightsBinary;
CostVal *m_vertWeightsBinary;
DataCostFn m_dataFn;
SmoothCostGeneralFn m_smoothFn;
bool m_needToFreeV;
bool m_needToFreeD;
REAL* m_messages; // size of one message: N = 1 if m_type == BINARY, N = K otherwise
// message between edges (x,y)-(x+1,y): m_messages+(2*x+2*y*m_width)*N
// message between edges (x,y)-(x,y+1): m_messages+(2*x+2*y*m_width+1)*N
int m_messageArraySizeInBytes;
REAL m_lowerBound;
void optimize_GRID_L1(int nIterations);
void optimize_GRID_L2(int nIterations);
void optimize_GRID_FIXED_MATRIX(int nIterations);
void optimize_GRID_GENERAL(int nIterations);
void optimize_GRID_BINARY(int nIterations);
};
#endif /* __TRWS_H__ */